From 448babd27f80601aaa43c7e71214435664981b72 Mon Sep 17 00:00:00 2001 From: glay Date: Tue, 26 Nov 2024 10:10:34 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84mistral=20tool=20use=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 22 +-- app/utils/aws.ts | 113 +++++++------- docs/bedrock-response-format.md | 258 ++++++++++++++++++++++++++++++++ 3 files changed, 331 insertions(+), 62 deletions(-) create mode 100644 docs/bedrock-response-format.md diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 79063e03c..78fc4b4dc 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -184,7 +184,6 @@ async function requestBedrock(req: NextRequest) { // Handle non-streaming response if (!shouldStream) { const responseText = await res.text(); - console.log("[Bedrock Response] Non-streaming:", responseText); const parsed = parseEventData(new TextEncoder().encode(responseText)); if (!parsed) { throw new Error("Failed to parse Bedrock response"); @@ -212,13 +211,18 @@ async function requestBedrock(req: NextRequest) { }, }); + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + newHeaders.set("Content-Type", "text/event-stream"); + newHeaders.set("Cache-Control", "no-cache"); + newHeaders.set("Connection", "keep-alive"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + return new Response(stream, { - headers: { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - "X-Accel-Buffering": "no", - }, + status: res.status, + statusText: res.statusText, + headers: newHeaders, }); } catch (e) { console.error("[Bedrock Request Error]:", e); @@ -232,10 +236,6 @@ export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { - if (req.method === "OPTIONS") { - return NextResponse.json({ body: "OK" }, { status: 200 }); - } - const subpath = params.path.join("/"); if (!ALLOWED_PATH.has(subpath)) { return NextResponse.json( diff --git a/app/utils/aws.ts b/app/utils/aws.ts index cb23f60e2..2f65e81d6 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -245,7 +245,7 @@ export async function sign({ export function parseEventData(chunk: Uint8Array): any { const decoder = new TextDecoder(); const text = decoder.decode(chunk); - // console.info("[AWS Parse ] parsing:", text); + try { const parsed = JSON.parse(text); // AWS Bedrock wraps the response in a 'body' field @@ -282,7 +282,6 @@ export function parseEventData(chunk: Uint8Array): any { // Handle plain text responses if (text.trim()) { - // Clean up any malformed JSON characters const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); return { output: cleanText }; } @@ -314,7 +313,6 @@ export function extractMessage(res: any, modelId: string = ""): string { console.error("[AWS Extract Error] extractMessage Empty response"); return ""; } - // console.log("[Response] extractMessage response: ", res); // Handle Mistral model response format if (modelId.toLowerCase().includes("mistral")) { @@ -329,6 +327,11 @@ export function extractMessage(res: any, modelId: string = ""): string { return res?.generation || ""; } + // Handle Titan model response format + if (modelId.toLowerCase().includes("titan")) { + return res?.outputText || ""; + } + // Handle Claude and other models return res?.content?.[0]?.text || ""; } @@ -338,12 +341,10 @@ export async function* transformBedrockStream( modelId: string, ) { const reader = stream.getReader(); - let toolInput = ""; try { while (true) { const { done, value } = await reader.read(); - if (done) break; const parsed = parseEventData(value); @@ -351,14 +352,40 @@ export async function* transformBedrockStream( // console.log("parseEventData========================="); // console.log(parsed); - + // Handle Claude 3 models + if (modelId.startsWith("anthropic.claude")) { + if (parsed.type === "message_start") { + // Initialize message + continue; + } else if (parsed.type === "content_block_start") { + if (parsed.content_block?.type === "tool_use") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + continue; + } else if (parsed.type === "content_block_delta") { + if (parsed.delta?.type === "text_delta") { + yield `data: ${JSON.stringify({ + delta: { text: parsed.delta.text }, + })}\n\n`; + } else if (parsed.delta?.type === "input_json_delta") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + } else if (parsed.type === "content_block_stop") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } else if ( + parsed.type === "message_delta" && + parsed.delta?.stop_reason + ) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.delta.stop_reason }, + })}\n\n`; + } + } // Handle Mistral models - if (modelId.toLowerCase().includes("mistral")) { - // Handle tool calls + else if (modelId.toLowerCase().includes("mistral")) { if (parsed.choices?.[0]?.message?.tool_calls) { const toolCalls = parsed.choices[0].message.tool_calls; for (const toolCall of toolCalls) { - // Emit tool call start yield `data: ${JSON.stringify({ type: "content_block_start", content_block: { @@ -368,7 +395,6 @@ export async function* transformBedrockStream( }, })}\n\n`; - // Emit tool arguments if (toolCall.function?.arguments) { yield `data: ${JSON.stringify({ type: "content_block_delta", @@ -379,66 +405,51 @@ export async function* transformBedrockStream( })}\n\n`; } - // Emit tool call stop yield `data: ${JSON.stringify({ type: "content_block_stop", })}\n\n`; } - continue; - } - - // Handle regular content - const content = parsed.choices?.[0]?.message?.content; - if (content?.trim()) { + } else if (parsed.choices?.[0]?.message?.content) { yield `data: ${JSON.stringify({ - delta: { text: content }, + delta: { text: parsed.choices[0].message.content }, })}\n\n`; } - // Handle stop reason if (parsed.choices?.[0]?.finish_reason) { yield `data: ${JSON.stringify({ delta: { stop_reason: parsed.choices[0].finish_reason }, })}\n\n`; } } - // Handle Claude models - else if (modelId.startsWith("anthropic.claude")) { - if (parsed.type === "content_block_delta") { - if (parsed.delta?.type === "text_delta") { - yield `data: ${JSON.stringify({ - delta: { text: parsed.delta.text }, - })}\n\n`; - } else if (parsed.delta?.type === "input_json_delta") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } - } else if ( - parsed.type === "message_delta" && - parsed.delta?.stop_reason - ) { + // Handle Llama models + else if (modelId.toLowerCase().includes("llama")) { + if (parsed.generation) { yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.delta.stop_reason }, + delta: { text: parsed.generation }, + })}\n\n`; + } + if (parsed.stop_reason) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.stop_reason }, })}\n\n`; - } else if ( - parsed.type === "content_block_start" && - parsed.content_block?.type === "tool_use" - ) { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } else if (parsed.type === "content_block_stop") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } else { - // Handle regular text responses - const text = parsed.response || parsed.output || ""; - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } } } - // Handle other models + // Handle Titan models + else if (modelId.toLowerCase().includes("titan")) { + if (parsed.outputText) { + yield `data: ${JSON.stringify({ + delta: { text: parsed.outputText }, + })}\n\n`; + } + if (parsed.completionReason) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.completionReason }, + })}\n\n`; + } + } + // Handle other models with basic text output else { - const text = parsed.outputText || parsed.generation || ""; + const text = parsed.response || parsed.output || ""; if (text) { yield `data: ${JSON.stringify({ delta: { text }, diff --git a/docs/bedrock-response-format.md b/docs/bedrock-response-format.md new file mode 100644 index 000000000..0b0d05c59 --- /dev/null +++ b/docs/bedrock-response-format.md @@ -0,0 +1,258 @@ +# Understanding Bedrock Response Format + +The AWS Bedrock streaming response format consists of multiple Server-Sent Events (SSE) chunks. Each chunk follows this structure: + +``` +:event-type chunk +:content-type application/json +:message-type event +{"bytes":"base64_encoded_data","p":"signature"} +``` + +## Model-Specific Response Formats + +### Claude 3 Format + +When using Claude 3 models (e.g., claude-3-haiku-20240307), the decoded messages include: + +1. **message_start** +```json +{ + "type": "message_start", + "message": { + "id": "msg_bdrk_01A6sahWac4XVTR9sX3rgvsZ", + "type": "message", + "role": "assistant", + "model": "claude-3-haiku-20240307", + "content": [], + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 8, + "output_tokens": 1 + } + } +} +``` + +2. **content_block_start** +```json +{ + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "text", + "text": "" + } +} +``` + +3. **content_block_delta** +```json +{ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "Hello" + } +} +``` + +### Mistral Format + +When using Mistral models (e.g., mistral-large-2407), the decoded messages have a different structure: + +```json +{ + "id": "b0098812-0ad9-42da-9f17-a5e2f554eb6b", + "object": "chat.completion.chunk", + "created": 1732582566, + "model": "mistral-large-2407", + "choices": [{ + "index": 0, + "logprobs": null, + "context_logits": null, + "generation_logits": null, + "message": { + "role": null, + "content": "Hello", + "tool_calls": null, + "index": null, + "tool_call_id": null + }, + "stop_reason": null + }], + "usage": null, + "p": null +} +``` + +### Llama Format + +When using Llama models (3.1 or 3.2), the decoded messages use a simpler structure focused on generation tokens: + +```json +{ + "generation": "Hello", + "prompt_token_count": null, + "generation_token_count": 2, + "stop_reason": null +} +``` + +Each chunk contains: +- generation: The generated text piece +- prompt_token_count: Token count of the input (only present in first chunk) +- generation_token_count: Running count of generated tokens +- stop_reason: Indicates completion (null until final chunk) + +First chunk example (includes prompt_token_count): +```json +{ + "generation": "\n\n", + "prompt_token_count": 10, + "generation_token_count": 1, + "stop_reason": null +} +``` + +### Titan Text Format + +When using Amazon's Titan models (text or TG1), the response comes as a single chunk with complete text and metrics: + +```json +{ + "outputText": "\nBot: Hello! How can I help you today?", + "index": 0, + "totalOutputTextTokenCount": 13, + "completionReason": "FINISH", + "inputTextTokenCount": 3, + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 3, + "outputTokenCount": 13, + "invocationLatency": 833, + "firstByteLatency": 833 + } +} +``` + +Both Titan text and Titan TG1 use the same response format, with only minor differences in token counts and latency values. For example, here's a TG1 response: + +```json +{ + "outputText": "\nBot: Hello! How can I help you?", + "index": 0, + "totalOutputTextTokenCount": 12, + "completionReason": "FINISH", + "inputTextTokenCount": 3, + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 3, + "outputTokenCount": 12, + "invocationLatency": 845, + "firstByteLatency": 845 + } +} +``` + +Key fields: +- outputText: The complete generated response +- totalOutputTextTokenCount: Total tokens in the response +- completionReason: Reason for completion (e.g., "FINISH") +- inputTextTokenCount: Number of input tokens +- amazon-bedrock-invocationMetrics: Detailed performance metrics + +## Model-Specific Completion Metrics + +### Mistral +```json +{ + "usage": { + "prompt_tokens": 5, + "total_tokens": 29, + "completion_tokens": 24 + }, + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 5, + "outputTokenCount": 24, + "invocationLatency": 719, + "firstByteLatency": 148 + } +} +``` + +### Claude 3 +Included in the message_delta with stop_reason. + +### Llama +Included in the final chunk with stop_reason "stop": +```json +{ + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 10, + "outputTokenCount": 11, + "invocationLatency": 873, + "firstByteLatency": 550 + } +} +``` + +### Titan +Both Titan text and TG1 include metrics in the single response chunk: +```json +{ + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 3, + "outputTokenCount": 12, + "invocationLatency": 845, + "firstByteLatency": 845 + } +} +``` + +## How the Response is Processed + +1. The raw response is first split into chunks based on SSE format +2. For each chunk: + - The base64 encoded data is decoded + - The JSON is parsed to extract the message content + - Based on the model type and message type, different processing is applied: + +### Claude 3 Processing +- message_start: Initializes a new message with model info and usage stats +- content_block_start: Starts a new content block (text, tool use, etc.) +- content_block_delta: Adds incremental content to the current block +- message_delta: Updates message metadata + +### Mistral Processing +- Each chunk contains a complete message object with choices array +- The content is streamed through the message.content field +- Final chunk includes token usage and invocation metrics + +### Llama Processing +- Each chunk contains a generation field with the text piece +- First chunk includes prompt_token_count +- Tracks generation progress through generation_token_count +- Simple streaming format focused on text generation +- Final chunk includes complete metrics + +### Titan Processing +- Single chunk response with complete text +- No streaming - returns full response at once +- Includes comprehensive metrics in the same chunk + +## Handling in Code + +The response is processed by the `transformBedrockStream` function in `app/utils/aws.ts`, which: + +1. Reads the stream chunks +2. Parses each chunk using `parseEventData` +3. Handles model-specific formats: + - For Claude: Processes message_start, content_block_start, content_block_delta + - For Mistral: Extracts content from choices[0].message.content + - For Llama: Uses the generation field directly + - For Titan: Uses the outputText field from the single response +4. Transforms the parsed data into a consistent format for the client +5. Yields the transformed data as SSE events + +This allows for real-time streaming of the model's response while maintaining a consistent format for the client application, regardless of which model is being used.