From ff356f0c8c78eacafeba403a0444aef743808e92 Mon Sep 17 00:00:00 2001 From: glay Date: Tue, 29 Oct 2024 22:20:26 +0800 Subject: [PATCH] =?UTF-8?q?=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20ap?= =?UTF-8?q?p/api/[provider]/[...path]/route.ts=20=09=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=9A=20=20=20=20=20app/api/auth.ts=20=09=E6=96=B0=E6=96=87?= =?UTF-8?q?=E4=BB=B6=EF=BC=9A=20=20=20app/api/bedrock.ts=20=09=E6=96=B0?= =?UTF-8?q?=E6=96=87=E4=BB=B6=EF=BC=9A=20=20=20app/api/bedrock/models.ts?= =?UTF-8?q?=20=09=E6=96=B0=E6=96=87=E4=BB=B6=EF=BC=9A=20=20=20app/api/bedr?= =?UTF-8?q?ock/utils.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app?= =?UTF-8?q?/client/api.ts=20=09=E6=96=B0=E6=96=87=E4=BB=B6=EF=BC=9A=20=20?= =?UTF-8?q?=20app/client/platforms/bedrock.ts=20=09=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=9A=20=20=20=20=20app/components/settings.tsx=20=09?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/config/server.ts?= =?UTF-8?q?=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/constant.ts?= =?UTF-8?q?=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/locales/cn.?= =?UTF-8?q?ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/locales/e?= =?UTF-8?q?n.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/store/a?= =?UTF-8?q?ccess.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/uti?= =?UTF-8?q?ls.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20package.js?= =?UTF-8?q?on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/[provider]/[...path]/route.ts | 4 + app/api/auth.ts | 22 ++ app/api/bedrock.ts | 288 ++++++++++++++++++++++++++ app/api/bedrock/models.ts | 280 +++++++++++++++++++++++++ app/api/bedrock/utils.ts | 218 +++++++++++++++++++ app/client/api.ts | 43 +++- app/client/platforms/bedrock.ts | 223 ++++++++++++++++++++ app/components/settings.tsx | 71 ++++++- app/config/server.ts | 10 +- app/constant.ts | 36 ++++ app/locales/cn.ts | 26 +++ app/locales/en.ts | 26 +++ app/store/access.ts | 21 ++ app/utils.ts | 3 + package.json | 3 +- 15 files changed, 1261 insertions(+), 13 deletions(-) create mode 100644 app/api/bedrock.ts create mode 100644 app/api/bedrock/models.ts create mode 100644 app/api/bedrock/utils.ts create mode 100644 app/client/platforms/bedrock.ts diff --git a/app/api/[provider]/[...path]/route.ts b/app/api/[provider]/[...path]/route.ts index 5ac248d0c..100c43714 100644 --- a/app/api/[provider]/[...path]/route.ts +++ b/app/api/[provider]/[...path]/route.ts @@ -1,6 +1,7 @@ import { ApiPath } from "@/app/constant"; import { NextRequest } from "next/server"; import { handle as openaiHandler } from "../../openai"; +import { handle as bedrockHandler } from "../../bedrock"; import { handle as azureHandler } from "../../azure"; import { handle as googleHandler } from "../../google"; import { handle as anthropicHandler } from "../../anthropic"; @@ -20,12 +21,15 @@ async function handle( const apiPath = `/api/${params.provider}`; console.log(`[${params.provider} Route] params `, params); switch (apiPath) { + case ApiPath.Bedrock: + return bedrockHandler(req, { params }); case ApiPath.Azure: return azureHandler(req, { params }); case ApiPath.Google: return googleHandler(req, { params }); case ApiPath.Anthropic: return anthropicHandler(req, { params }); + case ApiPath.Baidu: return baiduHandler(req, { params }); case ApiPath.ByteDance: diff --git a/app/api/auth.ts b/app/api/auth.ts index d4ac66a11..1a0ae0b43 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -52,6 +52,28 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { msg: "you are not allowed to access with your own api key", }; } + // Special handling for Bedrock + if (modelProvider === ModelProvider.Bedrock) { + const region = req.headers.get("X-Region"); + const accessKeyId = req.headers.get("X-Access-Key"); + const secretKey = req.headers.get("X-Secret-Key"); + + console.log("[Auth] Bedrock credentials:", { + region, + accessKeyId: accessKeyId ? "***" : undefined, + secretKey: secretKey ? "***" : undefined, + }); + + // Check if AWS credentials are provided + if (!region || !accessKeyId || !secretKey) { + return { + error: true, + msg: "Missing AWS credentials. Please configure Region, Access Key ID, and Secret Access Key in settings.", + }; + } + + return { error: false }; + } // if user does not provide an api key, inject system api key if (!apiKey) { diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts new file mode 100644 index 000000000..e2d212817 --- /dev/null +++ b/app/api/bedrock.ts @@ -0,0 +1,288 @@ +import { ModelProvider } from "../constant"; +import { prettyObject } from "../utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "./auth"; +import { + BedrockRuntimeClient, + InvokeModelCommand, + ValidationException, +} from "@aws-sdk/client-bedrock-runtime"; +import { validateModelId } from "./bedrock/utils"; +import { + ConverseRequest, + formatRequestBody, + parseModelResponse, +} from "./bedrock/models"; + +interface ContentItem { + type: string; + text?: string; + image_url?: { + url: string; + }; +} + +const ALLOWED_PATH = new Set(["invoke", "converse"]); + +export async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Bedrock Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const subpath = params.path.join("/"); + + if (!ALLOWED_PATH.has(subpath)) { + console.log("[Bedrock Route] forbidden path ", subpath); + return NextResponse.json( + { + error: true, + msg: "you are not allowed to request " + subpath, + }, + { + status: 403, + }, + ); + } + + const authResult = auth(req, ModelProvider.Bedrock); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + if (subpath === "converse") { + const response = await handleConverseRequest(req); + return response; + } else { + const response = await handleInvokeRequest(req); + return response; + } + } catch (e) { + console.error("[Bedrock] ", e); + + // Handle specific error cases + if (e instanceof ValidationException) { + return NextResponse.json( + { + error: true, + message: + "Model validation error. If using a Llama model, please provide a valid inference profile ARN.", + details: e.message, + }, + { status: 400 }, + ); + } + + return NextResponse.json( + { + error: true, + message: e instanceof Error ? e.message : "Unknown error", + details: prettyObject(e), + }, + { status: 500 }, + ); + } +} + +async function handleConverseRequest(req: NextRequest) { + const controller = new AbortController(); + + const region = req.headers.get("X-Region") || "us-east-1"; + const accessKeyId = req.headers.get("X-Access-Key") || ""; + const secretAccessKey = req.headers.get("X-Secret-Key") || ""; + const sessionToken = req.headers.get("X-Session-Token"); + + if (!accessKeyId || !secretAccessKey) { + return NextResponse.json( + { + error: true, + message: "Missing AWS credentials", + }, + { + status: 401, + }, + ); + } + + console.log("[Bedrock] Using region:", region); + + const client = new BedrockRuntimeClient({ + region, + credentials: { + accessKeyId, + secretAccessKey, + sessionToken: sessionToken || undefined, + }, + }); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + try { + const body = (await req.json()) as ConverseRequest; + const { modelId } = body; + + // Validate model ID + const validationError = validateModelId(modelId); + if (validationError) { + throw new ValidationException({ + message: validationError, + $metadata: {}, + }); + } + + console.log("[Bedrock] Invoking model:", modelId); + console.log("[Bedrock] Messages:", body.messages); + + const requestBody = formatRequestBody(body); + + const jsonString = JSON.stringify(requestBody); + const input = { + modelId, + contentType: "application/json", + accept: "application/json", + body: Uint8Array.from(Buffer.from(jsonString)), + }; + + console.log("[Bedrock] Request input:", { + ...input, + body: requestBody, + }); + + const command = new InvokeModelCommand(input); + const response = await client.send(command); + + console.log("[Bedrock] Got response"); + + // Parse and format the response based on model type + const responseBody = new TextDecoder().decode(response.body); + const formattedResponse = parseModelResponse(responseBody, modelId); + + return NextResponse.json(formattedResponse); + } catch (e) { + console.error("[Bedrock] Request error:", e); + throw e; // Let the main error handler deal with it + } finally { + clearTimeout(timeoutId); + } +} + +async function handleInvokeRequest(req: NextRequest) { + const controller = new AbortController(); + + const region = req.headers.get("X-Region") || "us-east-1"; + const accessKeyId = req.headers.get("X-Access-Key") || ""; + const secretAccessKey = req.headers.get("X-Secret-Key") || ""; + const sessionToken = req.headers.get("X-Session-Token"); + + if (!accessKeyId || !secretAccessKey) { + return NextResponse.json( + { + error: true, + message: "Missing AWS credentials", + }, + { + status: 401, + }, + ); + } + + const client = new BedrockRuntimeClient({ + region, + credentials: { + accessKeyId, + secretAccessKey, + sessionToken: sessionToken || undefined, + }, + }); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + try { + const body = await req.json(); + const { messages, model } = body; + + // Validate model ID + const validationError = validateModelId(model); + if (validationError) { + throw new ValidationException({ + message: validationError, + $metadata: {}, + }); + } + + console.log("[Bedrock] Invoking model:", model); + console.log("[Bedrock] Messages:", messages); + + const requestBody = formatRequestBody({ + modelId: model, + messages, + inferenceConfig: { + maxTokens: 2048, + temperature: 0.7, + topP: 0.9, + }, + }); + + const jsonString = JSON.stringify(requestBody); + const input = { + modelId: model, + contentType: "application/json", + accept: "application/json", + body: Uint8Array.from(Buffer.from(jsonString)), + }; + + console.log("[Bedrock] Request input:", { + ...input, + body: requestBody, + }); + + const command = new InvokeModelCommand(input); + const response = await client.send(command); + + console.log("[Bedrock] Got response"); + + // Parse and format the response + const responseBody = new TextDecoder().decode(response.body); + const formattedResponse = parseModelResponse(responseBody, model); + + // Extract text content from the response + let textContent = ""; + if (formattedResponse.content && Array.isArray(formattedResponse.content)) { + textContent = formattedResponse.content + .filter((item: ContentItem) => item.type === "text") + .map((item: ContentItem) => item.text || "") + .join(""); + } else if (typeof formattedResponse.content === "string") { + textContent = formattedResponse.content; + } + + // Return plain text response + return new NextResponse(textContent, { + headers: { + "Content-Type": "text/plain", + }, + }); + } catch (e) { + console.error("[Bedrock] Request error:", e); + throw e; + } finally { + clearTimeout(timeoutId); + } +} diff --git a/app/api/bedrock/models.ts b/app/api/bedrock/models.ts new file mode 100644 index 000000000..b9a0fee50 --- /dev/null +++ b/app/api/bedrock/models.ts @@ -0,0 +1,280 @@ +import { + Message, + validateMessageOrder, + processDocumentContent, + BedrockTextBlock, + BedrockImageBlock, + BedrockDocumentBlock, +} from "./utils"; + +export interface ConverseRequest { + modelId: string; + messages: Message[]; + inferenceConfig?: { + maxTokens?: number; + temperature?: number; + topP?: number; + }; + system?: string; + tools?: Array<{ + type: "function"; + function: { + name: string; + description: string; + parameters: { + type: string; + properties: Record; + required: string[]; + }; + }; + }>; +} + +interface ContentItem { + type: string; + text?: string; + image_url?: { + url: string; + }; + document?: { + format: string; + name: string; + source: { + bytes: string; + }; + }; +} + +type ProcessedContent = + | ContentItem + | BedrockTextBlock + | BedrockImageBlock + | BedrockDocumentBlock + | { + type: string; + source: { type: string; media_type: string; data: string }; + }; + +// Helper function to format request body based on model type +export function formatRequestBody(request: ConverseRequest) { + const baseModel = request.modelId; + const messages = validateMessageOrder(request.messages).map((msg) => ({ + role: msg.role, + content: Array.isArray(msg.content) + ? msg.content.map((item: ContentItem) => { + if (item.type === "image_url" && item.image_url?.url) { + // If it's a base64 image URL + const base64Match = item.image_url.url.match( + /^data:image\/([a-zA-Z]*);base64,([^"]*)$/, + ); + if (base64Match) { + return { + type: "image", + source: { + type: "base64", + media_type: `image/${base64Match[1]}`, + data: base64Match[2], + }, + }; + } + // If it's not a base64 URL, return as is + return item; + } + if ("document" in item) { + try { + return processDocumentContent(item); + } catch (error) { + console.error("Error processing document:", error); + return { + type: "text", + text: `[Document: ${item.document?.name || "Unknown"}]`, + }; + } + } + return { type: "text", text: item.text }; + }) + : [{ type: "text", text: msg.content }], + })); + + const systemPrompt = request.system + ? [{ type: "text", text: request.system }] + : undefined; + + const baseConfig = { + max_tokens: request.inferenceConfig?.maxTokens || 2048, + temperature: request.inferenceConfig?.temperature || 0.7, + top_p: request.inferenceConfig?.topP || 0.9, + }; + + if (baseModel.startsWith("anthropic.claude")) { + return { + messages, + system: systemPrompt, + anthropic_version: "bedrock-2023-05-31", + ...baseConfig, + ...(request.tools && { tools: request.tools }), + }; + } else if ( + baseModel.startsWith("meta.llama") || + baseModel.startsWith("mistral.") + ) { + return { + messages: messages.map((m) => ({ + role: m.role, + content: Array.isArray(m.content) + ? m.content.map((c: ProcessedContent) => { + if ("text" in c) return { type: "text", text: c.text || "" }; + if ("image_url" in c) + return { + type: "text", + text: `[Image: ${c.image_url?.url || "URL not provided"}]`, + }; + if ("document" in c) + return { + type: "text", + text: `[Document: ${c.document?.name || "Unknown"}]`, + }; + return { type: "text", text: "" }; + }) + : [{ type: "text", text: m.content }], + })), + ...baseConfig, + stop_sequences: ["\n\nHuman:", "\n\nAssistant:"], + }; + } else if (baseModel.startsWith("amazon.titan")) { + const formattedText = messages.map((m) => ({ + role: m.role, + content: [ + { + type: "text", + text: `${m.role === "user" ? "Human" : "Assistant"}: ${ + Array.isArray(m.content) + ? m.content + .map((c: ProcessedContent) => { + if ("text" in c) return c.text || ""; + if ("image_url" in c) + return `[Image: ${ + c.image_url?.url || "URL not provided" + }]`; + if ("document" in c) + return `[Document: ${c.document?.name || "Unknown"}]`; + return ""; + }) + .join("") + : m.content + }`, + }, + ], + })); + + return { + messages: formattedText, + textGenerationConfig: { + maxTokenCount: baseConfig.max_tokens, + temperature: baseConfig.temperature, + topP: baseConfig.top_p, + stopSequences: ["Human:", "Assistant:"], + }, + }; + } + + throw new Error(`Unsupported model: ${baseModel}`); +} + +// Helper function to parse and format response based on model type +export function parseModelResponse(responseBody: string, modelId: string): any { + const baseModel = modelId; + + try { + const response = JSON.parse(responseBody); + + // Common response format for all models + const formatResponse = (content: string | any[]) => ({ + role: "assistant", + content: Array.isArray(content) + ? content.map((item) => { + if (typeof item === "string") { + return { type: "text", text: item }; + } + // Handle different content types + if ("text" in item) { + return { type: "text", text: item.text || "" }; + } + if ("image" in item) { + return { + type: "image_url", + image_url: { + url: `data:image/${ + item.source?.media_type || "image/png" + };base64,${item.source?.data || ""}`, + }, + }; + } + // Document responses are converted to text + if ("document" in item) { + return { + type: "text", + text: `[Document Content]\n${item.text || ""}`, + }; + } + return { type: "text", text: item.text || "" }; + }) + : [{ type: "text", text: content }], + stop_reason: response.stop_reason || response.stopReason || "end_turn", + usage: response.usage || { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + }, + }); + + if (baseModel.startsWith("anthropic.claude")) { + // Handle the new Converse API response format + if (response.output?.message) { + return { + role: response.output.message.role, + content: response.output.message.content.map((item: any) => { + if ("text" in item) return { type: "text", text: item.text || "" }; + if ("image" in item) { + return { + type: "image_url", + image_url: { + url: `data:${item.source?.media_type || "image/png"};base64,${ + item.source?.data || "" + }`, + }, + }; + } + return { type: "text", text: item.text || "" }; + }), + stop_reason: response.stopReason, + usage: response.usage, + }; + } + // Fallback for older format + return formatResponse( + response.content || + (response.completion + ? [{ type: "text", text: response.completion }] + : []), + ); + } else if (baseModel.startsWith("meta.llama")) { + return formatResponse(response.generation || response.completion || ""); + } else if (baseModel.startsWith("amazon.titan")) { + return formatResponse(response.results?.[0]?.outputText || ""); + } else if (baseModel.startsWith("mistral.")) { + return formatResponse( + response.outputs?.[0]?.text || response.response || "", + ); + } + + throw new Error(`Unsupported model: ${baseModel}`); + } catch (e) { + console.error("[Bedrock] Failed to parse response:", e); + // Return raw text as fallback + return { + role: "assistant", + content: [{ type: "text", text: responseBody }], + }; + } +} diff --git a/app/api/bedrock/utils.ts b/app/api/bedrock/utils.ts new file mode 100644 index 000000000..85cd517b4 --- /dev/null +++ b/app/api/bedrock/utils.ts @@ -0,0 +1,218 @@ +import { MultimodalContent } from "../../client/api"; + +export interface Message { + role: string; + content: string | MultimodalContent[]; +} + +export interface ImageSource { + bytes: string; // base64 encoded image bytes +} + +export interface DocumentSource { + bytes: string; // base64 encoded document bytes +} + +export interface BedrockImageBlock { + image: { + format: "png" | "jpeg" | "gif" | "webp"; + source: ImageSource; + }; +} + +export interface BedrockDocumentBlock { + document: { + format: + | "pdf" + | "csv" + | "doc" + | "docx" + | "xls" + | "xlsx" + | "html" + | "txt" + | "md"; + name: string; + source: DocumentSource; + }; +} + +export interface BedrockTextBlock { + text: string; +} + +export type BedrockContentBlock = + | BedrockTextBlock + | BedrockImageBlock + | BedrockDocumentBlock; + +export interface BedrockResponse { + content?: any[]; + completion?: string; + stop_reason?: string; + usage?: { + input_tokens: number; + output_tokens: number; + total_tokens: number; + }; + tool_calls?: any[]; +} + +// Helper function to get the base model type from modelId +export function getModelType(modelId: string): string { + if (modelId.includes("inference-profile")) { + const match = modelId.match(/us\.(meta\.llama.+?)$/); + if (match) return match[1]; + } + return modelId; +} + +// Helper function to validate model ID +export function validateModelId(modelId: string): string | null { + // Check if model requires inference profile + if ( + modelId.startsWith("meta.llama") && + !modelId.includes("inference-profile") + ) { + return "Llama models require an inference profile. Please use the full inference profile ARN."; + } + return null; +} + +// Helper function to process document content for Bedrock +export function processDocumentContent(content: any): BedrockContentBlock { + if ( + !content?.document?.format || + !content?.document?.name || + !content?.document?.source?.bytes + ) { + throw new Error("Invalid document content format"); + } + + const format = content.document.format.toLowerCase(); + if ( + !["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"].includes( + format, + ) + ) { + throw new Error(`Unsupported document format: ${format}`); + } + + return { + document: { + format: format as BedrockDocumentBlock["document"]["format"], + name: sanitizeDocumentName(content.document.name), + source: { + bytes: content.document.source.bytes, + }, + }, + }; +} + +// Helper function to format content for Bedrock +export function formatContent( + content: string | MultimodalContent[], +): BedrockContentBlock[] { + if (typeof content === "string") { + return [{ text: content }]; + } + + const formattedContent: BedrockContentBlock[] = []; + + for (const item of content) { + if (item.type === "text" && item.text) { + formattedContent.push({ text: item.text }); + } else if (item.type === "image_url" && item.image_url?.url) { + // Extract base64 data from data URL + const base64Match = item.image_url.url.match( + /^data:image\/([a-zA-Z]*);base64,([^"]*)$/, + ); + if (base64Match) { + const format = base64Match[1].toLowerCase(); + if (["png", "jpeg", "gif", "webp"].includes(format)) { + formattedContent.push({ + image: { + format: format as "png" | "jpeg" | "gif" | "webp", + source: { + bytes: base64Match[2], + }, + }, + }); + } + } + } else if ("document" in item) { + try { + formattedContent.push(processDocumentContent(item)); + } catch (error) { + console.error("Error processing document:", error); + // Convert document to text as fallback + formattedContent.push({ + text: `[Document: ${(item as any).document?.name || "Unknown"}]`, + }); + } + } + } + + return formattedContent; +} + +// Helper function to ensure messages alternate between user and assistant +export function validateMessageOrder(messages: Message[]): Message[] { + const validatedMessages: Message[] = []; + let lastRole = ""; + + for (const message of messages) { + if (message.role === lastRole) { + // Skip duplicate roles to maintain alternation + continue; + } + validatedMessages.push(message); + lastRole = message.role; + } + + return validatedMessages; +} + +// Helper function to sanitize document names according to Bedrock requirements +function sanitizeDocumentName(name: string): string { + // Remove any characters that aren't alphanumeric, whitespace, hyphens, or parentheses + let sanitized = name.replace(/[^a-zA-Z0-9\s\-\(\)\[\]]/g, ""); + // Replace multiple whitespace characters with a single space + sanitized = sanitized.replace(/\s+/g, " "); + // Trim whitespace from start and end + return sanitized.trim(); +} + +// Helper function to convert Bedrock response back to MultimodalContent format +export function convertBedrockResponseToMultimodal( + response: BedrockResponse, +): string | MultimodalContent[] { + if (response.completion) { + return response.completion; + } + + if (!response.content) { + return ""; + } + + return response.content.map((block) => { + if ("text" in block) { + return { + type: "text", + text: block.text, + }; + } else if ("image" in block) { + return { + type: "image_url", + image_url: { + url: `data:image/${block.image.format};base64,${block.image.source.bytes}`, + }, + }; + } + // Document responses are converted to text content + return { + type: "text", + text: block.text || "", + }; + }); +} diff --git a/app/client/api.ts b/app/client/api.ts index 4238c2a26..e547bea0a 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -12,6 +12,7 @@ import { useChatStore, } from "../store"; import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; +import { BedrockApi } from "./platforms/bedrock"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; @@ -129,6 +130,9 @@ export class ClientApi { constructor(provider: ModelProvider = ModelProvider.GPT) { switch (provider) { + case ModelProvider.Bedrock: + this.llm = new BedrockApi(); + break; case ModelProvider.GeminiPro: this.llm = new GeminiProApi(); break; @@ -235,6 +239,7 @@ export function getHeaders(ignoreHeaders: boolean = false) { function getConfig() { const modelConfig = chatStore.currentSession().mask.modelConfig; + const isBedrock = modelConfig.providerName === ServiceProvider.Bedrock; const isGoogle = modelConfig.providerName === ServiceProvider.Google; const isAzure = modelConfig.providerName === ServiceProvider.Azure; const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; @@ -247,6 +252,8 @@ export function getHeaders(ignoreHeaders: boolean = false) { const isEnabledAccessControl = accessStore.enabledAccessControl(); const apiKey = isGoogle ? accessStore.googleApiKey + : isBedrock + ? accessStore.awsAccessKeyId // Use AWS access key for Bedrock : isAzure ? accessStore.azureApiKey : isAnthropic @@ -265,6 +272,7 @@ export function getHeaders(ignoreHeaders: boolean = false) { : "" : accessStore.openaiApiKey; return { + isBedrock, isGoogle, isAzure, isAnthropic, @@ -286,10 +294,13 @@ export function getHeaders(ignoreHeaders: boolean = false) { ? "x-api-key" : isGoogle ? "x-goog-api-key" + : isBedrock + ? "x-api-key" : "Authorization"; } const { + isBedrock, isGoogle, isAzure, isAnthropic, @@ -302,17 +313,27 @@ export function getHeaders(ignoreHeaders: boolean = false) { const authHeader = getAuthHeader(); - const bearerToken = getBearerToken( - apiKey, - isAzure || isAnthropic || isGoogle, - ); - - if (bearerToken) { - headers[authHeader] = bearerToken; - } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { - headers["Authorization"] = getBearerToken( - ACCESS_CODE_PREFIX + accessStore.accessCode, + if (isBedrock) { + // Add AWS credentials for Bedrock + headers["X-Region"] = accessStore.awsRegion; + headers["X-Access-Key"] = accessStore.awsAccessKeyId; + headers["X-Secret-Key"] = accessStore.awsSecretAccessKey; + if (accessStore.awsSessionToken) { + headers["X-Session-Token"] = accessStore.awsSessionToken; + } + } else { + const bearerToken = getBearerToken( + apiKey, + isAzure || isAnthropic || isGoogle, ); + + if (bearerToken) { + headers[authHeader] = bearerToken; + } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { + headers["Authorization"] = getBearerToken( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); + } } return headers; @@ -320,6 +341,8 @@ export function getHeaders(ignoreHeaders: boolean = false) { export function getClientApi(provider: ServiceProvider): ClientApi { switch (provider) { + case ServiceProvider.Bedrock: + return new ClientApi(ModelProvider.Bedrock); case ServiceProvider.Google: return new ClientApi(ModelProvider.GeminiPro); case ServiceProvider.Anthropic: diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts new file mode 100644 index 000000000..f8954f9d7 --- /dev/null +++ b/app/client/platforms/bedrock.ts @@ -0,0 +1,223 @@ +import { ApiPath } from "../../constant"; +import { ChatOptions, getHeaders, LLMApi, SpeechOptions } from "../api"; +import { + useAccessStore, + useAppConfig, + useChatStore, + usePluginStore, +} from "../../store"; +import { preProcessImageContent, stream } from "../../utils/chat"; +import Locale from "../../locales"; + +export interface BedrockChatRequest { + model: string; + messages: Array<{ + role: string; + content: + | string + | Array<{ + type: string; + text?: string; + image_url?: { url: string }; + document?: { + format: string; + name: string; + source: { + bytes: string; + }; + }; + }>; + }>; + temperature?: number; + top_p?: number; + max_tokens?: number; + stream?: boolean; +} + +export class BedrockApi implements LLMApi { + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + + extractMessage(res: any) { + console.log("[Response] bedrock response: ", res); + return res; + } + + async chat(options: ChatOptions): Promise { + const shouldStream = !!options.config.stream; + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const accessStore = useAccessStore.getState(); + + if ( + !accessStore.awsRegion || + !accessStore.awsAccessKeyId || + !accessStore.awsSecretAccessKey + ) { + console.log("AWS credentials are not set"); + let responseText = ""; + const responseTexts = [responseText]; + responseTexts.push(Locale.Error.Unauthorized); + responseText = responseTexts.join("\n\n"); + options.onFinish(responseText); + return; + } + + // Process messages to handle image and document content + const messages = await Promise.all( + options.messages.map(async (v) => { + const content = await preProcessImageContent(v.content); + // If content is an array (multimodal), ensure each item is properly formatted + if (Array.isArray(content)) { + return { + role: v.role, + content: content.map((item) => { + if (item.type === "image_url" && item.image_url?.url) { + // If the URL is a base64 data URL, use it directly + if (item.image_url.url.startsWith("data:image/")) { + return item; + } + // Otherwise, it's a regular URL that needs to be converted to base64 + // The conversion should have been handled by preProcessImageContent + return item; + } + if ("document" in item) { + // Handle document content + const doc = item as any; + if ( + doc?.document?.format && + doc?.document?.name && + doc?.document?.source?.bytes + ) { + return { + type: "document", + document: { + format: doc.document.format, + name: doc.document.name, + source: { + bytes: doc.document.source.bytes, + }, + }, + }; + } + } + return item; + }), + }; + } + // If content is a string, return it as is + return { + role: v.role, + content, + }; + }), + ); + + const requestBody: BedrockChatRequest = { + messages, + stream: shouldStream, + model: modelConfig.model, + max_tokens: modelConfig.max_tokens, + temperature: modelConfig.temperature, + top_p: modelConfig.top_p, + }; + + console.log("[Bedrock] Request:", { + model: modelConfig.model, + messages: messages, + }); + + const controller = new AbortController(); + options.onController?.(controller); + + const headers: Record = { + ...getHeaders(), + "X-Region": accessStore.awsRegion, + "X-Access-Key": accessStore.awsAccessKeyId, + "X-Secret-Key": accessStore.awsSecretAccessKey, + }; + + if (accessStore.awsSessionToken) { + headers["X-Session-Token"] = accessStore.awsSessionToken; + } + + try { + if (shouldStream) { + let responseText = ""; + const pluginStore = usePluginStore.getState(); + const currentSession = useChatStore.getState().currentSession(); + const [tools, funcs] = pluginStore.getAsTools( + currentSession.mask?.plugin || [], + ); + + await stream( + `${ApiPath.Bedrock}/invoke`, + requestBody, + headers, + Array.isArray(tools) ? tools : [], + funcs || {}, + controller, + (chunk: string) => { + try { + responseText += chunk; + return chunk; + } catch (e) { + console.error("[Request] parse error", chunk, e); + return ""; + } + }, + ( + requestPayload: any, + toolCallMessage: any, + toolCallResult: any[], + ) => { + console.log("[Bedrock] processToolMessage", { + requestPayload, + toolCallMessage, + toolCallResult, + }); + }, + options, + ); + } else { + const response = await fetch(`${ApiPath.Bedrock}/invoke`, { + method: "POST", + headers, + body: JSON.stringify(requestBody), + signal: controller.signal, + }); + + if (!response.ok) { + const error = await response.text(); + console.error("[Bedrock] Error response:", error); + throw new Error(`Bedrock API error: ${error}`); + } + + const text = await response.text(); + options.onFinish(text); + } + } catch (e) { + console.error("[Bedrock] Chat error:", e); + options.onError?.(e as Error); + } + } + + async usage() { + return { + used: 0, + total: 0, + }; + } + + async models() { + return []; + } +} diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 666caece8..9c6d9793c 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -963,7 +963,75 @@ export function Settings() { ); - + const bedrockConfigComponent = accessStore.provider === + ServiceProvider.Bedrock && ( + <> + + + accessStore.update( + (access) => (access.awsRegion = e.currentTarget.value), + ) + } + /> + + + { + accessStore.update( + (access) => (access.awsAccessKeyId = e.currentTarget.value), + ); + }} + /> + + + { + accessStore.update( + (access) => (access.awsSecretAccessKey = e.currentTarget.value), + ); + }} + /> + + + { + accessStore.update( + (access) => (access.awsSessionToken = e.currentTarget.value), + ); + }} + /> + + + ); const baiduConfigComponent = accessStore.provider === ServiceProvider.Baidu && ( <> @@ -1682,6 +1750,7 @@ export function Settings() { {openAIConfigComponent} + {bedrockConfigComponent} {azureConfigComponent} {googleConfigComponent} {anthropicConfigComponent} diff --git a/app/config/server.ts b/app/config/server.ts index eac4ba0cf..7e130aa0e 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -12,6 +12,10 @@ declare global { BASE_URL?: string; OPENAI_ORG_ID?: string; // openai only + // bedrock only + BEDROCK_URL?: string; + BEDROCK_API_KEY?: string; + VERCEL?: string; BUILD_MODE?: "standalone" | "export"; BUILD_APP?: string; // is building desktop app @@ -139,7 +143,7 @@ export const getServerSideConfig = () => { } const isStability = !!process.env.STABILITY_API_KEY; - + const isBedrock = !!process.env.BEDROCK_API_KEY; const isAzure = !!process.env.AZURE_URL; const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; @@ -168,6 +172,10 @@ export const getServerSideConfig = () => { apiKey: getApiKey(process.env.OPENAI_API_KEY), openaiOrgId: process.env.OPENAI_ORG_ID, + isBedrock, + bedrockUrl: process.env.BEDROCK_URL, + bedrockApiKey: getApiKey(process.env.BEDROCK_API_KEY), + isStability, stabilityUrl: process.env.STABILITY_URL, stabilityApiKey: getApiKey(process.env.STABILITY_API_KEY), diff --git a/app/constant.ts b/app/constant.ts index 9774bb594..0a9039878 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -12,6 +12,8 @@ export const RUNTIME_CONFIG_DOM = "danger-runtime-config"; export const STABILITY_BASE_URL = "https://api.stability.ai"; export const OPENAI_BASE_URL = "https://api.openai.com"; +export const BEDROCK_BASE_URL = + "https://bedrock-runtime.us-west-2.amazonaws.com"; export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; @@ -49,6 +51,7 @@ export enum Path { export enum ApiPath { Cors = "", + Bedrock = "/api/bedrock", Azure = "/api/azure", OpenAI = "/api/openai", Anthropic = "/api/anthropic", @@ -115,6 +118,7 @@ export enum ServiceProvider { Stability = "Stability", Iflytek = "Iflytek", XAI = "XAI", + Bedrock = "Bedrock", } // Google API safety settings, see https://ai.google.dev/gemini-api/docs/safety-settings @@ -128,6 +132,7 @@ export enum GoogleSafetySettingsThreshold { export enum ModelProvider { Stability = "Stability", + Bedrock = "Bedrock", GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", @@ -304,6 +309,26 @@ const openaiModels = [ "o1-preview", ]; +const bedrockModels = [ + // Claude Models + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", + // Amazon Titan Models + "amazon.titan-text-express-v1", + "amazon.titan-text-lite-v1", + // Meta Llama Models + "meta.llama3-2-1b-instruct-v1:0", + "meta.llama3-2-3b-instruct-v1:0", + "meta.llama3-2-11b-instruct-v1:0", + //Mistral + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2407-v1:0", +]; + const googleModels = [ "gemini-1.0-pro", "gemini-1.5-pro-latest", @@ -499,6 +524,17 @@ export const DEFAULT_MODELS = [ sorted: 11, }, })), + ...bedrockModels.map((name) => ({ + name, + available: true, + sorted: seq++, + provider: { + id: "bedrock", + providerName: "Bedrock", + providerType: "bedrock", + sorted: 12, + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 006fc8162..573969be7 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -342,6 +342,32 @@ const cn = { SubTitle: "除默认地址外,必须包含 http(s)://", }, }, + Bedrock: { + Region: { + Title: "AWS Region", + SubTitle: "The AWS region where Bedrock service is located", + Placeholder: "us-west-2", + }, + AccessKey: { + Title: "AWS Access Key ID", + SubTitle: "Your AWS access key ID for Bedrock service", + Placeholder: "AKIA...", + }, + SecretKey: { + Title: "AWS Secret Access Key", + SubTitle: "Your AWS secret access key for Bedrock service", + Placeholder: "****", + }, + SessionToken: { + Title: "AWS Session Token (Optional)", + SubTitle: "Your AWS session token if using temporary credentials", + Placeholder: "Optional session token", + }, + Endpoint: { + Title: "AWS Bedrock Endpoint", + SubTitle: "Custom endpoint for AWS Bedrock API. Default: ", + }, + }, Azure: { ApiKey: { Title: "接口密钥", diff --git a/app/locales/en.ts b/app/locales/en.ts index 7204bd946..9d3097ef8 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -346,6 +346,32 @@ const en: LocaleType = { SubTitle: "Must start with http(s):// or use /api/openai as default", }, }, + Bedrock: { + Region: { + Title: "AWS Region", + SubTitle: "The AWS region where Bedrock service is located", + Placeholder: "us-west-2", + }, + AccessKey: { + Title: "AWS Access Key ID", + SubTitle: "Your AWS access key ID for Bedrock service", + Placeholder: "AKIA...", + }, + SecretKey: { + Title: "AWS Secret Access Key", + SubTitle: "Your AWS secret access key for Bedrock service", + Placeholder: "****", + }, + SessionToken: { + Title: "AWS Session Token (Optional)", + SubTitle: "Your AWS session token if using temporary credentials", + Placeholder: "Optional session token", + }, + Endpoint: { + Title: "AWS Bedrock Endpoint", + SubTitle: "Custom endpoint for AWS Bedrock API. Default: ", + }, + }, Azure: { ApiKey: { Title: "Azure Api Key", diff --git a/app/store/access.ts b/app/store/access.ts index b3d412a2d..11127cbed 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -4,6 +4,7 @@ import { StoreKey, ApiPath, OPENAI_BASE_URL, + BEDROCK_BASE_URL, ANTHROPIC_BASE_URL, GEMINI_BASE_URL, BAIDU_BASE_URL, @@ -26,6 +27,7 @@ let fetchState = 0; // 0 not fetch, 1 fetching, 2 done const isApp = getClientConfig()?.buildMode === "export"; const DEFAULT_OPENAI_URL = isApp ? OPENAI_BASE_URL : ApiPath.OpenAI; +const DEFAULT_BEDROCK_URL = isApp ? BEDROCK_BASE_URL : ApiPath.Bedrock; const DEFAULT_GOOGLE_URL = isApp ? GEMINI_BASE_URL : ApiPath.Google; @@ -57,6 +59,16 @@ const DEFAULT_ACCESS_STATE = { openaiUrl: DEFAULT_OPENAI_URL, openaiApiKey: "", + // bedrock + bedrockUrl: DEFAULT_BEDROCK_URL, + bedrockApiKey: "", + awsRegion: "", + awsAccessKeyId: "", + awsSecretAccessKey: "", + awsSessionToken: "", + awsCognitoUser: false, + awsInferenceProfile: "", // Added inference profile field + // azure azureUrl: "", azureApiKey: "", @@ -141,6 +153,14 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["openaiApiKey"]); }, + isValidBedrock() { + return ensure(get(), [ + "awsAccessKeyId", + "awsSecretAccessKey", + "awsRegion", + ]); + }, + isValidAzure() { return ensure(get(), ["azureUrl", "azureApiKey", "azureApiVersion"]); }, @@ -186,6 +206,7 @@ export const useAccessStore = createPersistStore( // has token or has code or disabled access control return ( this.isValidOpenAI() || + this.isValidBedrock() || this.isValidAzure() || this.isValidGoogle() || this.isValidAnthropic() || diff --git a/app/utils.ts b/app/utils.ts index d8fc46330..78cfe5a0e 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -285,6 +285,9 @@ export function showPlugins(provider: ServiceProvider, model: string) { if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { return true; } + if (provider == ServiceProvider.Bedrock && !model.includes("claude-2")) { + return true; + } if (provider == ServiceProvider.Google && !model.includes("vision")) { return true; } diff --git a/package.json b/package.json index 5bca3b327..d1c8e8278 100644 --- a/package.json +++ b/package.json @@ -51,7 +51,8 @@ "sass": "^1.59.2", "spark-md5": "^3.0.2", "use-debounce": "^9.0.4", - "zustand": "^4.3.8" + "zustand": "^4.3.8", + "@aws-sdk/client-bedrock-runtime": "^3.679.0" }, "devDependencies": { "@tauri-apps/api": "^1.6.0",