feat:add amazon.nova model tool use support.

This commit is contained in:
glay 2024-12-14 11:01:10 +08:00
parent 0ec1ae6276
commit e839940a26
3 changed files with 57 additions and 48 deletions

View File

@ -139,7 +139,6 @@ export class BedrockApi implements LLMApi {
if (item.text || typeof item === "string") { if (item.text || typeof item === "string") {
return { text: item.text || item }; return { text: item.text || item };
} }
// Handle image content // Handle image content
if (item.image_url?.url) { if (item.image_url?.url) {
const { url = "" } = item.image_url; const { url = "" } = item.image_url;
@ -170,6 +169,7 @@ export class BedrockApi implements LLMApi {
top_p: modelConfig.top_p || 0.9, top_p: modelConfig.top_p || 0.9,
top_k: modelConfig.top_k || 50, top_k: modelConfig.top_k || 50,
max_new_tokens: modelConfig.max_tokens || 1000, max_new_tokens: modelConfig.max_tokens || 1000,
stopSequences: modelConfig.stop || [],
}, },
}; };
@ -182,7 +182,7 @@ export class BedrockApi implements LLMApi {
]; ];
} }
// Add tools if available - now in correct format // Add tools if available - exact Nova format
if (toolsArray.length > 0) { if (toolsArray.length > 0) {
requestBody.toolConfig = { requestBody.toolConfig = {
tools: toolsArray.map((tool) => ({ tools: toolsArray.map((tool) => ({
@ -192,14 +192,13 @@ export class BedrockApi implements LLMApi {
inputSchema: { inputSchema: {
json: { json: {
type: "object", type: "object",
properties: tool?.function?.parameters || {}, properties: tool?.function?.parameters?.properties || {},
required: Object.keys(tool?.function?.parameters || {}), required: tool?.function?.parameters?.required || [],
}, },
}, },
}, },
})), })),
toolChoice: { auto: {} }, toolChoice: { auto: {} },
// toolChoice: { any: {} }
}; };
} }
@ -501,7 +500,7 @@ export class BedrockApi implements LLMApi {
})), })),
); );
} else if (isNova) { } else if (isNova) {
// Format for Nova // Format for Nova - Updated format
// @ts-ignore // @ts-ignore
requestPayload?.messages?.splice( requestPayload?.messages?.splice(
// @ts-ignore // @ts-ignore
@ -511,35 +510,39 @@ export class BedrockApi implements LLMApi {
role: "assistant", role: "assistant",
content: [ content: [
{ {
text: "", // Add empty text content to satisfy type requirements toolUse: {
tool_calls: toolCallMessage.tool_calls.map( toolUseId: toolCallMessage.tool_calls[0].id,
(tool: ChatMessageTool) => ({ name: toolCallMessage.tool_calls[0]?.function?.name,
id: tool.id, input:
name: tool?.function?.name, typeof toolCallMessage.tool_calls[0]?.function
arguments: tool?.function?.arguments ?.arguments === "string"
? JSON.parse(tool?.function?.arguments) ? JSON.parse(
: {}, toolCallMessage.tool_calls[0]?.function
}), ?.arguments,
), )
: toolCallMessage.tool_calls[0]?.function
?.arguments || {},
},
}, },
], ],
}, },
...toolCallResult.map((result) => ({ {
role: "user", role: "user",
content: [ content: [
{ {
toolUseId: result.tool_call_id, toolResult: {
content: [ toolUseId: toolCallResult[0].tool_call_id,
{ content: [
json: {
typeof result.content === "string" json: {
? JSON.parse(result.content) content: toolCallResult[0].content,
: result.content, },
}, },
], ],
},
}, },
], ],
})), },
); );
} else { } else {
console.warn( console.warn(
@ -573,7 +576,8 @@ export class BedrockApi implements LLMApi {
const message = extractMessage(resJson); const message = extractMessage(resJson);
options.onFinish(message, res); options.onFinish(message, res);
} catch (e) { } catch (e) {
const error = e instanceof Error ? e : new Error('Unknown error occurred'); const error =
e instanceof Error ? e : new Error("Unknown error occurred");
console.error("[Bedrock Client] Chat failed:", error.message); console.error("[Bedrock Client] Chat failed:", error.message);
options.onError?.(error); options.onError?.(error);
} }
@ -829,7 +833,7 @@ function bedrockStream(
} catch (err) { } catch (err) {
console.error( console.error(
"[Bedrock Stream]:", "[Bedrock Stream]:",
err instanceof Error ? err.message : "Stream processing failed" err instanceof Error ? err.message : "Stream processing failed",
); );
throw new Error("Failed to process stream response"); throw new Error("Failed to process stream response");
} finally { } finally {
@ -843,7 +847,7 @@ function bedrockStream(
} }
console.error( console.error(
"[Bedrock Request] Failed:", "[Bedrock Request] Failed:",
e instanceof Error ? e.message : "Request failed" e instanceof Error ? e.message : "Request failed",
); );
options.onError?.(e); options.onError?.(e);
throw new Error("Request processing failed"); throw new Error("Request processing failed");

View File

@ -296,7 +296,8 @@ export function showPlugins(provider: ServiceProvider, model: string) {
} }
if ( if (
(provider == ServiceProvider.Bedrock && model.includes("claude-3")) || (provider == ServiceProvider.Bedrock && model.includes("claude-3")) ||
model.includes("mistral-large") model.includes("mistral-large") ||
model.includes("amazon.nova")
) { ) {
return true; return true;
} }

View File

@ -435,25 +435,31 @@ export function processMessage(
if (!data) return { remainText, index }; if (!data) return { remainText, index };
try { try {
// Handle Nova's tool calls // Handle Nova's tool calls with exact schema match
// console.log("processMessage data=========================",data); // console.log("processMessage data=========================",data);
if ( if (data.contentBlockStart?.start?.toolUse) {
data.stopReason === "tool_use" && const toolUse = data.contentBlockStart.start.toolUse;
data.output?.message?.content?.[0]?.toolUse
) {
const toolUse = data.output.message.content[0].toolUse;
index += 1; index += 1;
runTools.push({ runTools.push({
id: `tool-${Date.now()}`, id: toolUse.toolUseId,
type: "function", type: "function",
function: { function: {
name: toolUse.name, name: toolUse.name || "", // Ensure name is always present
arguments: JSON.stringify(toolUse.input), arguments: "{}", // Initialize empty arguments
}, },
}); });
return { remainText, index }; return { remainText, index };
} }
// Handle Nova's tool input in contentBlockDelta
if (data.contentBlockDelta?.delta?.toolUse?.input) {
if (runTools[index]) {
runTools[index].function.arguments =
data.contentBlockDelta.delta.toolUse.input;
}
return { remainText, index };
}
// Handle Nova's text content // Handle Nova's text content
if (data.output?.message?.content?.[0]?.text) { if (data.output?.message?.content?.[0]?.text) {
remainText += data.output.message.content[0].text; remainText += data.output.message.content[0].text;
@ -465,11 +471,9 @@ export function processMessage(
return { remainText, index }; return { remainText, index };
} }
// Handle Nova's contentBlockDelta event // Handle Nova's text delta
if (data.contentBlockDelta) { if (data.contentBlockDelta?.delta?.text) {
if (data.contentBlockDelta.delta?.text) { remainText += data.contentBlockDelta.delta.text;
remainText += data.contentBlockDelta.delta.text;
}
return { remainText, index }; return { remainText, index };
} }
@ -496,7 +500,7 @@ export function processMessage(
id: data.content_block.id, id: data.content_block.id,
type: "function", type: "function",
function: { function: {
name: data.content_block.name, name: data.content_block.name || "", // Ensure name is always present
arguments: "", arguments: "",
}, },
}); });
@ -523,7 +527,7 @@ export function processMessage(
id: toolCall.id || `tool-${Date.now()}`, id: toolCall.id || `tool-${Date.now()}`,
type: "function", type: "function",
function: { function: {
name: toolCall.function?.name, name: toolCall.function?.name || "", // Ensure name is always present
arguments: toolCall.function?.arguments || "", arguments: toolCall.function?.arguments || "",
}, },
}); });
@ -554,7 +558,7 @@ export function processMessage(
remainText += newText; remainText += newText;
} }
} catch (e) { } catch (e) {
console.warn("Failed to process Bedrock message"); console.warn("Failed to process Bedrock message:");
} }
return { remainText, index }; return { remainText, index };