diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 1417e9475..d0da02ed0 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -7,32 +7,117 @@ function parseEventData(chunk: Uint8Array): any { const decoder = new TextDecoder(); const text = decoder.decode(chunk); try { - return JSON.parse(text); + const parsed = JSON.parse(text); + // AWS Bedrock wraps the response in a 'body' field + if (typeof parsed.body === "string") { + try { + return JSON.parse(parsed.body); + } catch (e) { + return { output: parsed.body }; + } + } + return parsed.body || parsed; } catch (e) { + console.error("Error parsing event data:", e); try { + // Handle base64 encoded responses const base64Match = text.match(/:"([A-Za-z0-9+/=]+)"/); if (base64Match) { const decoded = Buffer.from(base64Match[1], "base64").toString("utf-8"); - return JSON.parse(decoded); + try { + return JSON.parse(decoded); + } catch (e) { + return { output: decoded }; + } } + + // Handle event-type responses const eventMatch = text.match(/:event-type[^\{]+({.*})/); if (eventMatch) { - return JSON.parse(eventMatch[1]); + try { + return JSON.parse(eventMatch[1]); + } catch (e) { + return { output: eventMatch[1] }; + } } - } catch (innerError) {} + + // Handle plain text responses + if (text.trim()) { + // Clean up any malformed JSON characters + const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); + return { output: cleanText }; + } + } catch (innerError) { + console.error("Error in fallback parsing:", innerError); + } } return null; } -async function* transformBedrockStream(stream: ReadableStream) { +async function* transformBedrockStream( + stream: ReadableStream, + modelId: string, +) { const reader = stream.getReader(); + let buffer = ""; + try { while (true) { const { done, value } = await reader.read(); - if (done) break; + if (done) { + if (buffer) { + yield `data: ${JSON.stringify({ + delta: { text: buffer }, + })}\n\n`; + } + break; + } const parsed = parseEventData(value); - if (parsed) { + if (!parsed) continue; + + console.log("Parsed response:", JSON.stringify(parsed, null, 2)); + + // Handle Titan models + if (modelId.startsWith("amazon.titan")) { + const text = parsed.outputText || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle LLaMA3 models + else if (modelId.startsWith("us.meta.llama3")) { + let text = ""; + if (parsed.generation) { + text = parsed.generation; + } else if (parsed.output) { + text = parsed.output; + } else if (typeof parsed === "string") { + text = parsed; + } + + if (text) { + // Clean up any control characters or invalid JSON characters + text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle Mistral models + else if (modelId.startsWith("mistral.mistral")) { + const text = + parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle Claude models + else if (modelId.startsWith("anthropic.claude")) { if (parsed.type === "content_block_delta") { if (parsed.delta?.type === "text_delta") { yield `data: ${JSON.stringify({ @@ -66,6 +151,8 @@ async function* transformBedrockStream(stream: ReadableStream) { function validateRequest(body: any, modelId: string): void { if (!modelId) throw new Error("Model ID is required"); + const bodyContent = body.body || body; + if (modelId.startsWith("anthropic.claude")) { if ( !body.anthropic_version || @@ -82,13 +169,14 @@ function validateRequest(body: any, modelId: string): void { } else if (typeof body.prompt !== "string") { throw new Error("prompt is required for Claude 2 and earlier"); } - } else if (modelId.startsWith("meta.llama")) { - if (!body.prompt) throw new Error("Llama requires a prompt"); + } else if (modelId.startsWith("us.meta.llama3")) { + if (!bodyContent.prompt) { + throw new Error("prompt is required for LLaMA3 models"); + } } else if (modelId.startsWith("mistral.mistral")) { - if (!Array.isArray(body.messages)) - throw new Error("Mistral requires a messages array"); + if (!bodyContent.prompt) throw new Error("Mistral requires a prompt"); } else if (modelId.startsWith("amazon.titan")) { - if (!body.inputText) throw new Error("Titan requires inputText"); + if (!bodyContent.inputText) throw new Error("Titan requires inputText"); } } @@ -114,14 +202,35 @@ async function requestBedrock(req: NextRequest) { throw new Error("Failed to decrypt AWS credentials"); } - const endpoint = `https://bedrock-runtime.${awsRegion}.amazonaws.com/model/${modelId}/invoke-with-response-stream`; + // Construct the base endpoint + const baseEndpoint = `https://bedrock-runtime.${awsRegion}.amazonaws.com`; + + // Set up timeout const timeoutId = setTimeout(() => controller.abort(), 10 * 60 * 1000); try { + // Determine the endpoint and request body based on model type + let endpoint; + let requestBody; + let additionalHeaders = {}; + const bodyText = await req.clone().text(); + if (!bodyText) { + throw new Error("Request body is empty"); + } + const bodyJson = JSON.parse(bodyText); validateRequest(bodyJson, modelId); - const canonicalBody = JSON.stringify(bodyJson); + + // For all other models, use standard endpoint + endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; + requestBody = JSON.stringify(bodyJson.body || bodyJson); + + console.log("Request to AWS Bedrock:", { + endpoint, + modelId, + body: requestBody, + }); const headers = await sign({ method: "POST", @@ -130,14 +239,17 @@ async function requestBedrock(req: NextRequest) { accessKeyId: decryptedAccessKey, secretAccessKey: decryptedSecretKey, sessionToken: decryptedSessionToken, - body: canonicalBody, + body: requestBody, service: "bedrock", }); const res = await fetch(endpoint, { method: "POST", - headers, - body: canonicalBody, + headers: { + ...headers, + ...additionalHeaders, + }, + body: requestBody, redirect: "manual", // @ts-ignore duplex: "half", @@ -146,15 +258,20 @@ async function requestBedrock(req: NextRequest) { if (!res.ok) { const error = await res.text(); + console.error("AWS Bedrock error response:", error); try { const errorJson = JSON.parse(error); throw new Error(errorJson.message || error); } catch { - throw new Error(error); + throw new Error(error || "Failed to get response from Bedrock"); } } - const transformedStream = transformBedrockStream(res.body!); + if (!res.body) { + throw new Error("Empty response from Bedrock"); + } + + const transformedStream = transformBedrockStream(res.body, modelId); const stream = new ReadableStream({ async start(controller) { try { @@ -163,6 +280,7 @@ async function requestBedrock(req: NextRequest) { } controller.close(); } catch (err) { + console.error("Stream error:", err); controller.error(err); } }, @@ -177,6 +295,7 @@ async function requestBedrock(req: NextRequest) { }, }); } catch (e) { + console.error("Request error:", e); throw e; } finally { clearTimeout(timeoutId); @@ -202,6 +321,7 @@ export async function handle( try { return await requestBedrock(req); } catch (e) { + console.error("Handler error:", e); return NextResponse.json( { error: true, msg: e instanceof Error ? e.message : "Unknown error" }, { status: 500 }, diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 4c6371b17..c13aa4410 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -1,4 +1,11 @@ -import { ChatOptions, LLMApi, SpeechOptions } from "../api"; +import { + ChatOptions, + LLMApi, + SpeechOptions, + RequestMessage, + MultimodalContent, + MessageRole, +} from "../api"; import { useAppConfig, usePluginStore, @@ -15,6 +22,8 @@ const ClaudeMapper = { system: "user", } as const; +type ClaudeRole = keyof typeof ClaudeMapper; + interface ToolDefinition { function?: { name: string; @@ -28,44 +37,131 @@ export class BedrockApi implements LLMApi { throw new Error("Speech not implemented for Bedrock."); } - extractMessage(res: any) { - if (res?.content?.[0]?.text) return res.content[0].text; - if (res?.messages?.[0]?.content?.[0]?.text) - return res.messages[0].content[0].text; - if (res?.delta?.text) return res.delta.text; - return ""; + extractMessage(res: any, modelId: string = "") { + try { + // Handle Titan models + if (modelId.startsWith("amazon.titan")) { + if (res?.delta?.text) return res.delta.text; + return res?.outputText || ""; + } + + // Handle LLaMA models + if (modelId.startsWith("us.meta.llama3")) { + if (res?.delta?.text) return res.delta.text; + if (res?.generation) return res.generation; + if (typeof res?.output === "string") return res.output; + if (typeof res === "string") return res; + return ""; + } + + // Handle Mistral models + if (modelId.startsWith("mistral.mistral")) { + if (res?.delta?.text) return res.delta.text; + return res?.outputs?.[0]?.text || res?.output || res?.completion || ""; + } + + // Handle Claude models and fallback cases + if (res?.content?.[0]?.text) return res.content[0].text; + if (res?.messages?.[0]?.content?.[0]?.text) + return res.messages[0].content[0].text; + if (res?.delta?.text) return res.delta.text; + if (res?.completion) return res.completion; + if (res?.generation) return res.generation; + if (res?.outputText) return res.outputText; + if (res?.output) return res.output; + + if (typeof res === "string") return res; + + return ""; + } catch (e) { + console.error("Error extracting message:", e); + return ""; + } } - async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); - const isClaude3 = options.config.model.startsWith("anthropic.claude-3"); + formatRequestBody( + messages: RequestMessage[], + systemMessage: string, + modelConfig: any, + ) { + const model = modelConfig.model; - const modelConfig = { - ...useAppConfig.getState().modelConfig, - ...useChatStore.getState().currentSession().mask.modelConfig, - model: options.config.model, - }; - - let systemMessage = ""; - const messages = []; - for (const msg of options.messages) { - const content = await preProcessImageContent(msg.content); - if (msg.role === "system") { - systemMessage = getMessageTextContent(msg); - } else { - messages.push({ role: msg.role, content }); - } + // Handle Titan models + if (model.startsWith("amazon.titan")) { + const allMessages = systemMessage + ? [ + { role: "system" as MessageRole, content: systemMessage }, + ...messages, + ] + : messages; + const inputText = allMessages + .map((m) => `${m.role}: ${getMessageTextContent(m)}`) + .join("\n"); + return { + body: { + inputText, + textGenerationConfig: { + maxTokenCount: modelConfig.max_tokens, + temperature: modelConfig.temperature, + stopSequences: [], + }, + }, + }; } + // Handle LLaMA3 models - simplified format + if (model.startsWith("us.meta.llama3")) { + const allMessages = systemMessage + ? [ + { role: "system" as MessageRole, content: systemMessage }, + ...messages, + ] + : messages; + + const prompt = allMessages + .map((m) => `${m.role}: ${getMessageTextContent(m)}`) + .join("\n"); + + return { + contentType: "application/json", + accept: "application/json", + body: { + prompt, + }, + }; + } + + // Handle Mistral models + if (model.startsWith("mistral.mistral")) { + const allMessages = systemMessage + ? [ + { role: "system" as MessageRole, content: systemMessage }, + ...messages, + ] + : messages; + const prompt = allMessages + .map((m) => `${m.role}: ${getMessageTextContent(m)}`) + .join("\n"); + return { + body: { + prompt, + temperature: modelConfig.temperature || 0.7, + max_tokens: modelConfig.max_tokens || 4096, + }, + }; + } + + // Handle Claude models (existing implementation) + const isClaude3 = model.startsWith("anthropic.claude-3"); const formattedMessages = messages .filter( (v) => v.content && (typeof v.content !== "string" || v.content.trim()), ) .map((v) => { const { role, content } = v; - const insideRole = ClaudeMapper[role] ?? "user"; + const insideRole = ClaudeMapper[role as ClaudeRole] ?? "user"; - if (!visionModel || typeof content === "string") { + if (!isVisionModel(model) || typeof content === "string") { return { role: insideRole, content: [{ type: "text", text: getMessageTextContent(v) }], @@ -74,7 +170,7 @@ export class BedrockApi implements LLMApi { return { role: insideRole, - content: content + content: (content as MultimodalContent[]) .filter((v) => v.image_url || v.text) .map(({ type, text, image_url }) => { if (type === "text") return { type, text: text! }; @@ -96,17 +192,40 @@ export class BedrockApi implements LLMApi { }; }); - const requestBody = { + return { anthropic_version: "bedrock-2023-05-31", max_tokens: modelConfig.max_tokens, messages: formattedMessages, ...(systemMessage && { system: systemMessage }), - ...(modelConfig.temperature !== undefined && { - temperature: modelConfig.temperature, - }), - ...(modelConfig.top_p !== undefined && { top_p: modelConfig.top_p }), + temperature: modelConfig.temperature, ...(isClaude3 && { top_k: 5 }), }; + } + + async chat(options: ChatOptions) { + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + model: options.config.model, + }; + + let systemMessage = ""; + const messages = []; + for (const msg of options.messages) { + const content = await preProcessImageContent(msg.content); + if (msg.role === "system") { + systemMessage = getMessageTextContent(msg); + } else { + messages.push({ role: msg.role, content }); + } + } + + const requestBody = this.formatRequestBody( + messages, + systemMessage, + modelConfig, + ); + // console.log("Request body:", JSON.stringify(requestBody, null, 2)); const controller = new AbortController(); options.onController?.(controller); @@ -121,7 +240,8 @@ export class BedrockApi implements LLMApi { try { const apiEndpoint = "/api/bedrock/chat"; const headers = { - "Content-Type": "application/json", + "Content-Type": requestBody.contentType || "application/json", + Accept: requestBody.accept || "application/json", "X-Region": accessStore.awsRegion, "X-Access-Key": accessStore.awsAccessKey, "X-Secret-Key": accessStore.awsSecretKey, @@ -154,6 +274,7 @@ export class BedrockApi implements LLMApi { (text: string, runTools: ChatMessageTool[]) => { try { const chunkJson = JSON.parse(text); + // console.log("Received chunk:", JSON.stringify(chunkJson, null, 2)); if (chunkJson?.content_block?.type === "tool_use") { index += 1; currentToolArgs = ""; @@ -193,8 +314,11 @@ export class BedrockApi implements LLMApi { runTools[index].function!.arguments = currentToolArgs; } catch (e) {} } - return this.extractMessage(chunkJson); + const message = this.extractMessage(chunkJson, modelConfig.model); + // console.log("Extracted message:", message); + return message; } catch (e) { + console.error("Error parsing chunk:", e); return ""; } }, @@ -251,10 +375,13 @@ export class BedrockApi implements LLMApi { }); const resJson = await res.json(); - const message = this.extractMessage(resJson); + // console.log("Response:", JSON.stringify(resJson, null, 2)); + const message = this.extractMessage(resJson, modelConfig.model); + // console.log("Extracted message:", message); options.onFinish(message, res); } } catch (e) { + console.error("Chat error:", e); options.onError?.(e as Error); } } diff --git a/app/constant.ts b/app/constant.ts index 32b051c76..5ac589167 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -330,40 +330,24 @@ const bedrockModels = [ // Amazon Titan Models "amazon.titan-text-express-v1", "amazon.titan-text-lite-v1", - "amazon.titan-text-agile-v1", - - // Cohere Models - "cohere.command-light-text-v14", - "cohere.command-r-plus-v1:0", - "cohere.command-r-v1:0", - "cohere.command-text-v14", - + "amazon.titan-tg1-large", // Claude Models "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-5-haiku-20241022-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-5-sonnet-20241022-v2:0", "anthropic.claude-3-opus-20240229-v1:0", - "anthropic.claude-2.1", - "anthropic.claude-v2", - "anthropic.claude-v1", - "anthropic.claude-instant-v1", - // Meta Llama Models - "meta.llama2-13b-chat-v1", - "meta.llama2-70b-chat-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-2-11b-instruct-v1:0", - "meta.llama3-2-90b-instruct-v1:0", - + "us.meta.llama3-1-8b-instruct-v1:0", + "us.meta.llama3-1-70b-instruct-v1:0", + "us.meta.llama3-2-1b-instruct-v1:0", + "us.meta.llama3-2-3b-instruct-v1:0", + "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0", // Mistral Models "mistral.mistral-7b-instruct-v0:2", "mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2407-v1:0", - - // AI21 Models - "ai21.j2-mid-v1", - "ai21.j2-ultra-v1", ]; const googleModels = [