diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 043fa7aa2..5e1f9f0a2 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -3,22 +3,49 @@ import { ChatOptions, getHeaders, LLMApi, + LLMModel, LLMUsage, - MultimodalContent, SpeechOptions, } from "../api"; -import { useAccessStore, useAppConfig } from "../../store"; -import Locale from "../../locales"; import { - getMessageImages, - getMessageTextContent, - isVisionModel, -} from "../../utils"; + useAccessStore, + useAppConfig, + usePluginStore, + useChatStore, + ChatMessageTool, +} from "../../store"; +import { getMessageTextContent, isVisionModel } from "../../utils"; import { fetch } from "../../utils/stream"; +import { preProcessImageContent, stream } from "../../utils/chat"; -const MAX_IMAGE_SIZE = 1024 * 1024 * 4; // 4MB limit +export type MultiBlockContent = { + type: "image" | "text"; + source?: { + type: string; + media_type: string; + data: string; + }; + text?: string; +}; + +export type AnthropicMessage = { + role: (typeof ClaudeMapper)[keyof typeof ClaudeMapper]; + content: string | MultiBlockContent[]; +}; + +const ClaudeMapper = { + assistant: "assistant", + user: "user", + system: "user", +} as const; export class BedrockApi implements LLMApi { + usage(): Promise { + throw new Error("Method not implemented."); + } + models(): Promise { + throw new Error("Method not implemented."); + } speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); } @@ -31,154 +58,17 @@ export class BedrockApi implements LLMApi { return res; } - async processDocument( - file: File, - ): Promise<{ display: string; content: MultimodalContent }> { - return new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.onload = async () => { - try { - const arrayBuffer = reader.result as ArrayBuffer; - const format = file.name.split(".").pop()?.toLowerCase(); - - if (!format) { - throw new Error("Could not determine file format"); - } - - // Format file size - const size = file.size; - let sizeStr = ""; - if (size < 1024) { - sizeStr = size + " B"; - } else if (size < 1024 * 1024) { - sizeStr = (size / 1024).toFixed(2) + " KB"; - } else { - sizeStr = (size / (1024 * 1024)).toFixed(2) + " MB"; - } - - // Create display text - const displayText = `Document: ${file.name} (${sizeStr})`; - - // Create actual content - const content: MultimodalContent = { - type: "document", - document: { - format: format as - | "pdf" - | "csv" - | "doc" - | "docx" - | "xls" - | "xlsx" - | "html" - | "txt" - | "md", - name: file.name, - source: { - bytes: Buffer.from(arrayBuffer).toString("base64"), - }, - }, - }; - - resolve({ - display: displayText, - content: content, - }); - } catch (e) { - reject(e); - } - }; - reader.onerror = () => reject(reader.error); - reader.readAsArrayBuffer(file); - }); - } - - async processImage(url: string): Promise { - if (url.startsWith("data:")) { - const base64Match = url.match(/^data:image\/([a-zA-Z]*);base64,([^"]*)/); - if (base64Match) { - const format = base64Match[1].toLowerCase(); - const base64Data = base64Match[2]; - - // Check base64 size - const binarySize = atob(base64Data).length; - if (binarySize > MAX_IMAGE_SIZE) { - throw new Error( - `Image size (${(binarySize / (1024 * 1024)).toFixed( - 2, - )}MB) exceeds maximum allowed size of 4MB`, - ); - } - - return { - type: "image_url", - image_url: { - url: url, - }, - }; - } - throw new Error("Invalid data URL format"); - } - - // For non-data URLs, fetch and convert to base64 - try { - const response = await fetch(url); - if (!response.ok) { - throw new Error(`Failed to fetch image: ${response.statusText}`); - } - - const blob = await response.blob(); - if (blob.size > MAX_IMAGE_SIZE) { - throw new Error( - `Image size (${(blob.size / (1024 * 1024)).toFixed( - 2, - )}MB) exceeds maximum allowed size of 4MB`, - ); - } - - const reader = new FileReader(); - const base64 = await new Promise((resolve, reject) => { - reader.onloadend = () => resolve(reader.result as string); - reader.onerror = () => reject(new Error("Failed to read image data")); - reader.readAsDataURL(blob); - }); - - return { - type: "image_url", - image_url: { - url: base64, - }, - }; - } catch (error) { - console.error("[Bedrock] Image processing error:", error); - throw error; - } - } - async chat(options: ChatOptions): Promise { + const visionModel = isVisionModel(options.config.model); const accessStore = useAccessStore.getState(); + const shouldStream = !!options.config.stream; const modelConfig = { ...useAppConfig.getState().modelConfig, - ...options.config, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, }; - - if ( - !accessStore.awsRegion || - !accessStore.awsAccessKeyId || - !accessStore.awsSecretAccessKey - ) { - console.log("AWS credentials are not set"); - let responseText = ""; - const responseTexts = [responseText]; - responseTexts.push(Locale.Error.Unauthorized); - responseText = responseTexts.join("\n\n"); - options.onFinish(responseText); - return; - } - - const controller = new AbortController(); - options.onController?.(controller); - const headers: Record = { ...getHeaders(), "X-Region": accessStore.awsRegion, @@ -186,200 +76,212 @@ export class BedrockApi implements LLMApi { "X-Secret-Key": accessStore.awsSecretAccessKey, }; - if (accessStore.awsSessionToken) { - headers["X-Session-Token"] = accessStore.awsSessionToken; + // try get base64image from local cache image_url + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = await preProcessImageContent(v.content); + messages.push({ role: v.role, content }); } - try { - // Process messages to handle multimodal content - const messages = await Promise.all( - options.messages.map(async (msg) => { - if (Array.isArray(msg.content)) { - // For vision models, include both text and images - if (isVisionModel(options.config.model)) { - const images = getMessageImages(msg); - const content: MultimodalContent[] = []; + const keys = ["system", "user"]; - // Process documents first - for (const item of msg.content) { - // Check for document content - if (item && typeof item === "object") { - if ("file" in item && item.file instanceof File) { - try { - console.log( - "[Bedrock] Processing document:", - item.file.name, - ); - const { content: docContent } = - await this.processDocument(item.file); - content.push(docContent); - } catch (e) { - console.error("[Bedrock] Failed to process document:", e); - } - } else if ("document" in item && item.document) { - // If document content is already processed, include it directly - content.push(item as MultimodalContent); - } - } - } + // roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages + for (let i = 0; i < messages.length - 1; i++) { + const message = messages[i]; + const nextMessage = messages[i + 1]; - // Add text content if it's not a document display text - const text = getMessageTextContent(msg); - if (text && !text.startsWith("Document: ")) { - content.push({ type: "text", text }); - } + if (keys.includes(message.role) && keys.includes(nextMessage.role)) { + messages[i] = [ + message, + { + role: "assistant", + content: ";", + }, + ] as any; + } + } - // Process images with size check and error handling - for (const url of images) { - try { - const imageContent = await this.processImage(url); - content.push(imageContent); - } catch (e) { - console.error("[Bedrock] Failed to process image:", e); - // Add error message as text content - content.push({ - type: "text", - text: `Error processing image: ${e}`, - }); - } - } + const prompt = messages + .flat() + .filter((v) => { + if (!v.content) return false; + if (typeof v.content === "string" && !v.content.trim()) return false; + return true; + }) + .map((v) => { + const { role, content } = v; + const insideRole = ClaudeMapper[role] ?? "user"; - // Only return content if there is any - if (content.length > 0) { - return { ...msg, content }; - } - } - // For non-vision models, only include text - return { ...msg, content: getMessageTextContent(msg) }; - } - return msg; - }), - ); - - // Filter out empty messages - const filteredMessages = messages.filter((msg) => { - if (Array.isArray(msg.content)) { - return msg.content.length > 0; + if (!visionModel || typeof content === "string") { + return { + role: insideRole, + content: getMessageTextContent(v), + }; } - return msg.content !== ""; + return { + role: insideRole, + content: content + .filter((v) => v.image_url || v.text) + .map(({ type, text, image_url }) => { + if (type === "text") { + return { + type, + text: text!, + }; + } + const { url = "" } = image_url || {}; + const colonIndex = url.indexOf(":"); + const semicolonIndex = url.indexOf(";"); + const comma = url.indexOf(","); + + const mimeType = url.slice(colonIndex + 1, semicolonIndex); + const encodeType = url.slice(semicolonIndex + 1, comma); + const data = url.slice(comma + 1); + + return { + type: "image" as const, + source: { + type: encodeType, + media_type: mimeType, + data, + }, + }; + }), + }; }); - const requestBody = { - messages: filteredMessages, - modelId: options.config.model, - inferenceConfig: { - maxTokens: modelConfig.max_tokens, - temperature: modelConfig.temperature, - topP: modelConfig.top_p, - stopSequences: [], - }, - }; + if (prompt[0]?.role === "assistant") { + prompt.unshift({ + role: "user", + content: ";", + }); + } - console.log( - "[Bedrock] Request body:", - JSON.stringify( - { - ...requestBody, - messages: requestBody.messages.map((msg) => ({ - ...msg, - content: Array.isArray(msg.content) - ? msg.content.map((c) => ({ - type: c.type, - ...(c.document - ? { - document: { - format: c.document.format, - name: c.document.name, - }, - } - : {}), - ...(c.image_url ? { image_url: { url: "[BINARY]" } } : {}), - ...(c.text ? { text: c.text } : {}), - })) - : msg.content, - })), - }, - null, - 2, - ), - ); + const [tools, funcs] = usePluginStore + .getState() + .getAsTools(useChatStore.getState().currentSession().mask?.plugin || []); - const shouldStream = !!options.config.stream; - const conversePath = `${ApiPath.Bedrock}/converse`; - - if (shouldStream) { - let response = await fetch(conversePath, { - method: "POST", - headers: { - ...headers, - "X-Stream": "true", - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); - - if (!response.ok) { - const error = await response.text(); - throw new Error(`Bedrock API error: ${error}`); - } - - let buffer = ""; - const reader = response.body?.getReader(); - if (!reader) { - throw new Error("No response body reader available"); - } - - let currentContent = ""; - let isFirstMessage = true; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - // Convert the chunk to text and add to buffer - const chunk = new TextDecoder().decode(value); - buffer += chunk; - - // Process complete messages from buffer - let newlineIndex; - while ((newlineIndex = buffer.indexOf("\n")) !== -1) { - const line = buffer.slice(0, newlineIndex).trim(); - buffer = buffer.slice(newlineIndex + 1); - - if (line.startsWith("data: ")) { - try { - const event = JSON.parse(line.slice(6)); - - if (event.type === "messageStart") { - if (isFirstMessage) { - isFirstMessage = false; - } - continue; - } - - if (event.type === "text" && event.content) { - currentContent += event.content; - options.onUpdate?.(currentContent, event.content); - } - - if (event.type === "messageStop") { - options.onFinish(currentContent); - return; - } - - if (event.type === "error") { - throw new Error(event.message || "Unknown error"); - } - } catch (e) { - console.error("[Bedrock] Failed to parse stream event:", e); - } + const requestBody = { + modelId: options.config.model, + messages: messages.filter((msg) => msg.content.length > 0), + inferenceConfig: { + maxTokens: modelConfig.max_tokens, + temperature: modelConfig.temperature, + topP: modelConfig.top_p, + stopSequences: [], + }, + toolConfig: + Array.isArray(tools) && tools.length > 0 + ? { + tools: tools.map((tool: any) => ({ + toolSpec: { + name: tool?.function?.name, + description: tool?.function?.description, + inputSchema: { + json: tool?.function?.parameters, + }, + }, + })), + toolChoice: { auto: {} }, } - } - } + : undefined, + }; - // If we reach here without a messageStop event, finish with current content - options.onFinish(currentContent); - } else { + const conversePath = `${ApiPath.Bedrock}/converse`; + const controller = new AbortController(); + options.onController?.(controller); + + if (shouldStream) { + let currentToolUse: ChatMessageTool | null = null; + return stream( + conversePath, + requestBody, + headers, + Array.isArray(tools) + ? tools.map((tool: any) => ({ + name: tool?.function?.name, + description: tool?.function?.description, + input_schema: tool?.function?.parameters, + })) + : [], + funcs, + controller, + // parseSSE + (text: string, runTools: ChatMessageTool[]) => { + const event = JSON.parse(text); + + if (event.type === "messageStart") { + return ""; + } + + if (event.type === "contentBlockStart" && event.start?.toolUse) { + const { toolUseId, name } = event.start.toolUse; + currentToolUse = { + id: toolUseId, + type: "function", + function: { + name, + arguments: "", + }, + }; + runTools.push(currentToolUse); + return ""; + } + + if (event.type === "text" && event.content) { + return event.content; + } + + if ( + event.type === "toolUse" && + event.input && + currentToolUse?.function + ) { + currentToolUse.function.arguments += event.input; + return ""; + } + + if (event.type === "error") { + throw new Error(event.message || "Unknown error"); + } + + return ""; + }, + // processToolMessage + (requestPayload: any, toolCallMessage: any, toolCallResult: any[]) => { + currentToolUse = null; + requestPayload?.messages?.splice( + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + type: "tool_use", + id: tool.id, + name: tool?.function?.name, + input: tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + }), + ), + }, + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: result.tool_call_id, + content: result.content, + }, + ], + })), + ); + }, + options, + ); + } else { + try { const response = await fetch(conversePath, { method: "POST", headers, @@ -395,23 +297,10 @@ export class BedrockApi implements LLMApi { const responseBody = await response.json(); const content = this.extractMessage(responseBody); options.onFinish(content); + } catch (e: any) { + console.error("[Bedrock] Chat error:", e); + throw e; } - } catch (e) { - console.error("[Bedrock] Chat error:", e); - options.onError?.(e as Error); } } - - async usage(): Promise { - // Bedrock usage is tracked through AWS billing - return { - used: 0, - total: 0, - }; - } - - async models() { - // Return empty array as models are configured through AWS console - return []; - } }