From 6f7a6350305c751f8ed5972c2206fee750347042 Mon Sep 17 00:00:00 2001 From: glay Date: Sun, 24 Nov 2024 23:54:04 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84llama=E5=92=8Cmistral?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E6=8E=A8=E7=90=86=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 48 +++++++-- app/client/platforms/bedrock.ts | 57 +++++++--- app/utils.ts | 5 +- app/utils/aws.ts | 183 ++++++++++++++++++++------------ 4 files changed, 204 insertions(+), 89 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index d65fc3f50..45aa84b2f 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -4,7 +4,6 @@ import { sign, decrypt, getBedrockEndpoint, - getModelHeaders, transformBedrockStream, parseEventData, BedrockCredentials, @@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) { } catch (e) { throw new Error(`Invalid JSON in request body: ${e}`); } + console.log( + "[Bedrock Request] original Body:", + JSON.stringify(bodyJson, null, 2), + ); // Extract tool configuration if present let tools: any[] | undefined; @@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) { modelId, shouldStream, ); - const additionalHeaders = getModelHeaders(modelId); console.log("[Bedrock Request] Endpoint:", endpoint); console.log("[Bedrock Request] Model ID:", modelId); - // Only include tools for Claude models - const isClaudeModel = modelId.toLowerCase().includes("claude3"); + // Handle tools for different models + const isMistralModel = modelId.toLowerCase().includes("mistral"); + const isClaudeModel = modelId.toLowerCase().includes("claude"); + const requestBody = { ...bodyJson, - ...(isClaudeModel && tools && { tools }), }; + if (tools && tools.length > 0) { + if (isClaudeModel) { + // Claude models already have correct tool format + requestBody.tools = tools; + } else if (isMistralModel) { + // Format messages for Mistral + if (typeof requestBody.prompt === "string") { + requestBody.messages = [ + { role: "user", content: requestBody.prompt }, + ]; + delete requestBody.prompt; + } + + // Add tools in Mistral's format + requestBody.tool_choice = "auto"; + requestBody.tools = tools.map((tool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }, + })); + } + } + // Sign request const headers = await sign({ method: "POST", @@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) { body: JSON.stringify(requestBody), service: "bedrock", isStreaming: shouldStream, - additionalHeaders, }); // Make request to AWS Bedrock console.log( - "[Bedrock Request] Body:", + "[Bedrock Request] Final Body:", JSON.stringify(requestBody, null, 2), ); const res = await fetch(endpoint, { @@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) { // Handle streaming response const transformedStream = transformBedrockStream(res.body, modelId); + const encoder = new TextEncoder(); const stream = new ReadableStream({ async start(controller) { try { for await (const chunk of transformedStream) { - controller.enqueue(new TextEncoder().encode(chunk)); + // Ensure we're sending non-empty chunks + if (chunk && chunk.trim()) { + controller.enqueue(encoder.encode(chunk)); + } } controller.close(); } catch (err) { diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 7d1bda3dd..1799e4dbc 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -37,7 +37,7 @@ export class BedrockApi implements LLMApi { if (model.startsWith("amazon.titan")) { const inputText = messages .map((message) => { - return `${message.role}: ${message.content}`; + return `${message.role}: ${getMessageTextContent(message)}`; }) .join("\n\n"); @@ -52,32 +52,59 @@ export class BedrockApi implements LLMApi { } // Handle LLaMA models - if (model.startsWith("us.meta.llama")) { - const prompt = messages - .map((message) => { - return `${message.role}: ${message.content}`; - }) - .join("\n\n"); + if (model.includes("meta.llama")) { + // Format conversation for Llama models + let prompt = ""; + let systemPrompt = ""; + + // Extract system message if present + const systemMessage = messages.find((m) => m.role === "system"); + if (systemMessage) { + systemPrompt = getMessageTextContent(systemMessage); + } + + // Format the conversation + const conversationMessages = messages.filter((m) => m.role !== "system"); + prompt = `[INST] <>\n${ + systemPrompt || "You are a helpful, respectful and honest assistant." + }\n<>\n\n`; + + for (let i = 0; i < conversationMessages.length; i++) { + const message = conversationMessages[i]; + const content = getMessageTextContent(message); + if (i === 0 && message.role === "user") { + // First user message goes in the same [INST] block as system prompt + prompt += `${content} [/INST]`; + } else { + if (message.role === "user") { + prompt += `\n\n[INST] ${content} [/INST]`; + } else { + prompt += ` ${content} `; + } + } + } + return { prompt, max_gen_len: modelConfig.max_tokens || 512, - temperature: modelConfig.temperature || 0.6, + temperature: modelConfig.temperature || 0.7, top_p: modelConfig.top_p || 0.9, - stop: ["User:", "System:", "Assistant:", "\n\n"], }; } // Handle Mistral models if (model.startsWith("mistral.mistral")) { - const prompt = messages - .map((message) => { - return `${message.role}: ${message.content}`; - }) - .join("\n\n"); + // Format messages for Mistral's chat format + const formattedMessages = messages.map((message) => ({ + role: message.role, + content: getMessageTextContent(message), + })); + return { - prompt, + messages: formattedMessages, max_tokens: modelConfig.max_tokens || 4096, temperature: modelConfig.temperature || 0.7, + top_p: modelConfig.top_p || 0.9, }; } diff --git a/app/utils.ts b/app/utils.ts index f47856729..30c2dde5d 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -292,7 +292,10 @@ export function showPlugins(provider: ServiceProvider, model: string) { if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { return true; } - if (provider == ServiceProvider.Bedrock && model.includes("claude-3")) { + if ( + (provider == ServiceProvider.Bedrock && model.includes("claude-3")) || + model.includes("mistral-large") + ) { return true; } if (provider == ServiceProvider.Google && !model.includes("vision")) { diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 75b359195..127e82cf9 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -75,7 +75,6 @@ export interface SignParams { body: string; service: string; isStreaming?: boolean; - additionalHeaders?: Record; } function hmac( @@ -160,7 +159,6 @@ export async function sign({ body, service, isStreaming = true, - additionalHeaders = {}, }: SignParams): Promise> { try { const endpoint = new URL(url); @@ -181,7 +179,6 @@ export async function sign({ host: endpoint.host, "x-amz-content-sha256": payloadHash, "x-amz-date": amzDate, - ...additionalHeaders, }; if (isStreaming) { @@ -311,32 +308,25 @@ export function getBedrockEndpoint( return endpoint; } -export function getModelHeaders(modelId: string): Record { - if (!modelId) { - throw new Error("Model ID is required for headers"); - } - - const headers: Record = {}; - - if ( - modelId.startsWith("us.meta.llama") || - modelId.startsWith("mistral.mistral") - ) { - headers["content-type"] = "application/json"; - headers["accept"] = "application/json"; - } - - return headers; -} - export function extractMessage(res: any, modelId: string = ""): string { if (!res) { console.error("[AWS Extract Error] extractMessage Empty response"); return ""; } console.log("[Response] extractMessage response: ", res); - return res?.content?.[0]?.text; - return ""; + + // Handle Mistral model response format + if (modelId.toLowerCase().includes("mistral")) { + return res?.outputs?.[0]?.text || ""; + } + + // Handle Llama model response format + if (modelId.toLowerCase().includes("llama")) { + return res?.generation || ""; + } + + // Handle Claude and other models + return res?.content?.[0]?.text || ""; } export async function* transformBedrockStream( @@ -344,58 +334,105 @@ export async function* transformBedrockStream( modelId: string, ) { const reader = stream.getReader(); - let buffer = ""; + let accumulatedText = ""; + let toolCallStarted = false; + let currentToolCall = null; try { while (true) { const { done, value } = await reader.read(); - if (done) { - if (buffer) { - yield `data: ${JSON.stringify({ - delta: { text: buffer }, - })}\n\n`; - } - break; - } + + if (done) break; const parsed = parseEventData(value); if (!parsed) continue; - // Handle Titan models - if (modelId.startsWith("amazon.titan")) { - const text = parsed.outputText || ""; - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } - // Handle LLaMA3 models - else if (modelId.startsWith("us.meta.llama3")) { - let text = ""; - if (parsed.generation) { - text = parsed.generation; - } else if (parsed.output) { - text = parsed.output; - } else if (typeof parsed === "string") { - text = parsed; - } - - if (text) { - // Clean up any control characters or invalid JSON characters - text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } + console.log("parseEventData========================="); + console.log(parsed); // Handle Mistral models - else if (modelId.startsWith("mistral.mistral")) { - const text = - parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; - if (text) { + if (modelId.toLowerCase().includes("mistral")) { + // If we have content, accumulate it + if ( + parsed.choices?.[0]?.message?.role === "assistant" && + parsed.choices?.[0]?.message?.content + ) { + accumulatedText += parsed.choices?.[0]?.message?.content; + console.log("accumulatedText========================="); + console.log(accumulatedText); + // Check for tool call in the accumulated text + if (!toolCallStarted && accumulatedText.includes("```json")) { + const jsonMatch = accumulatedText.match( + /```json\s*({[\s\S]*?})\s*```/, + ); + if (jsonMatch) { + try { + const toolData = JSON.parse(jsonMatch[1]); + currentToolCall = { + id: `tool-${Date.now()}`, + name: toolData.name, + arguments: toolData.arguments, + }; + + // Emit tool call start + yield `data: ${JSON.stringify({ + type: "content_block_start", + content_block: { + type: "tool_use", + id: currentToolCall.id, + name: currentToolCall.name, + }, + })}\n\n`; + + // Emit tool arguments + yield `data: ${JSON.stringify({ + type: "content_block_delta", + delta: { + type: "input_json_delta", + partial_json: JSON.stringify(currentToolCall.arguments), + }, + })}\n\n`; + + // Emit tool call stop + yield `data: ${JSON.stringify({ + type: "content_block_stop", + })}\n\n`; + + // Clear the accumulated text after processing the tool call + accumulatedText = accumulatedText.replace( + /```json\s*{[\s\S]*?}\s*```/, + "", + ); + toolCallStarted = false; + currentToolCall = null; + } catch (e) { + console.error("Failed to parse tool JSON:", e); + } + } + } + // emit the text content if it's not empty + if (parsed.choices?.[0]?.message?.content.trim()) { + yield `data: ${JSON.stringify({ + delta: { text: parsed.choices?.[0]?.message?.content }, + })}\n\n`; + } + // Handle stop reason if present + if (parsed.choices?.[0]?.stop_reason) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.choices[0].stop_reason }, + })}\n\n`; + } + } + } + // Handle Llama models + else if (modelId.toLowerCase().includes("llama")) { + if (parsed.generation) { yield `data: ${JSON.stringify({ - delta: { text }, + delta: { text: parsed.generation }, + })}\n\n`; + } + if (parsed.stop_reason) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.stop_reason }, })}\n\n`; } } @@ -423,6 +460,22 @@ export async function* transformBedrockStream( yield `data: ${JSON.stringify(parsed)}\n\n`; } else if (parsed.type === "content_block_stop") { yield `data: ${JSON.stringify(parsed)}\n\n`; + } else { + // Handle regular text responses + const text = parsed.response || parsed.output || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + } else { + // Handle other model text responses + const text = parsed.outputText || parsed.generation || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; } } }