From e839940a26d965403d7872981681c5cf489318b2 Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 14 Dec 2024 11:01:10 +0800 Subject: [PATCH] feat:add amazon.nova model tool use support. --- app/client/platforms/bedrock.ts | 64 +++++++++++++++++---------------- app/utils.ts | 3 +- app/utils/aws.ts | 38 +++++++++++--------- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index d1d519837..69e9b9137 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -139,7 +139,6 @@ export class BedrockApi implements LLMApi { if (item.text || typeof item === "string") { return { text: item.text || item }; } - // Handle image content if (item.image_url?.url) { const { url = "" } = item.image_url; @@ -170,6 +169,7 @@ export class BedrockApi implements LLMApi { top_p: modelConfig.top_p || 0.9, top_k: modelConfig.top_k || 50, max_new_tokens: modelConfig.max_tokens || 1000, + stopSequences: modelConfig.stop || [], }, }; @@ -182,7 +182,7 @@ export class BedrockApi implements LLMApi { ]; } - // Add tools if available - now in correct format + // Add tools if available - exact Nova format if (toolsArray.length > 0) { requestBody.toolConfig = { tools: toolsArray.map((tool) => ({ @@ -192,14 +192,13 @@ export class BedrockApi implements LLMApi { inputSchema: { json: { type: "object", - properties: tool?.function?.parameters || {}, - required: Object.keys(tool?.function?.parameters || {}), + properties: tool?.function?.parameters?.properties || {}, + required: tool?.function?.parameters?.required || [], }, }, }, })), toolChoice: { auto: {} }, - // toolChoice: { any: {} } }; } @@ -501,7 +500,7 @@ export class BedrockApi implements LLMApi { })), ); } else if (isNova) { - // Format for Nova + // Format for Nova - Updated format // @ts-ignore requestPayload?.messages?.splice( // @ts-ignore @@ -511,35 +510,39 @@ export class BedrockApi implements LLMApi { role: "assistant", content: [ { - text: "", // Add empty text content to satisfy type requirements - tool_calls: toolCallMessage.tool_calls.map( - (tool: ChatMessageTool) => ({ - id: tool.id, - name: tool?.function?.name, - arguments: tool?.function?.arguments - ? JSON.parse(tool?.function?.arguments) - : {}, - }), - ), + toolUse: { + toolUseId: toolCallMessage.tool_calls[0].id, + name: toolCallMessage.tool_calls[0]?.function?.name, + input: + typeof toolCallMessage.tool_calls[0]?.function + ?.arguments === "string" + ? JSON.parse( + toolCallMessage.tool_calls[0]?.function + ?.arguments, + ) + : toolCallMessage.tool_calls[0]?.function + ?.arguments || {}, + }, }, ], }, - ...toolCallResult.map((result) => ({ + { role: "user", content: [ { - toolUseId: result.tool_call_id, - content: [ - { - json: - typeof result.content === "string" - ? JSON.parse(result.content) - : result.content, - }, - ], + toolResult: { + toolUseId: toolCallResult[0].tool_call_id, + content: [ + { + json: { + content: toolCallResult[0].content, + }, + }, + ], + }, }, ], - })), + }, ); } else { console.warn( @@ -573,7 +576,8 @@ export class BedrockApi implements LLMApi { const message = extractMessage(resJson); options.onFinish(message, res); } catch (e) { - const error = e instanceof Error ? e : new Error('Unknown error occurred'); + const error = + e instanceof Error ? e : new Error("Unknown error occurred"); console.error("[Bedrock Client] Chat failed:", error.message); options.onError?.(error); } @@ -829,7 +833,7 @@ function bedrockStream( } catch (err) { console.error( "[Bedrock Stream]:", - err instanceof Error ? err.message : "Stream processing failed" + err instanceof Error ? err.message : "Stream processing failed", ); throw new Error("Failed to process stream response"); } finally { @@ -843,7 +847,7 @@ function bedrockStream( } console.error( "[Bedrock Request] Failed:", - e instanceof Error ? e.message : "Request failed" + e instanceof Error ? e.message : "Request failed", ); options.onError?.(e); throw new Error("Request processing failed"); diff --git a/app/utils.ts b/app/utils.ts index a2a4f21dc..9af7aadee 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -296,7 +296,8 @@ export function showPlugins(provider: ServiceProvider, model: string) { } if ( (provider == ServiceProvider.Bedrock && model.includes("claude-3")) || - model.includes("mistral-large") + model.includes("mistral-large") || + model.includes("amazon.nova") ) { return true; } diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 5fc78bc7e..a976e4d0d 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -435,25 +435,31 @@ export function processMessage( if (!data) return { remainText, index }; try { - // Handle Nova's tool calls + // Handle Nova's tool calls with exact schema match // console.log("processMessage data=========================",data); - if ( - data.stopReason === "tool_use" && - data.output?.message?.content?.[0]?.toolUse - ) { - const toolUse = data.output.message.content[0].toolUse; + if (data.contentBlockStart?.start?.toolUse) { + const toolUse = data.contentBlockStart.start.toolUse; index += 1; runTools.push({ - id: `tool-${Date.now()}`, + id: toolUse.toolUseId, type: "function", function: { - name: toolUse.name, - arguments: JSON.stringify(toolUse.input), + name: toolUse.name || "", // Ensure name is always present + arguments: "{}", // Initialize empty arguments }, }); return { remainText, index }; } + // Handle Nova's tool input in contentBlockDelta + if (data.contentBlockDelta?.delta?.toolUse?.input) { + if (runTools[index]) { + runTools[index].function.arguments = + data.contentBlockDelta.delta.toolUse.input; + } + return { remainText, index }; + } + // Handle Nova's text content if (data.output?.message?.content?.[0]?.text) { remainText += data.output.message.content[0].text; @@ -465,11 +471,9 @@ export function processMessage( return { remainText, index }; } - // Handle Nova's contentBlockDelta event - if (data.contentBlockDelta) { - if (data.contentBlockDelta.delta?.text) { - remainText += data.contentBlockDelta.delta.text; - } + // Handle Nova's text delta + if (data.contentBlockDelta?.delta?.text) { + remainText += data.contentBlockDelta.delta.text; return { remainText, index }; } @@ -496,7 +500,7 @@ export function processMessage( id: data.content_block.id, type: "function", function: { - name: data.content_block.name, + name: data.content_block.name || "", // Ensure name is always present arguments: "", }, }); @@ -523,7 +527,7 @@ export function processMessage( id: toolCall.id || `tool-${Date.now()}`, type: "function", function: { - name: toolCall.function?.name, + name: toolCall.function?.name || "", // Ensure name is always present arguments: toolCall.function?.arguments || "", }, }); @@ -554,7 +558,7 @@ export function processMessage( remainText += newText; } } catch (e) { - console.warn("Failed to process Bedrock message"); + console.warn("Failed to process Bedrock message:"); } return { remainText, index };