From 58837f6deccea7d09b09611d646549ba8b20cc0f Mon Sep 17 00:00:00 2001 From: glay Date: Tue, 5 Nov 2024 17:28:19 +0800 Subject: [PATCH] =?UTF-8?q?=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20ap?= =?UTF-8?q?p/api/bedrock.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20?= =?UTF-8?q?=20app/client/platforms/bedrock.ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 549 +++++++------------------------- app/client/platforms/bedrock.ts | 179 +++++------ 2 files changed, 193 insertions(+), 535 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 8b5ddc47e..aeee8eb55 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -5,17 +5,13 @@ import { BedrockRuntimeClient, ConverseStreamCommand, ConverseStreamCommandInput, + Message, + ContentBlock, ConverseStreamOutput, - ModelStreamErrorException, - type Message, - type ContentBlock, - type SystemContentBlock, - type Tool, - type ToolChoice, - type ToolResultContentBlock, } from "@aws-sdk/client-bedrock-runtime"; -// 解密函数 +const ALLOWED_PATH = new Set(["converse"]); + function decrypt(str: string): string { try { return Buffer.from(str, "base64").toString().split("").reverse().join(""); @@ -24,14 +20,11 @@ function decrypt(str: string): string { } } -// Constants and Types -const ALLOWED_PATH = new Set(["converse"]); - export interface ConverseRequest { modelId: string; messages: { role: "user" | "assistant" | "system"; - content: string | ContentItem[]; + content: string | any[]; }[]; inferenceConfig?: { maxTokens?: number; @@ -39,324 +32,89 @@ export interface ConverseRequest { topP?: number; stopSequences?: string[]; }; - toolConfig?: { - tools: Tool[]; - toolChoice?: ToolChoice; - }; -} - -interface ContentItem { - type: "text" | "image_url" | "document" | "tool_use" | "tool_result"; - text?: string; - image_url?: { - url: string; // base64 data URL - }; - document?: { - format: DocumentFormat; - name: string; - source: { - bytes: string; // base64 - }; - }; - tool_use?: { - tool_use_id: string; - name: string; - input: any; - }; - tool_result?: { - tool_use_id: string; - content: ToolResultItem[]; - status: "success" | "error"; - }; -} - -interface ToolResultItem { - type: "text" | "image" | "document" | "json"; - text?: string; - image?: { - format: "png" | "jpeg" | "gif" | "webp"; - source: { - bytes: string; // base64 - }; - }; - document?: { - format: DocumentFormat; - name: string; - source: { - bytes: string; // base64 - }; - }; - json?: any; -} - -type DocumentFormat = - | "pdf" - | "csv" - | "doc" - | "docx" - | "xls" - | "xlsx" - | "html" - | "txt" - | "md"; - -function validateImageSize(base64Data: string): boolean { - const sizeInBytes = (base64Data.length * 3) / 4; - const maxSize = 3.75 * 1024 * 1024; - if (sizeInBytes > maxSize) { - throw new Error("Image size exceeds 3.75 MB limit"); - } - return true; -} - -// Content Processing Functions -function convertContentToAWSBlock(item: ContentItem): ContentBlock | null { - if (item.type === "text" && item.text) { - return { text: item.text }; - } - - if (item.type === "image_url" && item.image_url?.url) { - const base64Match = item.image_url.url.match( - /^data:image\/([a-zA-Z]*);base64,([^"]*)/, - ); - if (base64Match) { - const format = base64Match[1].toLowerCase(); - if (["png", "jpeg", "gif", "webp"].includes(format)) { - validateImageSize(base64Match[2]); - return { - image: { - format: format as "png" | "jpeg" | "gif" | "webp", - source: { - bytes: Uint8Array.from(Buffer.from(base64Match[2], "base64")), - }, - }, - }; - } - } - } - - if (item.type === "tool_use" && item.tool_use) { - return { - toolUse: { - toolUseId: item.tool_use.tool_use_id, - name: item.tool_use.name, - input: item.tool_use.input, - }, - }; - } - - if (item.type === "tool_result" && item.tool_result) { - const toolResultContent = item.tool_result.content - .map((resultItem) => { - if (resultItem.type === "text" && resultItem.text) { - return { text: resultItem.text } as ToolResultContentBlock; - } - if (resultItem.type === "image" && resultItem.image) { - return { - image: { - format: resultItem.image.format, - source: { - bytes: Uint8Array.from( - Buffer.from(resultItem.image.source.bytes, "base64"), - ), - }, - }, - } as ToolResultContentBlock; - } - if (resultItem.type === "document" && resultItem.document) { - return { - document: { - format: resultItem.document.format, - name: resultItem.document.name, - source: { - bytes: Uint8Array.from( - Buffer.from(resultItem.document.source.bytes, "base64"), - ), - }, - }, - } as ToolResultContentBlock; - } - if (resultItem.type === "json" && resultItem.json) { - return { json: resultItem.json } as ToolResultContentBlock; - } - return null; - }) - .filter((content): content is ToolResultContentBlock => content !== null); - - if (toolResultContent.length === 0) { - return null; - } - - return { - toolResult: { - toolUseId: item.tool_result.tool_use_id, - content: toolResultContent, - status: item.tool_result.status, - }, - }; - } - - return null; -} - -function convertContentToAWS(content: string | ContentItem[]): ContentBlock[] { - if (typeof content === "string") { - return [{ text: content }]; - } - - const blocks = content - .map(convertContentToAWSBlock) - .filter((block): block is ContentBlock => block !== null); - - return blocks.length > 0 ? blocks : [{ text: "" }]; -} - -function formatMessages(messages: ConverseRequest["messages"]): { - messages: Message[]; - systemPrompt?: SystemContentBlock[]; -} { - const systemMessages = messages.filter((msg) => msg.role === "system"); - const nonSystemMessages = messages.filter((msg) => msg.role !== "system"); - - const systemPrompt = - systemMessages.length > 0 - ? systemMessages.map((msg) => { - if (typeof msg.content === "string") { - return { text: msg.content } as SystemContentBlock; - } - const blocks = convertContentToAWS(msg.content); - return blocks[0] as SystemContentBlock; - }) - : undefined; - - const formattedMessages = nonSystemMessages.reduce( - (acc: Message[], curr, idx) => { - if (idx > 0 && curr.role === nonSystemMessages[idx - 1].role) { - return acc; - } - - const content = convertContentToAWS(curr.content); - if (content.length > 0) { - acc.push({ - role: curr.role as "user" | "assistant", - content, - }); - } - return acc; - }, - [], - ); - - if (formattedMessages.length === 0 || formattedMessages[0].role !== "user") { - formattedMessages.unshift({ - role: "user", - content: [{ text: "Hello" }], - }); - } - - if (formattedMessages[formattedMessages.length - 1].role !== "user") { - formattedMessages.push({ - role: "user", - content: [{ text: "Continue" }], - }); - } - - return { messages: formattedMessages, systemPrompt }; } function formatRequestBody( request: ConverseRequest, ): ConverseStreamCommandInput { - const { messages, systemPrompt } = formatMessages(request.messages); - const input: ConverseStreamCommandInput = { + const messages: Message[] = request.messages.map((msg) => ({ + role: msg.role === "system" ? "user" : msg.role, + content: Array.isArray(msg.content) + ? msg.content.map((item) => { + if (item.type === "tool_use") { + return { + toolUse: { + toolUseId: item.id, + name: item.name, + input: item.input || "{}", + }, + } as ContentBlock; + } + if (item.type === "tool_result") { + return { + toolResult: { + toolUseId: item.tool_use_id, + content: [{ text: item.content || ";" }], + status: "success", + }, + } as ContentBlock; + } + if (item.type === "text") { + return { text: item.text || ";" } as ContentBlock; + } + if (item.type === "image") { + return { + image: { + format: item.source.media_type.split("/")[1] as + | "png" + | "jpeg" + | "gif" + | "webp", + source: { + bytes: Uint8Array.from( + Buffer.from(item.source.data, "base64"), + ), + }, + }, + } as ContentBlock; + } + return { text: ";" } as ContentBlock; + }) + : [{ text: msg.content || ";" } as ContentBlock], + })); + + return { modelId: request.modelId, messages, - ...(systemPrompt && { system: systemPrompt }), + ...(request.inferenceConfig && { + inferenceConfig: request.inferenceConfig, + }), }; - - if (request.inferenceConfig) { - input.inferenceConfig = { - maxTokens: request.inferenceConfig.maxTokens, - temperature: request.inferenceConfig.temperature, - topP: request.inferenceConfig.topP, - stopSequences: request.inferenceConfig.stopSequences, - }; - } - - if (request.toolConfig) { - input.toolConfig = { - tools: request.toolConfig.tools, - toolChoice: request.toolConfig.toolChoice, - }; - } - - const logInput = { - ...input, - messages: messages.map((msg) => ({ - role: msg.role, - content: msg.content?.map((content) => { - if ("image" in content && content.image) { - return { - image: { - format: content.image.format, - source: { bytes: "[BINARY]" }, - }, - }; - } - if ("document" in content && content.document) { - return { - document: { ...content.document, source: { bytes: "[BINARY]" } }, - }; - } - return content; - }), - })), - }; - - console.log( - "[Bedrock] Formatted request:", - JSON.stringify(logInput, null, 2), - ); - return input; } -// Main Request Handler export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { - console.log("[Bedrock Route] params ", params); - if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } const subpath = params.path.join("/"); - if (!ALLOWED_PATH.has(subpath)) { - console.log("[Bedrock Route] forbidden path ", subpath); return NextResponse.json( - { - error: true, - msg: "you are not allowed to request " + subpath, - }, - { - status: 403, - }, + { error: true, msg: "Path not allowed: " + subpath }, + { status: 403 }, ); } const serverConfig = getServerSideConfig(); - - // 首先尝试使用环境变量中的凭证 let region = serverConfig.awsRegion; let accessKeyId = serverConfig.awsAccessKey; let secretAccessKey = serverConfig.awsSecretKey; let sessionToken = undefined; - // 如果环境变量中没有配置,则尝试使用前端传来的加密凭证 if (!region || !accessKeyId || !secretAccessKey) { - // 解密前端传来的凭证 region = decrypt(req.headers.get("X-Region") ?? ""); accessKeyId = decrypt(req.headers.get("X-Access-Key") ?? ""); secretAccessKey = decrypt(req.headers.get("X-Secret-Key") ?? ""); @@ -367,50 +125,19 @@ export async function handle( if (!region || !accessKeyId || !secretAccessKey) { return NextResponse.json( - { - error: true, - msg: "AWS credentials not found in environment variables or request headers", - }, - { - status: 401, - }, + { error: true, msg: "Missing AWS credentials" }, + { status: 401 }, ); } try { const client = new BedrockRuntimeClient({ region, - credentials: { - accessKeyId, - secretAccessKey, - sessionToken, - }, + credentials: { accessKeyId, secretAccessKey, sessionToken }, }); - const response = await handleConverseRequest(req, client); - return response; - } catch (e) { - console.error("[Bedrock] ", e); - return NextResponse.json( - { - error: true, - message: e instanceof Error ? e.message : "Unknown error", - details: prettyObject(e), - }, - { status: 500 }, - ); - } -} - -async function handleConverseRequest( - req: NextRequest, - client: BedrockRuntimeClient, -) { - try { const body = (await req.json()) as ConverseRequest; - const { modelId } = body; - - console.log("[Bedrock] Invoking model:", modelId); + console.log("[Bedrock] Request:", body.modelId); const command = new ConverseStreamCommand(formatRequestBody(body)); const response = await client.send(command); @@ -422,128 +149,71 @@ async function handleConverseRequest( const stream = new ReadableStream({ async start(controller) { try { - const responseStream = response.stream; - if (!responseStream) { - throw new Error("No stream in response"); - } - + const responseStream = + response.stream as AsyncIterable; for await (const event of responseStream) { - const output = event as ConverseStreamOutput; + if ( + "contentBlockStart" in event && + event.contentBlockStart?.start?.toolUse && + event.contentBlockStart.contentBlockIndex !== undefined + ) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "content_block", + content_block: { + type: "tool_use", + id: event.contentBlockStart.start.toolUse.toolUseId, + name: event.contentBlockStart.start.toolUse.name, + }, + index: event.contentBlockStart.contentBlockIndex, + })}\n\n`, + ); + } else if ( + "contentBlockDelta" in event && + event.contentBlockDelta?.delta && + event.contentBlockDelta.contentBlockIndex !== undefined + ) { + const delta = event.contentBlockDelta.delta; - if ("messageStart" in output && output.messageStart?.role) { - controller.enqueue( - `data: ${JSON.stringify({ - stream: { - messageStart: { role: output.messageStart.role }, - }, - })}\n\n`, - ); - } else if ( - "contentBlockStart" in output && - output.contentBlockStart - ) { - controller.enqueue( - `data: ${JSON.stringify({ - stream: { - contentBlockStart: { - contentBlockIndex: - output.contentBlockStart.contentBlockIndex, - start: output.contentBlockStart.start, - }, - }, - })}\n\n`, - ); - } else if ( - "contentBlockDelta" in output && - output.contentBlockDelta?.delta - ) { - if ("text" in output.contentBlockDelta.delta) { + if ("text" in delta && delta.text) { controller.enqueue( `data: ${JSON.stringify({ - stream: { - contentBlockDelta: { - delta: { text: output.contentBlockDelta.delta.text }, - contentBlockIndex: - output.contentBlockDelta.contentBlockIndex, - }, + type: "content_block_delta", + delta: { + type: "text_delta", + text: delta.text, }, + index: event.contentBlockDelta.contentBlockIndex, })}\n\n`, ); - } else if ("toolUse" in output.contentBlockDelta.delta) { + } else if ("toolUse" in delta && delta.toolUse?.input) { controller.enqueue( `data: ${JSON.stringify({ - stream: { - contentBlockDelta: { - delta: { - toolUse: { - input: - output.contentBlockDelta.delta.toolUse?.input, - }, - }, - contentBlockIndex: - output.contentBlockDelta.contentBlockIndex, - }, + type: "content_block_delta", + delta: { + type: "input_json_delta", + partial_json: delta.toolUse.input, }, + index: event.contentBlockDelta.contentBlockIndex, })}\n\n`, ); } } else if ( - "contentBlockStop" in output && - output.contentBlockStop + "contentBlockStop" in event && + event.contentBlockStop?.contentBlockIndex !== undefined ) { controller.enqueue( `data: ${JSON.stringify({ - stream: { - contentBlockStop: { - contentBlockIndex: - output.contentBlockStop.contentBlockIndex, - }, - }, - })}\n\n`, - ); - } else if ("messageStop" in output && output.messageStop) { - controller.enqueue( - `data: ${JSON.stringify({ - stream: { - messageStop: { - stopReason: output.messageStop.stopReason, - additionalModelResponseFields: - output.messageStop.additionalModelResponseFields, - }, - }, - })}\n\n`, - ); - } else if ("metadata" in output && output.metadata) { - controller.enqueue( - `data: ${JSON.stringify({ - stream: { - metadata: { - usage: output.metadata.usage, - metrics: output.metadata.metrics, - trace: output.metadata.trace, - }, - }, + type: "content_block_stop", + index: event.contentBlockStop.contentBlockIndex, })}\n\n`, ); } } controller.close(); } catch (error) { - const errorResponse = { - stream: { - error: - error instanceof Error - ? error.constructor.name - : "UnknownError", - message: error instanceof Error ? error.message : "Unknown error", - ...(error instanceof ModelStreamErrorException && { - originalStatusCode: error.originalStatusCode, - originalMessage: error.originalMessage, - }), - }, - }; - controller.enqueue(`data: ${JSON.stringify(errorResponse)}\n\n`); - controller.close(); + console.error("[Bedrock] Stream error:", error); + controller.error(error); } }, }); @@ -555,8 +225,15 @@ async function handleConverseRequest( Connection: "keep-alive", }, }); - } catch (error) { - console.error("[Bedrock] Request error:", error); - throw error; + } catch (e) { + console.error("[Bedrock] Error:", e); + return NextResponse.json( + { + error: true, + message: e instanceof Error ? e.message : "Unknown error", + details: prettyObject(e), + }, + { status: 500 }, + ); } } diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index b44070352..a2324b818 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -16,6 +16,7 @@ import { import { getMessageTextContent, isVisionModel } from "../../utils"; import { fetch } from "../../utils/stream"; import { preProcessImageContent, stream } from "../../utils/chat"; +import { RequestPayload } from "./openai"; export type MultiBlockContent = { type: "image" | "text"; @@ -39,12 +40,6 @@ const ClaudeMapper = { } as const; export class BedrockApi implements LLMApi { - usage(): Promise { - throw new Error("Method not implemented."); - } - models(): Promise { - throw new Error("Method not implemented."); - } speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); } @@ -149,34 +144,15 @@ export class BedrockApi implements LLMApi { }); } - const [tools, funcs] = usePluginStore - .getState() - .getAsTools(useChatStore.getState().currentSession().mask?.plugin || []); - const requestBody = { modelId: options.config.model, - messages: messages.filter((msg) => msg.content.length > 0), + messages: prompt, inferenceConfig: { maxTokens: modelConfig.max_tokens, temperature: modelConfig.temperature, topP: modelConfig.top_p, stopSequences: [], }, - toolConfig: - Array.isArray(tools) && tools.length > 0 - ? { - tools: tools.map((tool: any) => ({ - toolSpec: { - name: tool?.function?.name, - description: tool?.function?.description, - inputSchema: { - json: tool?.function?.parameters, - }, - }, - })), - toolChoice: { auto: {} }, - } - : undefined, }; const conversePath = `${ApiPath.Bedrock}/converse`; @@ -185,83 +161,80 @@ export class BedrockApi implements LLMApi { if (shouldStream) { let currentToolUse: ChatMessageTool | null = null; + let index = -1; + const [tools, funcs] = usePluginStore + .getState() + .getAsTools( + useChatStore.getState().currentSession().mask?.plugin || [], + ); return stream( conversePath, requestBody, getHeaders(), - Array.isArray(tools) - ? tools.map((tool: any) => ({ - name: tool?.function?.name, - description: tool?.function?.description, - input_schema: tool?.function?.parameters, - })) - : [], + // @ts-ignore + tools.map((tool) => ({ + name: tool?.function?.name, + description: tool?.function?.description, + input_schema: tool?.function?.parameters, + })), funcs, controller, // parseSSE + // parseSSE (text: string, runTools: ChatMessageTool[]) => { - const parsed = JSON.parse(text); - const event = parsed.stream; + // console.log("parseSSE", text, runTools); + let chunkJson: + | undefined + | { + type: "content_block_delta" | "content_block_stop"; + content_block?: { + type: "tool_use"; + id: string; + name: string; + }; + delta?: { + type: "text_delta" | "input_json_delta"; + text?: string; + partial_json?: string; + }; + index: number; + }; + chunkJson = JSON.parse(text); - if (!event) { - console.warn("[Bedrock] Unexpected event format:", parsed); - return ""; - } - - if (event.messageStart) { - return ""; - } - - if (event.contentBlockStart?.start?.toolUse) { - const { toolUseId, name } = event.contentBlockStart.start.toolUse; - currentToolUse = { - id: toolUseId, + if (chunkJson?.content_block?.type == "tool_use") { + index += 1; + const id = chunkJson?.content_block.id; + const name = chunkJson?.content_block.name; + runTools.push({ + id, type: "function", function: { name, arguments: "", }, - }; - runTools.push(currentToolUse); - return ""; + }); } - - if (event.contentBlockDelta?.delta?.text) { - return event.contentBlockDelta.delta.text; - } - if ( - event.contentBlockDelta?.delta?.toolUse?.input && - currentToolUse?.function + chunkJson?.delta?.type == "input_json_delta" && + chunkJson?.delta?.partial_json ) { - currentToolUse.function.arguments += - event.contentBlockDelta.delta.toolUse.input; - return ""; + // @ts-ignore + runTools[index]["function"]["arguments"] += + chunkJson?.delta?.partial_json; } - - if ( - event.internalServerException || - event.modelStreamErrorException || - event.validationException || - event.throttlingException || - event.serviceUnavailableException - ) { - const errorMessage = - event.internalServerException?.message || - event.modelStreamErrorException?.message || - event.validationException?.message || - event.throttlingException?.message || - event.serviceUnavailableException?.message || - "Unknown error"; - throw new Error(errorMessage); - } - - return ""; + return chunkJson?.delta?.text; }, - // processToolMessage - (requestPayload: any, toolCallMessage: any, toolCallResult: any[]) => { - currentToolUse = null; + // processToolMessage, include tool_calls message and tool call results + ( + requestPayload: RequestPayload, + toolCallMessage: any, + toolCallResult: any[], + ) => { + // reset index value + index = -1; + // @ts-ignore requestPayload?.messages?.splice( + // @ts-ignore requestPayload?.messages?.length, 0, { @@ -277,6 +250,7 @@ export class BedrockApi implements LLMApi { }), ), }, + // @ts-ignore ...toolCallResult.map((result) => ({ role: "user", content: [ @@ -292,26 +266,33 @@ export class BedrockApi implements LLMApi { options, ); } else { + const payload = { + method: "POST", + body: JSON.stringify(requestBody), + signal: controller.signal, + headers: { + ...getHeaders(), // get common headers + }, + }; + try { - const response = await fetch(conversePath, { - method: "POST", - headers: getHeaders(), - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + controller.signal.onabort = () => options.onFinish(""); - if (!response.ok) { - const error = await response.text(); - throw new Error(`Bedrock API error: ${error}`); - } + const res = await fetch(conversePath, payload); + const resJson = await res.json(); - const responseBody = await response.json(); - const content = this.extractMessage(responseBody); - options.onFinish(content); - } catch (e: any) { - console.error("[Bedrock] Chat error:", e); - throw e; + const message = this.extractMessage(resJson); + options.onFinish(message); + } catch (e) { + console.error("failed to chat", e); + options.onError?.(e as Error); } } } + usage(): Promise { + throw new Error("Method not implemented."); + } + models(): Promise { + throw new Error("Method not implemented."); + } }