diff --git a/.env.template b/.env.template index 1f4773195..1661f77e1 100644 --- a/.env.template +++ b/.env.template @@ -71,8 +71,4 @@ WHITE_WEBDAV_ENDPOINTS= ### bedrock (optional) AWS_REGION= AWS_ACCESS_KEY=AKIA -AWS_SECRET_KEY= -### Assign this with a secure, randomly generated key; -### Generate a secure, random key that is at least 32 characters long. You can use a password generator or a command like this: -### openssl rand -base64 32 -ENCRYPTION_KEY= \ No newline at end of file +AWS_SECRET_KEY= \ No newline at end of file diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index fa52363ae..d65fc3f50 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -1,209 +1,23 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "./auth"; -import { sign, decrypt } from "../utils/aws"; +import { + sign, + decrypt, + getBedrockEndpoint, + getModelHeaders, + transformBedrockStream, + parseEventData, + BedrockCredentials, +} from "../utils/aws"; import { getServerSideConfig } from "../config/server"; import { ModelProvider } from "../constant"; import { prettyObject } from "../utils/format"; const ALLOWED_PATH = new Set(["chat", "models"]); -function parseEventData(chunk: Uint8Array): any { - const decoder = new TextDecoder(); - const text = decoder.decode(chunk); - try { - const parsed = JSON.parse(text); - // AWS Bedrock wraps the response in a 'body' field - if (typeof parsed.body === "string") { - try { - const bodyJson = JSON.parse(parsed.body); - return bodyJson; - } catch (e) { - return { output: parsed.body }; - } - } - return parsed.body || parsed; - } catch (e) { - // console.error("Error parsing event data:", e); - try { - // Handle base64 encoded responses - const base64Match = text.match(/:"([A-Za-z0-9+/=]+)"/); - if (base64Match) { - const decoded = Buffer.from(base64Match[1], "base64").toString("utf-8"); - try { - return JSON.parse(decoded); - } catch (e) { - return { output: decoded }; - } - } - - // Handle event-type responses - const eventMatch = text.match(/:event-type[^\{]+({.*})/); - if (eventMatch) { - try { - return JSON.parse(eventMatch[1]); - } catch (e) { - return { output: eventMatch[1] }; - } - } - - // Handle plain text responses - if (text.trim()) { - // Clean up any malformed JSON characters - const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); - return { output: cleanText }; - } - } catch (innerError) { - console.error("Error in fallback parsing:", innerError); - } - } - return null; -} - -async function* transformBedrockStream( - stream: ReadableStream, - modelId: string, -) { - const reader = stream.getReader(); - let buffer = ""; - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) { - if (buffer) { - yield `data: ${JSON.stringify({ - delta: { text: buffer }, - })}\n\n`; - } - break; - } - - const parsed = parseEventData(value); - if (!parsed) continue; - - // Handle Titan models - if (modelId.startsWith("amazon.titan")) { - const text = parsed.outputText || ""; - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } - // Handle LLaMA models - else if (modelId.startsWith("us.meta.llama")) { - let text = ""; - if (parsed.outputs?.[0]?.text) { - text = parsed.outputs[0].text; - } else if (parsed.generation) { - text = parsed.generation; - } else if (parsed.output) { - text = parsed.output; - } else if (typeof parsed === "string") { - text = parsed; - } - - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } - // Handle Mistral models - else if (modelId.startsWith("mistral.mistral")) { - let text = ""; - if (parsed.outputs?.[0]?.text) { - text = parsed.outputs[0].text; - } else if (parsed.output) { - text = parsed.output; - } else if (parsed.completion) { - text = parsed.completion; - } else if (typeof parsed === "string") { - text = parsed; - } - - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } - // Handle Claude models - else if (modelId.startsWith("anthropic.claude")) { - if (parsed.type === "content_block_delta") { - if (parsed.delta?.type === "text_delta") { - yield `data: ${JSON.stringify({ - delta: { text: parsed.delta.text }, - })}\n\n`; - } else if (parsed.delta?.type === "input_json_delta") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } - } else if ( - parsed.type === "message_delta" && - parsed.delta?.stop_reason - ) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.delta.stop_reason }, - })}\n\n`; - } else if ( - parsed.type === "content_block_start" && - parsed.content_block?.type === "tool_use" - ) { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } else if (parsed.type === "content_block_stop") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } - } - } - } finally { - reader.releaseLock(); - } -} - -function validateRequest(body: any, modelId: string): void { - if (!modelId) throw new Error("Model ID is required"); - - // Handle nested body structure - const bodyContent = body.body || body; - - if (modelId.startsWith("anthropic.claude")) { - if ( - !body.anthropic_version || - body.anthropic_version !== "bedrock-2023-05-31" - ) { - throw new Error("anthropic_version must be 'bedrock-2023-05-31'"); - } - if (typeof body.max_tokens !== "number" || body.max_tokens < 0) { - throw new Error("max_tokens must be a positive number"); - } - if (modelId.startsWith("anthropic.claude-3")) { - if (!Array.isArray(body.messages)) - throw new Error("messages array is required for Claude 3"); - } else if (typeof body.prompt !== "string") { - throw new Error("prompt is required for Claude 2 and earlier"); - } - } else if (modelId.startsWith("us.meta.llama")) { - if (!bodyContent.prompt || typeof bodyContent.prompt !== "string") { - throw new Error("prompt string is required for LLaMA models"); - } - if ( - !bodyContent.max_gen_len || - typeof bodyContent.max_gen_len !== "number" - ) { - throw new Error("max_gen_len must be a positive number for LLaMA models"); - } - } else if (modelId.startsWith("mistral.mistral")) { - if (!bodyContent.prompt) { - throw new Error("prompt is required for Mistral models"); - } - } else if (modelId.startsWith("amazon.titan")) { - if (!bodyContent.inputText) throw new Error("Titan requires inputText"); - } -} - -async function requestBedrock(req: NextRequest) { - const controller = new AbortController(); - +async function getBedrockCredentials( + req: NextRequest, +): Promise { // Get AWS credentials from server config first const config = getServerSideConfig(); let awsRegion = config.awsRegion; @@ -224,90 +38,99 @@ async function requestBedrock(req: NextRequest) { if (!encryptedRegion || !encryptedAccessKey || !encryptedSecretKey) { throw new Error("Invalid Authorization header format"); } - + const encryptionKey = req.headers.get("XEncryptionKey") || ""; // Decrypt the credentials - awsRegion = decrypt(encryptedRegion); - awsAccessKey = decrypt(encryptedAccessKey); - awsSecretKey = decrypt(encryptedSecretKey); + awsRegion = decrypt(encryptedRegion, encryptionKey); + awsAccessKey = decrypt(encryptedAccessKey, encryptionKey); + awsSecretKey = decrypt(encryptedSecretKey, encryptionKey); if (!awsRegion || !awsAccessKey || !awsSecretKey) { - throw new Error("Failed to decrypt AWS credentials"); + throw new Error( + "Failed to decrypt AWS credentials. Please ensure ENCRYPTION_KEY is set correctly.", + ); } } - let modelId = req.headers.get("ModelID"); - let shouldStream = req.headers.get("ShouldStream"); - if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { - throw new Error("Missing required AWS credentials or model ID"); - } + return { + region: awsRegion, + accessKeyId: awsAccessKey, + secretAccessKey: awsSecretKey, + }; +} - // Construct the base endpoint - const baseEndpoint = `https://bedrock-runtime.${awsRegion}.amazonaws.com`; - - // Set up timeout +async function requestBedrock(req: NextRequest) { + const controller = new AbortController(); const timeoutId = setTimeout(() => controller.abort(), 10 * 60 * 1000); try { - // Determine the endpoint and request body based on model type - let endpoint; + // Get credentials and model info + const credentials = await getBedrockCredentials(req); + const modelId = req.headers.get("XModelID"); + const shouldStream = req.headers.get("ShouldStream") !== "false"; + if (!modelId) { + throw new Error("Missing model ID"); + } + + // Parse and validate request body const bodyText = await req.clone().text(); if (!bodyText) { throw new Error("Request body is empty"); } - - const bodyJson = JSON.parse(bodyText); - - // Debug log the request body - console.log("Original request body:", JSON.stringify(bodyJson, null, 2)); - - validateRequest(bodyJson, modelId); - - // For all models, use standard endpoints - if (shouldStream === "false") { - endpoint = `${baseEndpoint}/model/${modelId}/invoke`; - } else { - endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; + let bodyJson; + try { + bodyJson = JSON.parse(bodyText); + } catch (e) { + throw new Error(`Invalid JSON in request body: ${e}`); } - // Set additional headers based on model type - const additionalHeaders: Record = {}; - if ( - modelId.startsWith("us.meta.llama") || - modelId.startsWith("mistral.mistral") - ) { - additionalHeaders["content-type"] = "application/json"; - additionalHeaders["accept"] = "application/json"; + // Extract tool configuration if present + let tools: any[] | undefined; + if (bodyJson.tools) { + tools = bodyJson.tools; + delete bodyJson.tools; // Remove from main request body } - // For Mistral models, unwrap the body object - const finalRequestBody = - modelId.startsWith("mistral.mistral") && bodyJson.body - ? bodyJson.body - : bodyJson; + // Get endpoint and prepare request + const endpoint = getBedrockEndpoint( + credentials.region, + modelId, + shouldStream, + ); + const additionalHeaders = getModelHeaders(modelId); - // Set content type and accept headers for specific models + console.log("[Bedrock Request] Endpoint:", endpoint); + console.log("[Bedrock Request] Model ID:", modelId); + + // Only include tools for Claude models + const isClaudeModel = modelId.toLowerCase().includes("claude3"); + const requestBody = { + ...bodyJson, + ...(isClaudeModel && tools && { tools }), + }; + + // Sign request const headers = await sign({ method: "POST", url: endpoint, - region: awsRegion, - accessKeyId: awsAccessKey, - secretAccessKey: awsSecretKey, - body: JSON.stringify(finalRequestBody), + region: credentials.region, + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + body: JSON.stringify(requestBody), service: "bedrock", - isStreaming: shouldStream !== "false", + isStreaming: shouldStream, additionalHeaders, }); - // Debug log the final request body - // console.log("Final request endpoint:", endpoint); - // console.log(headers); - // console.log("Final request body:", JSON.stringify(finalRequestBody, null, 2)); - + // Make request to AWS Bedrock + console.log( + "[Bedrock Request] Body:", + JSON.stringify(requestBody, null, 2), + ); const res = await fetch(endpoint, { method: "POST", headers, - body: JSON.stringify(finalRequestBody), + body: JSON.stringify(requestBody), redirect: "manual", // @ts-ignore duplex: "half", @@ -316,24 +139,35 @@ async function requestBedrock(req: NextRequest) { if (!res.ok) { const error = await res.text(); - console.error("AWS Bedrock error response:", error); + console.error("[Bedrock Error] Status:", res.status); + console.error("[Bedrock Error] Response:", error); try { const errorJson = JSON.parse(error); throw new Error(errorJson.message || error); } catch { - throw new Error(error || "Failed to get response from Bedrock"); + throw new Error( + `Bedrock request failed with status ${res.status}: ${ + error || "No error message" + }`, + ); } } if (!res.body) { - throw new Error("Empty response from Bedrock"); + console.error("[Bedrock Error] Empty response body"); + throw new Error( + "Empty response from Bedrock. Please check AWS credentials and permissions.", + ); } // Handle non-streaming response - if (shouldStream === "false") { + if (!shouldStream) { const responseText = await res.text(); - console.error("AWS Bedrock shouldStream === false:", responseText); + console.log("[Bedrock Response] Non-streaming:", responseText); const parsed = parseEventData(new TextEncoder().encode(responseText)); + if (!parsed) { + throw new Error("Failed to parse Bedrock response"); + } return NextResponse.json(parsed); } @@ -347,7 +181,7 @@ async function requestBedrock(req: NextRequest) { } controller.close(); } catch (err) { - console.error("Stream error:", err); + console.error("[Bedrock Stream Error]:", err); controller.error(err); } }, @@ -362,7 +196,7 @@ async function requestBedrock(req: NextRequest) { }, }); } catch (e) { - console.error("Request error:", e); + console.error("[Bedrock Request Error]:", e); throw e; } finally { clearTimeout(timeoutId); @@ -384,12 +218,14 @@ export async function handle( { status: 403 }, ); } + const authResult = auth(req, ModelProvider.Bedrock); if (authResult.error) { return NextResponse.json(authResult, { status: 401, }); } + try { return await requestBedrock(req); } catch (e) { diff --git a/app/client/api.ts b/app/client/api.ts index 04da39ac1..06537d1de 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -280,11 +280,11 @@ export function getHeaders(ignoreHeaders: boolean = false) { ? accessStore.awsRegion && accessStore.awsAccessKey && accessStore.awsSecretKey - ? encrypt(accessStore.awsRegion) + + ? encrypt(accessStore.awsRegion, accessStore.encryptionKey) + ":" + - encrypt(accessStore.awsAccessKey) + + encrypt(accessStore.awsAccessKey, accessStore.encryptionKey) + ":" + - encrypt(accessStore.awsSecretKey) + encrypt(accessStore.awsSecretKey, accessStore.encryptionKey) : "" : accessStore.openaiApiKey; return { diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 5c661c86f..7d1bda3dd 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -1,24 +1,19 @@ "use client"; -import { - ChatOptions, - getHeaders, - LLMApi, - SpeechOptions, - RequestMessage, - MultimodalContent, - MessageRole, -} from "../api"; +import { ChatOptions, getHeaders, LLMApi, SpeechOptions } from "../api"; import { useAppConfig, usePluginStore, useChatStore, useAccessStore, ChatMessageTool, -} from "../../store"; -import { preProcessImageContent, stream } from "../../utils/chat"; -import { getMessageTextContent, isVisionModel } from "../../utils"; -import { ApiPath, BEDROCK_BASE_URL } from "../../constant"; -import { getClientConfig } from "../../config/client"; +} from "@/app/store"; +import { preProcessImageContent, stream } from "@/app/utils/chat"; +import { getMessageTextContent, isVisionModel } from "@/app/utils"; +import { ApiPath, BEDROCK_BASE_URL } from "@/app/constant"; +import { getClientConfig } from "@/app/config/client"; +import { extractMessage } from "@/app/utils/aws"; +import { RequestPayload } from "./openai"; +import { fetch } from "@/app/utils/stream"; const ClaudeMapper = { assistant: "assistant", @@ -28,184 +23,41 @@ const ClaudeMapper = { type ClaudeRole = keyof typeof ClaudeMapper; -interface ToolDefinition { - function?: { - name: string; - description?: string; - parameters?: any; - }; -} - export class BedrockApi implements LLMApi { - private disableListModels = true; - - path(path: string): string { - const accessStore = useAccessStore.getState(); - - let baseUrl = ""; - - if (accessStore.useCustomConfig) { - baseUrl = accessStore.bedrockUrl; - } - - if (baseUrl.length === 0) { - const isApp = !!getClientConfig()?.isApp; - const apiPath = ApiPath.Bedrock; - baseUrl = isApp ? BEDROCK_BASE_URL : apiPath; - } - - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.slice(0, baseUrl.length - 1); - } - if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Bedrock)) { - baseUrl = "https://" + baseUrl; - } - - console.log("[Proxy Endpoint] ", baseUrl, path); - - return [baseUrl, path].join("/"); - } - speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); } - extractMessage(res: any, modelId: string = "") { - try { - // Handle Titan models - if (modelId.startsWith("amazon.titan")) { - let text = ""; - if (res?.delta?.text) { - text = res.delta.text; - } else { - text = res?.outputText || ""; - } - // Clean up Titan response by removing leading question mark and whitespace - return text.replace(/^[\s?]+/, ""); - } - - // Handle LLaMA models - if (modelId.startsWith("us.meta.llama")) { - if (res?.delta?.text) { - return res.delta.text; - } - if (res?.generation) { - return res.generation; - } - if (res?.outputs?.[0]?.text) { - return res.outputs[0].text; - } - if (res?.output) { - return res.output; - } - if (typeof res === "string") { - return res; - } - return ""; - } - - // Handle Mistral models - if (modelId.startsWith("mistral.mistral")) { - if (res?.delta?.text) { - return res.delta.text; - } - if (res?.outputs?.[0]?.text) { - return res.outputs[0].text; - } - if (res?.content?.[0]?.text) { - return res.content[0].text; - } - if (res?.output) { - return res.output; - } - if (res?.completion) { - return res.completion; - } - if (typeof res === "string") { - return res; - } - return ""; - } - - // Handle Claude models - if (res?.content?.[0]?.text) return res.content[0].text; - if (res?.messages?.[0]?.content?.[0]?.text) - return res.messages[0].content[0].text; - if (res?.delta?.text) return res.delta.text; - if (res?.completion) return res.completion; - if (res?.generation) return res.generation; - if (res?.outputText) return res.outputText; - if (res?.output) return res.output; - - if (typeof res === "string") return res; - - return ""; - } catch (e) { - console.error("Error extracting message:", e); - return ""; - } - } - - formatRequestBody( - messages: RequestMessage[], - systemMessage: string, - modelConfig: any, - ) { + formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) { const model = modelConfig.model; + const visionModel = isVisionModel(modelConfig.model); + // Handle Titan models if (model.startsWith("amazon.titan")) { - const allMessages = systemMessage - ? [ - { role: "system" as MessageRole, content: systemMessage }, - ...messages, - ] - : messages; - - const inputText = allMessages - .map((m) => { - if (m.role === "system") { - return getMessageTextContent(m); - } - return getMessageTextContent(m); + const inputText = messages + .map((message) => { + return `${message.role}: ${message.content}`; }) .join("\n\n"); return { - body: { - inputText, - textGenerationConfig: { - maxTokenCount: modelConfig.max_tokens, - temperature: modelConfig.temperature, - stopSequences: [], - }, + inputText, + textGenerationConfig: { + maxTokenCount: modelConfig.max_tokens, + temperature: modelConfig.temperature, + stopSequences: [], }, }; } // Handle LLaMA models if (model.startsWith("us.meta.llama")) { - const allMessages = systemMessage - ? [ - { role: "system" as MessageRole, content: systemMessage }, - ...messages, - ] - : messages; - - const prompt = allMessages - .map((m) => { - const content = getMessageTextContent(m); - if (m.role === "system") { - return `System: ${content}`; - } else if (m.role === "user") { - return `User: ${content}`; - } else if (m.role === "assistant") { - return `Assistant: ${content}`; - } - return content; + const prompt = messages + .map((message) => { + return `${message.role}: ${message.content}`; }) .join("\n\n"); - return { prompt, max_gen_len: modelConfig.max_tokens || 512, @@ -217,116 +69,124 @@ export class BedrockApi implements LLMApi { // Handle Mistral models if (model.startsWith("mistral.mistral")) { - const allMessages = systemMessage - ? [ - { role: "system" as MessageRole, content: systemMessage }, - ...messages, - ] - : messages; - - const formattedConversation = allMessages - .map((m) => { - const content = getMessageTextContent(m); - if (m.role === "system") { - return content; - } else if (m.role === "user") { - return content; - } else if (m.role === "assistant") { - return content; - } - return content; + const prompt = messages + .map((message) => { + return `${message.role}: ${message.content}`; }) - .join("\n"); - - // Format according to Mistral's requirements + .join("\n\n"); return { - prompt: formattedConversation, + prompt, max_tokens: modelConfig.max_tokens || 4096, temperature: modelConfig.temperature || 0.7, }; } // Handle Claude models - const isClaude3 = model.startsWith("anthropic.claude-3"); - const formattedMessages = messages - .filter( - (v) => v.content && (typeof v.content !== "string" || v.content.trim()), - ) + const keys = ["system", "user"]; + // roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages + for (let i = 0; i < messages.length - 1; i++) { + const message = messages[i]; + const nextMessage = messages[i + 1]; + + if (keys.includes(message.role) && keys.includes(nextMessage.role)) { + messages[i] = [ + message, + { + role: "assistant", + content: ";", + }, + ] as any; + } + } + const prompt = messages + .flat() + .filter((v) => { + if (!v.content) return false; + if (typeof v.content === "string" && !v.content.trim()) return false; + return true; + }) .map((v) => { const { role, content } = v; - const insideRole = ClaudeMapper[role as ClaudeRole] ?? "user"; + const insideRole = ClaudeMapper[role] ?? "user"; - if (!isVisionModel(model) || typeof content === "string") { + if (!visionModel || typeof content === "string") { return { role: insideRole, - content: [{ type: "text", text: getMessageTextContent(v) }], + content: getMessageTextContent(v), }; } - return { role: insideRole, - content: (content as MultimodalContent[]) + content: content .filter((v) => v.image_url || v.text) .map(({ type, text, image_url }) => { - if (type === "text") return { type, text: text! }; - + if (type === "text") { + return { + type, + text: text!, + }; + } const { url = "" } = image_url || {}; const colonIndex = url.indexOf(":"); const semicolonIndex = url.indexOf(";"); const comma = url.indexOf(","); + const mimeType = url.slice(colonIndex + 1, semicolonIndex); + const encodeType = url.slice(semicolonIndex + 1, comma); + const data = url.slice(comma + 1); + return { - type: "image", + type: "image" as const, source: { - type: url.slice(semicolonIndex + 1, comma), - media_type: url.slice(colonIndex + 1, semicolonIndex), - data: url.slice(comma + 1), + type: encodeType, + media_type: mimeType, + data, }, }; }), }; }); - return { - body: { - anthropic_version: "bedrock-2023-05-31", - max_tokens: modelConfig.max_tokens, - messages: formattedMessages, - ...(systemMessage && { system: systemMessage }), - temperature: modelConfig.temperature, - ...(isClaude3 && { top_k: modelConfig.top_k || 50 }), - }, + if (prompt[0]?.role === "assistant") { + prompt.unshift({ + role: "user", + content: ";", + }); + } + const requestBody: any = { + anthropic_version: useAccessStore.getState().bedrockAnthropicVersion, + max_tokens: modelConfig.max_tokens, + messages: prompt, + temperature: modelConfig.temperature, + top_p: modelConfig.top_p || 0.9, + top_k: modelConfig.top_k || 5, }; + return requestBody; } async chat(options: ChatOptions) { + const accessStore = useAccessStore.getState(); + + const shouldStream = !!options.config.stream; + const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, - model: options.config.model, + ...{ + model: options.config.model, + }, }; - let systemMessage = ""; - const messages = []; - for (const msg of options.messages) { - const content = await preProcessImageContent(msg.content); - if (msg.role === "system") { - systemMessage = getMessageTextContent(msg); - } else { - messages.push({ role: msg.role, content }); - } + // try get base64image from local cache image_url + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = await preProcessImageContent(v.content); + messages.push({ role: v.role, content }); } - const requestBody = this.formatRequestBody( - messages, - systemMessage, - modelConfig, - ); - const controller = new AbortController(); options.onController?.(controller); - const accessStore = useAccessStore.getState(); if (!accessStore.isValidBedrock()) { throw new Error( "Invalid AWS credentials. Please check your configuration and ensure ENCRYPTION_KEY is set.", @@ -336,29 +196,30 @@ export class BedrockApi implements LLMApi { try { const chatPath = this.path("chat"); const headers = getHeaders(); - headers.ModelID = modelConfig.model; + headers.XModelID = modelConfig.model; + headers.XEncryptionKey = accessStore.encryptionKey; - // For LLaMA and Mistral models, send the request body directly without the 'body' wrapper - const finalRequestBody = - modelConfig.model.startsWith("us.meta.llama") || - modelConfig.model.startsWith("mistral.mistral") - ? requestBody - : requestBody.body; + console.log("[Bedrock Client] Request:", { + path: chatPath, + model: modelConfig.model, + messages: messages.length, + stream: shouldStream, + }); - if (options.config.stream) { + const finalRequestBody = this.formatRequestBody(messages, modelConfig); + if (shouldStream) { let index = -1; - let currentToolArgs = ""; const [tools, funcs] = usePluginStore .getState() .getAsTools( useChatStore.getState().currentSession().mask?.plugin || [], ); - return stream( chatPath, finalRequestBody, headers, - (tools as ToolDefinition[]).map((tool) => ({ + // @ts-ignore + tools.map((tool) => ({ name: tool?.function?.name, description: tool?.function?.description, input_schema: tool?.function?.parameters, @@ -366,96 +227,86 @@ export class BedrockApi implements LLMApi { funcs, controller, (text: string, runTools: ChatMessageTool[]) => { - try { - const chunkJson = JSON.parse(text); - if (chunkJson?.content_block?.type === "tool_use") { - index += 1; - currentToolArgs = ""; - const id = chunkJson.content_block?.id; - const name = chunkJson.content_block?.name; - if (id && name) { - runTools.push({ - id, - type: "function", - function: { name, arguments: "" }, - }); - } - } else if ( - chunkJson?.delta?.type === "input_json_delta" && - chunkJson.delta?.partial_json - ) { - currentToolArgs += chunkJson.delta.partial_json; - try { - JSON.parse(currentToolArgs); - if (index >= 0 && index < runTools.length) { - runTools[index].function!.arguments = currentToolArgs; - } - } catch (e) {} - } else if ( - chunkJson?.type === "content_block_stop" && - currentToolArgs && - index >= 0 && - index < runTools.length - ) { - try { - if (currentToolArgs.trim().endsWith(",")) { - currentToolArgs = currentToolArgs.slice(0, -1) + "}"; - } else if (!currentToolArgs.endsWith("}")) { - currentToolArgs += "}"; - } - JSON.parse(currentToolArgs); - runTools[index].function!.arguments = currentToolArgs; - } catch (e) {} - } - const message = this.extractMessage(chunkJson, modelConfig.model); - return message; - } catch (e) { - console.error("Error parsing chunk:", e); - return ""; + // console.log("parseSSE", text, runTools); + let chunkJson: + | undefined + | { + type: "content_block_delta" | "content_block_stop"; + content_block?: { + type: "tool_use"; + id: string; + name: string; + }; + delta?: { + type: "text_delta" | "input_json_delta"; + text?: string; + partial_json?: string; + }; + index: number; + }; + chunkJson = JSON.parse(text); + + if (chunkJson?.content_block?.type == "tool_use") { + index += 1; + const id = chunkJson?.content_block.id; + const name = chunkJson?.content_block.name; + runTools.push({ + id, + type: "function", + function: { + name, + arguments: "", + }, + }); } + if ( + chunkJson?.delta?.type == "input_json_delta" && + chunkJson?.delta?.partial_json + ) { + // @ts-ignore + runTools[index]["function"]["arguments"] += + chunkJson?.delta?.partial_json; + } + return chunkJson?.delta?.text; }, + // processToolMessage, include tool_calls message and tool call results ( - requestPayload: any, + requestPayload: RequestPayload, toolCallMessage: any, toolCallResult: any[], ) => { + // reset index value index = -1; - currentToolArgs = ""; - if (requestPayload?.messages) { - requestPayload.messages.splice( - requestPayload.messages.length, - 0, - { - role: "assistant", - content: [ - { - type: "text", - text: JSON.stringify( - toolCallMessage.tool_calls.map( - (tool: ChatMessageTool) => ({ - type: "tool_use", - id: tool.id, - name: tool?.function?.name, - input: tool?.function?.arguments - ? JSON.parse(tool?.function?.arguments) - : {}, - }), - ), - ), - }, - ], - }, - ...toolCallResult.map((result) => ({ - role: "user", - content: [ - { - type: "text", - text: `Tool '${result.tool_call_id}' returned: ${result.content}`, - }, - ], - })), - ); - } + // @ts-ignore + requestPayload?.messages?.splice( + // @ts-ignore + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + type: "tool_use", + id: tool.id, + name: tool?.function?.name, + input: tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + }), + ), + }, + // @ts-ignore + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: result.tool_call_id, + content: result.content, + }, + ], + })), + ); }, options, ); @@ -467,15 +318,48 @@ export class BedrockApi implements LLMApi { body: JSON.stringify(finalRequestBody), }); + if (!res.ok) { + const errorText = await res.text(); + console.error("[Bedrock Client] Error response:", errorText); + throw new Error(`Request failed: ${errorText}`); + } + const resJson = await res.json(); - const message = this.extractMessage(resJson, modelConfig.model); + if (!resJson) { + throw new Error("Empty response from server"); + } + + const message = extractMessage(resJson, modelConfig.model); + if (!message) { + throw new Error("Failed to extract message from response"); + } + options.onFinish(message, res); } } catch (e) { - console.error("Chat error:", e); + console.error("[Bedrock Client] Chat error:", e); options.onError?.(e as Error); } } + path(path: string): string { + const accessStore = useAccessStore.getState(); + let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : ""; + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + const apiPath = ApiPath.Bedrock; + baseUrl = isApp ? BEDROCK_BASE_URL : apiPath; + } + + baseUrl = baseUrl.endsWith("/") ? baseUrl.slice(0, -1) : baseUrl; + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Bedrock)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Bedrock Client] API Endpoint:", baseUrl, path); + + return [baseUrl, path].join("/"); + } async usage() { return { used: 0, total: 0 }; diff --git a/app/components/settings.tsx b/app/components/settings.tsx index bc251d47e..c71569b27 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -1,5 +1,4 @@ import { useState, useEffect, useMemo } from "react"; - import styles from "./settings.module.scss"; import ResetIcon from "../icons/reload.svg"; @@ -1027,14 +1026,22 @@ export function Settings() { > { accessStore.update( - (access) => (access.bedrockEncryptionKey = e.currentTarget.value), + (access) => (access.encryptionKey = e.currentTarget.value), ); }} + onBlur={(e) => { + const value = e.currentTarget.value; + if (!value || value.length < 8) { + showToast(Locale.Settings.Access.Bedrock.EncryptionKey.Invalid); + accessStore.update((access) => (access.encryptionKey = "")); + return; + } + }} maskWhenShow={true} /> diff --git a/app/config/server.ts b/app/config/server.ts index e8fbf131a..1f59c805f 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -186,7 +186,7 @@ export const getServerSideConfig = () => { awsRegion: process.env.AWS_REGION, awsAccessKey: process.env.AWS_ACCESS_KEY, awsSecretKey: process.env.AWS_SECRET_KEY, - bedrockEncryptionKey: process.env.ENCRYPTION_KEY, + encryptionKey: process.env.ENCRYPTION_KEY, isStability, stabilityUrl: process.env.STABILITY_URL, diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 9a8f5cc19..90970ba5e 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -365,6 +365,7 @@ const cn = { Title: "加密密钥", SubTitle: "用于配置数据的加密密钥", Placeholder: "输入加密密钥", + Invalid: "无效的加密密钥。必须至少包含8个字符!", }, }, Azure: { diff --git a/app/locales/en.ts b/app/locales/en.ts index 670f822c8..5e800049e 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -369,6 +369,8 @@ const en: LocaleType = { Title: "Encryption Key", SubTitle: "Your encryption key for configuration data", Placeholder: "Enter encryption key", + Invalid: + "Invalid encryption key format. Must no less than 8 characters long!", }, }, Azure: { diff --git a/app/store/access.ts b/app/store/access.ts index c0b3268cf..156bd6ff1 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -113,7 +113,8 @@ const DEFAULT_ACCESS_STATE = { awsRegion: "", awsAccessKey: "", awsSecretKey: "", - bedrockEncryptionKey: "", + encryptionKey: "", + bedrockAnthropicVersion: "bedrock-2023-05-31", // server config needCode: true, @@ -194,7 +195,7 @@ export const useAccessStore = createPersistStore( "awsRegion", "awsAccessKey", "awsSecretKey", - "bedrockEncryptionKey", + "encryptionKey", ]); }, @@ -256,13 +257,19 @@ export const useAccessStore = createPersistStore( // Override the set method to encrypt AWS credentials before storage set: (partial: { [key: string]: any }) => { if (partial.awsAccessKey) { - partial.awsAccessKey = encrypt(partial.awsAccessKey); + partial.awsAccessKey = encrypt( + partial.awsAccessKey, + partial.encryptionKey, + ); } if (partial.awsSecretKey) { - partial.awsSecretKey = encrypt(partial.awsSecretKey); + partial.awsSecretKey = encrypt( + partial.awsSecretKey, + partial.encryptionKey, + ); } if (partial.awsRegion) { - partial.awsRegion = encrypt(partial.awsRegion); + partial.awsRegion = encrypt(partial.awsRegion, partial.encryptionKey); } set(partial); }, @@ -272,9 +279,15 @@ export const useAccessStore = createPersistStore( const state = get(); return { ...state, - awsRegion: state.awsRegion ? decrypt(state.awsRegion) : "", - awsAccessKey: state.awsAccessKey ? decrypt(state.awsAccessKey) : "", - awsSecretKey: state.awsSecretKey ? decrypt(state.awsSecretKey) : "", + awsRegion: state.awsRegion + ? decrypt(state.awsRegion, state.encryptionKey) + : "", + awsAccessKey: state.awsAccessKey + ? decrypt(state.awsAccessKey, state.encryptionKey) + : "", + awsSecretKey: state.awsSecretKey + ? decrypt(state.awsSecretKey, state.encryptionKey) + : "", }; }, }), diff --git a/app/utils/aws.ts b/app/utils/aws.ts index d2997412f..75b359195 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -3,36 +3,59 @@ import HmacSHA256 from "crypto-js/hmac-sha256"; import Hex from "crypto-js/enc-hex"; import Utf8 from "crypto-js/enc-utf8"; import { AES, enc } from "crypto-js"; -import { getServerSideConfig } from "../config/server"; -const serverConfig = getServerSideConfig(); -const SECRET_KEY = serverConfig.bedrockEncryptionKey || ""; -if (serverConfig.isBedrock && !SECRET_KEY) { - console.error("When use Bedrock modle,ENCRYPTION_KEY should been set!"); +// Types and Interfaces +export interface BedrockCredentials { + region: string; + accessKeyId: string; + secretAccessKey: string; } -export function encrypt(data: string): string { +export interface BedrockRequestConfig { + modelId: string; + shouldStream: boolean; + body: any; + credentials: BedrockCredentials; +} + +export interface ModelValidationConfig { + requiredFields: string[]; + optionalFields?: string[]; + customValidation?: (body: any) => string | null; +} + +// Encryption utilities +export function encrypt(data: string, encryptionKey: string): string { if (!data) return ""; + if (!encryptionKey) { + console.error("[AWS Encryption Error] Encryption key is required"); + throw new Error("Encryption key is required for AWS credential encryption"); + } try { - return AES.encrypt(data, SECRET_KEY).toString(); + return AES.encrypt(data, encryptionKey).toString(); } catch (error) { - console.error("Encryption failed:", error); - return ""; + console.error("[AWS Encryption Error]:", error); + throw new Error("Failed to encrypt AWS credentials"); } } -export function decrypt(encryptedData: string): string { +export function decrypt(encryptedData: string, encryptionKey: string): string { if (!encryptedData) return ""; + if (!encryptionKey) { + console.error("[AWS Decryption Error] Encryption key is required"); + throw new Error("Encryption key is required for AWS credential decryption"); + } try { - const bytes = AES.decrypt(encryptedData, SECRET_KEY); + const bytes = AES.decrypt(encryptedData, encryptionKey); const decrypted = bytes.toString(enc.Utf8); if (!decrypted && encryptedData) { - return encryptedData; + console.error("[AWS Decryption Error] Failed to decrypt data"); + throw new Error("Failed to decrypt AWS credentials"); } return decrypted; } catch (error) { - console.error("Decryption failed:", error); - return ""; + console.error("[AWS Decryption Error]:", error); + throw new Error("Failed to decrypt AWS credentials"); } } @@ -42,6 +65,7 @@ export function maskSensitiveValue(value: string): string { return "*".repeat(value.length - 4) + value.slice(-4); } +// AWS Signing export interface SignParams { method: string; url: string; @@ -138,74 +162,271 @@ export async function sign({ isStreaming = true, additionalHeaders = {}, }: SignParams): Promise> { - const endpoint = new URL(url); - const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); - const canonicalQueryString = endpoint.search.slice(1); + try { + const endpoint = new URL(url); + const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); + const canonicalQueryString = endpoint.search.slice(1); - const now = new Date(); - const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); - const dateStamp = amzDate.slice(0, 8); + const now = new Date(); + const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); + const dateStamp = amzDate.slice(0, 8); - const payloadHash = SHA256(body).toString(Hex); + const payloadHash = SHA256(body).toString(Hex); - const headers: Record = { - accept: isStreaming - ? "application/vnd.amazon.eventstream" - : "application/json", - "content-type": "application/json", - host: endpoint.host, - "x-amz-content-sha256": payloadHash, - "x-amz-date": amzDate, - ...additionalHeaders, - }; + const headers: Record = { + accept: isStreaming + ? "application/vnd.amazon.eventstream" + : "application/json", + "content-type": "application/json", + host: endpoint.host, + "x-amz-content-sha256": payloadHash, + "x-amz-date": amzDate, + ...additionalHeaders, + }; - if (isStreaming) { - headers["x-amzn-bedrock-accept"] = "*/*"; + if (isStreaming) { + headers["x-amzn-bedrock-accept"] = "*/*"; + } + + const sortedHeaderKeys = Object.keys(headers).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()), + ); + + const canonicalHeaders = sortedHeaderKeys + .map( + (key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`, + ) + .join(""); + + const signedHeaders = sortedHeaderKeys + .map((key) => key.toLowerCase()) + .join(";"); + + const canonicalRequest = [ + method.toUpperCase(), + canonicalUri, + canonicalQueryString, + canonicalHeaders, + signedHeaders, + payloadHash, + ].join("\n"); + + const algorithm = "AWS4-HMAC-SHA256"; + const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; + const stringToSign = [ + algorithm, + amzDate, + credentialScope, + SHA256(canonicalRequest).toString(Hex), + ].join("\n"); + + const signingKey = getSigningKey( + secretAccessKey, + dateStamp, + region, + service, + ); + const signature = hmac(signingKey, stringToSign).toString(Hex); + + const authorization = [ + `${algorithm} Credential=${accessKeyId}/${credentialScope}`, + `SignedHeaders=${signedHeaders}`, + `Signature=${signature}`, + ].join(", "); + + return { + ...headers, + Authorization: authorization, + }; + } catch (error) { + console.error("[AWS Signing Error]:", error); + throw new Error("Failed to sign AWS request"); + } +} + +// Bedrock utilities +export function parseEventData(chunk: Uint8Array): any { + const decoder = new TextDecoder(); + const text = decoder.decode(chunk); + try { + const parsed = JSON.parse(text); + // AWS Bedrock wraps the response in a 'body' field + if (typeof parsed.body === "string") { + try { + return JSON.parse(parsed.body); + } catch (e) { + return { output: parsed.body }; + } + } + return parsed.body || parsed; + } catch (e) { + try { + // Handle base64 encoded responses + const base64Match = text.match(/:"([A-Za-z0-9+/=]+)"/); + if (base64Match) { + const decoded = Buffer.from(base64Match[1], "base64").toString("utf-8"); + try { + return JSON.parse(decoded); + } catch (e) { + return { output: decoded }; + } + } + + // Handle event-type responses + const eventMatch = text.match(/:event-type[^\{]+({.*})/); + if (eventMatch) { + try { + return JSON.parse(eventMatch[1]); + } catch (e) { + return { output: eventMatch[1] }; + } + } + + // Handle plain text responses + if (text.trim()) { + // Clean up any malformed JSON characters + const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); + return { output: cleanText }; + } + } catch (innerError) { + console.error("[AWS Parse Error] Inner parsing failed:", innerError); + } + } + return null; +} + +export function getBedrockEndpoint( + region: string, + modelId: string, + shouldStream: boolean, +): string { + if (!region || !modelId) { + throw new Error("Region and model ID are required for Bedrock endpoint"); + } + const baseEndpoint = `https://bedrock-runtime.${region}.amazonaws.com`; + const endpoint = + shouldStream === false + ? `${baseEndpoint}/model/${modelId}/invoke` + : `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; + return endpoint; +} + +export function getModelHeaders(modelId: string): Record { + if (!modelId) { + throw new Error("Model ID is required for headers"); } - const sortedHeaderKeys = Object.keys(headers).sort((a, b) => - a.toLowerCase().localeCompare(b.toLowerCase()), - ); + const headers: Record = {}; - const canonicalHeaders = sortedHeaderKeys - .map( - (key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`, - ) - .join(""); + if ( + modelId.startsWith("us.meta.llama") || + modelId.startsWith("mistral.mistral") + ) { + headers["content-type"] = "application/json"; + headers["accept"] = "application/json"; + } - const signedHeaders = sortedHeaderKeys - .map((key) => key.toLowerCase()) - .join(";"); - - const canonicalRequest = [ - method.toUpperCase(), - canonicalUri, - canonicalQueryString, - canonicalHeaders, - signedHeaders, - payloadHash, - ].join("\n"); - - const algorithm = "AWS4-HMAC-SHA256"; - const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; - const stringToSign = [ - algorithm, - amzDate, - credentialScope, - SHA256(canonicalRequest).toString(Hex), - ].join("\n"); - - const signingKey = getSigningKey(secretAccessKey, dateStamp, region, service); - const signature = hmac(signingKey, stringToSign).toString(Hex); - - const authorization = [ - `${algorithm} Credential=${accessKeyId}/${credentialScope}`, - `SignedHeaders=${signedHeaders}`, - `Signature=${signature}`, - ].join(", "); - - return { - ...headers, - Authorization: authorization, - }; + return headers; +} + +export function extractMessage(res: any, modelId: string = ""): string { + if (!res) { + console.error("[AWS Extract Error] extractMessage Empty response"); + return ""; + } + console.log("[Response] extractMessage response: ", res); + return res?.content?.[0]?.text; + return ""; +} + +export async function* transformBedrockStream( + stream: ReadableStream, + modelId: string, +) { + const reader = stream.getReader(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + if (buffer) { + yield `data: ${JSON.stringify({ + delta: { text: buffer }, + })}\n\n`; + } + break; + } + + const parsed = parseEventData(value); + if (!parsed) continue; + + // Handle Titan models + if (modelId.startsWith("amazon.titan")) { + const text = parsed.outputText || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle LLaMA3 models + else if (modelId.startsWith("us.meta.llama3")) { + let text = ""; + if (parsed.generation) { + text = parsed.generation; + } else if (parsed.output) { + text = parsed.output; + } else if (typeof parsed === "string") { + text = parsed; + } + + if (text) { + // Clean up any control characters or invalid JSON characters + text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle Mistral models + else if (modelId.startsWith("mistral.mistral")) { + const text = + parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; + if (text) { + yield `data: ${JSON.stringify({ + delta: { text }, + })}\n\n`; + } + } + // Handle Claude models + else if (modelId.startsWith("anthropic.claude")) { + if (parsed.type === "content_block_delta") { + if (parsed.delta?.type === "text_delta") { + yield `data: ${JSON.stringify({ + delta: { text: parsed.delta.text }, + })}\n\n`; + } else if (parsed.delta?.type === "input_json_delta") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + } else if ( + parsed.type === "message_delta" && + parsed.delta?.stop_reason + ) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.delta.stop_reason }, + })}\n\n`; + } else if ( + parsed.type === "content_block_start" && + parsed.content_block?.type === "tool_use" + ) { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } else if (parsed.type === "content_block_stop") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + } + } + } finally { + reader.releaseLock(); + } }