完善mistral tool use功能

This commit is contained in:
glay 2024-11-26 10:10:34 +08:00
parent e6633753a4
commit 448babd27f
3 changed files with 331 additions and 62 deletions

View File

@ -184,7 +184,6 @@ async function requestBedrock(req: NextRequest) {
// Handle non-streaming response // Handle non-streaming response
if (!shouldStream) { if (!shouldStream) {
const responseText = await res.text(); const responseText = await res.text();
console.log("[Bedrock Response] Non-streaming:", responseText);
const parsed = parseEventData(new TextEncoder().encode(responseText)); const parsed = parseEventData(new TextEncoder().encode(responseText));
if (!parsed) { if (!parsed) {
throw new Error("Failed to parse Bedrock response"); throw new Error("Failed to parse Bedrock response");
@ -212,13 +211,18 @@ async function requestBedrock(req: NextRequest) {
}, },
}); });
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
newHeaders.set("Content-Type", "text/event-stream");
newHeaders.set("Cache-Control", "no-cache");
newHeaders.set("Connection", "keep-alive");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");
return new Response(stream, { return new Response(stream, {
headers: { status: res.status,
"Content-Type": "text/event-stream", statusText: res.statusText,
"Cache-Control": "no-cache", headers: newHeaders,
Connection: "keep-alive",
"X-Accel-Buffering": "no",
},
}); });
} catch (e) { } catch (e) {
console.error("[Bedrock Request Error]:", e); console.error("[Bedrock Request Error]:", e);
@ -232,10 +236,6 @@ export async function handle(
req: NextRequest, req: NextRequest,
{ params }: { params: { path: string[] } }, { params }: { params: { path: string[] } },
) { ) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const subpath = params.path.join("/"); const subpath = params.path.join("/");
if (!ALLOWED_PATH.has(subpath)) { if (!ALLOWED_PATH.has(subpath)) {
return NextResponse.json( return NextResponse.json(

View File

@ -245,7 +245,7 @@ export async function sign({
export function parseEventData(chunk: Uint8Array): any { export function parseEventData(chunk: Uint8Array): any {
const decoder = new TextDecoder(); const decoder = new TextDecoder();
const text = decoder.decode(chunk); const text = decoder.decode(chunk);
// console.info("[AWS Parse ] parsing:", text);
try { try {
const parsed = JSON.parse(text); const parsed = JSON.parse(text);
// AWS Bedrock wraps the response in a 'body' field // AWS Bedrock wraps the response in a 'body' field
@ -282,7 +282,6 @@ export function parseEventData(chunk: Uint8Array): any {
// Handle plain text responses // Handle plain text responses
if (text.trim()) { if (text.trim()) {
// Clean up any malformed JSON characters
const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
return { output: cleanText }; return { output: cleanText };
} }
@ -314,7 +313,6 @@ export function extractMessage(res: any, modelId: string = ""): string {
console.error("[AWS Extract Error] extractMessage Empty response"); console.error("[AWS Extract Error] extractMessage Empty response");
return ""; return "";
} }
// console.log("[Response] extractMessage response: ", res);
// Handle Mistral model response format // Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) { if (modelId.toLowerCase().includes("mistral")) {
@ -329,6 +327,11 @@ export function extractMessage(res: any, modelId: string = ""): string {
return res?.generation || ""; return res?.generation || "";
} }
// Handle Titan model response format
if (modelId.toLowerCase().includes("titan")) {
return res?.outputText || "";
}
// Handle Claude and other models // Handle Claude and other models
return res?.content?.[0]?.text || ""; return res?.content?.[0]?.text || "";
} }
@ -338,12 +341,10 @@ export async function* transformBedrockStream(
modelId: string, modelId: string,
) { ) {
const reader = stream.getReader(); const reader = stream.getReader();
let toolInput = "";
try { try {
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) break; if (done) break;
const parsed = parseEventData(value); const parsed = parseEventData(value);
@ -351,14 +352,40 @@ export async function* transformBedrockStream(
// console.log("parseEventData========================="); // console.log("parseEventData=========================");
// console.log(parsed); // console.log(parsed);
// Handle Claude 3 models
if (modelId.startsWith("anthropic.claude")) {
if (parsed.type === "message_start") {
// Initialize message
continue;
} else if (parsed.type === "content_block_start") {
if (parsed.content_block?.type === "tool_use") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
continue;
} else if (parsed.type === "content_block_delta") {
if (parsed.delta?.type === "text_delta") {
yield `data: ${JSON.stringify({
delta: { text: parsed.delta.text },
})}\n\n`;
} else if (parsed.delta?.type === "input_json_delta") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
} else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (
parsed.type === "message_delta" &&
parsed.delta?.stop_reason
) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.delta.stop_reason },
})}\n\n`;
}
}
// Handle Mistral models // Handle Mistral models
if (modelId.toLowerCase().includes("mistral")) { else if (modelId.toLowerCase().includes("mistral")) {
// Handle tool calls
if (parsed.choices?.[0]?.message?.tool_calls) { if (parsed.choices?.[0]?.message?.tool_calls) {
const toolCalls = parsed.choices[0].message.tool_calls; const toolCalls = parsed.choices[0].message.tool_calls;
for (const toolCall of toolCalls) { for (const toolCall of toolCalls) {
// Emit tool call start
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
type: "content_block_start", type: "content_block_start",
content_block: { content_block: {
@ -368,7 +395,6 @@ export async function* transformBedrockStream(
}, },
})}\n\n`; })}\n\n`;
// Emit tool arguments
if (toolCall.function?.arguments) { if (toolCall.function?.arguments) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
type: "content_block_delta", type: "content_block_delta",
@ -379,66 +405,51 @@ export async function* transformBedrockStream(
})}\n\n`; })}\n\n`;
} }
// Emit tool call stop
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
type: "content_block_stop", type: "content_block_stop",
})}\n\n`; })}\n\n`;
} }
continue; } else if (parsed.choices?.[0]?.message?.content) {
}
// Handle regular content
const content = parsed.choices?.[0]?.message?.content;
if (content?.trim()) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text: content }, delta: { text: parsed.choices[0].message.content },
})}\n\n`; })}\n\n`;
} }
// Handle stop reason
if (parsed.choices?.[0]?.finish_reason) { if (parsed.choices?.[0]?.finish_reason) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].finish_reason }, delta: { stop_reason: parsed.choices[0].finish_reason },
})}\n\n`; })}\n\n`;
} }
} }
// Handle Claude models // Handle Llama models
else if (modelId.startsWith("anthropic.claude")) { else if (modelId.toLowerCase().includes("llama")) {
if (parsed.type === "content_block_delta") { if (parsed.generation) {
if (parsed.delta?.type === "text_delta") {
yield `data: ${JSON.stringify({
delta: { text: parsed.delta.text },
})}\n\n`;
} else if (parsed.delta?.type === "input_json_delta") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
} else if (
parsed.type === "message_delta" &&
parsed.delta?.stop_reason
) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.delta.stop_reason }, delta: { text: parsed.generation },
})}\n\n`;
}
if (parsed.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
})}\n\n`; })}\n\n`;
} else if (
parsed.type === "content_block_start" &&
parsed.content_block?.type === "tool_use"
) {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else {
// Handle regular text responses
const text = parsed.response || parsed.output || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
} }
} }
// Handle other models // Handle Titan models
else if (modelId.toLowerCase().includes("titan")) {
if (parsed.outputText) {
yield `data: ${JSON.stringify({
delta: { text: parsed.outputText },
})}\n\n`;
}
if (parsed.completionReason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.completionReason },
})}\n\n`;
}
}
// Handle other models with basic text output
else { else {
const text = parsed.outputText || parsed.generation || ""; const text = parsed.response || parsed.output || "";
if (text) { if (text) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text }, delta: { text },

View File

@ -0,0 +1,258 @@
# Understanding Bedrock Response Format
The AWS Bedrock streaming response format consists of multiple Server-Sent Events (SSE) chunks. Each chunk follows this structure:
```
:event-type chunk
:content-type application/json
:message-type event
{"bytes":"base64_encoded_data","p":"signature"}
```
## Model-Specific Response Formats
### Claude 3 Format
When using Claude 3 models (e.g., claude-3-haiku-20240307), the decoded messages include:
1. **message_start**
```json
{
"type": "message_start",
"message": {
"id": "msg_bdrk_01A6sahWac4XVTR9sX3rgvsZ",
"type": "message",
"role": "assistant",
"model": "claude-3-haiku-20240307",
"content": [],
"stop_reason": null,
"stop_sequence": null,
"usage": {
"input_tokens": 8,
"output_tokens": 1
}
}
}
```
2. **content_block_start**
```json
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "text",
"text": ""
}
}
```
3. **content_block_delta**
```json
{
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": "Hello"
}
}
```
### Mistral Format
When using Mistral models (e.g., mistral-large-2407), the decoded messages have a different structure:
```json
{
"id": "b0098812-0ad9-42da-9f17-a5e2f554eb6b",
"object": "chat.completion.chunk",
"created": 1732582566,
"model": "mistral-large-2407",
"choices": [{
"index": 0,
"logprobs": null,
"context_logits": null,
"generation_logits": null,
"message": {
"role": null,
"content": "Hello",
"tool_calls": null,
"index": null,
"tool_call_id": null
},
"stop_reason": null
}],
"usage": null,
"p": null
}
```
### Llama Format
When using Llama models (3.1 or 3.2), the decoded messages use a simpler structure focused on generation tokens:
```json
{
"generation": "Hello",
"prompt_token_count": null,
"generation_token_count": 2,
"stop_reason": null
}
```
Each chunk contains:
- generation: The generated text piece
- prompt_token_count: Token count of the input (only present in first chunk)
- generation_token_count: Running count of generated tokens
- stop_reason: Indicates completion (null until final chunk)
First chunk example (includes prompt_token_count):
```json
{
"generation": "\n\n",
"prompt_token_count": 10,
"generation_token_count": 1,
"stop_reason": null
}
```
### Titan Text Format
When using Amazon's Titan models (text or TG1), the response comes as a single chunk with complete text and metrics:
```json
{
"outputText": "\nBot: Hello! How can I help you today?",
"index": 0,
"totalOutputTextTokenCount": 13,
"completionReason": "FINISH",
"inputTextTokenCount": 3,
"amazon-bedrock-invocationMetrics": {
"inputTokenCount": 3,
"outputTokenCount": 13,
"invocationLatency": 833,
"firstByteLatency": 833
}
}
```
Both Titan text and Titan TG1 use the same response format, with only minor differences in token counts and latency values. For example, here's a TG1 response:
```json
{
"outputText": "\nBot: Hello! How can I help you?",
"index": 0,
"totalOutputTextTokenCount": 12,
"completionReason": "FINISH",
"inputTextTokenCount": 3,
"amazon-bedrock-invocationMetrics": {
"inputTokenCount": 3,
"outputTokenCount": 12,
"invocationLatency": 845,
"firstByteLatency": 845
}
}
```
Key fields:
- outputText: The complete generated response
- totalOutputTextTokenCount: Total tokens in the response
- completionReason: Reason for completion (e.g., "FINISH")
- inputTextTokenCount: Number of input tokens
- amazon-bedrock-invocationMetrics: Detailed performance metrics
## Model-Specific Completion Metrics
### Mistral
```json
{
"usage": {
"prompt_tokens": 5,
"total_tokens": 29,
"completion_tokens": 24
},
"amazon-bedrock-invocationMetrics": {
"inputTokenCount": 5,
"outputTokenCount": 24,
"invocationLatency": 719,
"firstByteLatency": 148
}
}
```
### Claude 3
Included in the message_delta with stop_reason.
### Llama
Included in the final chunk with stop_reason "stop":
```json
{
"amazon-bedrock-invocationMetrics": {
"inputTokenCount": 10,
"outputTokenCount": 11,
"invocationLatency": 873,
"firstByteLatency": 550
}
}
```
### Titan
Both Titan text and TG1 include metrics in the single response chunk:
```json
{
"amazon-bedrock-invocationMetrics": {
"inputTokenCount": 3,
"outputTokenCount": 12,
"invocationLatency": 845,
"firstByteLatency": 845
}
}
```
## How the Response is Processed
1. The raw response is first split into chunks based on SSE format
2. For each chunk:
- The base64 encoded data is decoded
- The JSON is parsed to extract the message content
- Based on the model type and message type, different processing is applied:
### Claude 3 Processing
- message_start: Initializes a new message with model info and usage stats
- content_block_start: Starts a new content block (text, tool use, etc.)
- content_block_delta: Adds incremental content to the current block
- message_delta: Updates message metadata
### Mistral Processing
- Each chunk contains a complete message object with choices array
- The content is streamed through the message.content field
- Final chunk includes token usage and invocation metrics
### Llama Processing
- Each chunk contains a generation field with the text piece
- First chunk includes prompt_token_count
- Tracks generation progress through generation_token_count
- Simple streaming format focused on text generation
- Final chunk includes complete metrics
### Titan Processing
- Single chunk response with complete text
- No streaming - returns full response at once
- Includes comprehensive metrics in the same chunk
## Handling in Code
The response is processed by the `transformBedrockStream` function in `app/utils/aws.ts`, which:
1. Reads the stream chunks
2. Parses each chunk using `parseEventData`
3. Handles model-specific formats:
- For Claude: Processes message_start, content_block_start, content_block_delta
- For Mistral: Extracts content from choices[0].message.content
- For Llama: Uses the generation field directly
- For Titan: Uses the outputText field from the single response
4. Transforms the parsed data into a consistent format for the client
5. Yields the transformed data as SSE events
This allows for real-time streaming of the model's response while maintaining a consistent format for the client application, regardless of which model is being used.