From 0c5585064143aa5d4513f4169b32fd3fe6c2300a Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 7 Dec 2024 12:18:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96bedrock=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E7=9A=84=E7=BC=93=E5=86=B2?= =?UTF-8?q?=E6=9C=BA=E5=88=B6=EF=BC=8C=E7=AE=80=E5=8C=96app=E5=92=8C?= =?UTF-8?q?=E5=90=8E=E5=8F=B0api=E8=B0=83=E7=94=A8=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=88=A4=E6=96=AD=E5=92=8C=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 47 +--- app/client/platforms/bedrock.ts | 444 +++++++++++++++++++++-------- app/utils/aws.ts | 477 +++++++++++++++++++------------- 3 files changed, 619 insertions(+), 349 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 1cc177348..bf1313eb0 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -4,8 +4,6 @@ import { sign, decrypt, getBedrockEndpoint, - transformBedrockStream, - parseEventData, BedrockCredentials, } from "../utils/aws"; import { getServerSideConfig } from "../config/server"; @@ -178,50 +176,7 @@ async function requestBedrock(req: NextRequest) { "Empty response from Bedrock. Please check AWS credentials and permissions.", ); } - - // Handle non-streaming response - if (!shouldStream) { - const responseText = await res.text(); - const parsed = parseEventData(new TextEncoder().encode(responseText)); - if (!parsed) { - throw new Error("Failed to parse Bedrock response"); - } - return NextResponse.json(parsed); - } - - // Handle streaming response - const transformedStream = transformBedrockStream(res.body, modelId); - const encoder = new TextEncoder(); - const stream = new ReadableStream({ - async start(controller) { - try { - for await (const chunk of transformedStream) { - // Ensure we're sending non-empty chunks - if (chunk && chunk.trim()) { - controller.enqueue(encoder.encode(chunk)); - } - } - controller.close(); - } catch (err) { - console.error("[Bedrock Stream Error]:", err); - controller.error(err); - } - }, - }); - - const newHeaders = new Headers(res.headers); - newHeaders.delete("www-authenticate"); - newHeaders.set("Content-Type", "text/event-stream"); - newHeaders.set("Cache-Control", "no-cache"); - newHeaders.set("Connection", "keep-alive"); - // to disable nginx buffering - newHeaders.set("X-Accel-Buffering", "no"); - - return new Response(stream, { - status: res.status, - statusText: res.statusText, - headers: newHeaders, - }); + return res; } catch (e) { console.error("[Bedrock Request Error]:", e); throw e; diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 2260eae05..9f5932698 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -7,13 +7,21 @@ import { useAccessStore, ChatMessageTool, } from "@/app/store"; -import { preProcessImageContent, stream } from "@/app/utils/chat"; +import { preProcessImageContent } from "@/app/utils/chat"; import { getMessageTextContent, isVisionModel } from "@/app/utils"; import { ApiPath, BEDROCK_BASE_URL } from "@/app/constant"; import { getClientConfig } from "@/app/config/client"; -import { extractMessage } from "@/app/utils/aws"; +import { + extractMessage, + processMessage, + processChunks, + parseEventData, + sign, +} from "@/app/utils/aws"; import { RequestPayload } from "./openai"; -import { fetch } from "@/app/utils/stream"; +import { REQUEST_TIMEOUT_MS } from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import Locale from "@/app/locales"; const ClaudeMapper = { assistant: "assistant", @@ -26,18 +34,7 @@ const MistralMapper = { user: "user", assistant: "assistant", } as const; - -type ClaudeRole = keyof typeof ClaudeMapper; type MistralRole = keyof typeof MistralMapper; - -interface Tool { - function?: { - name?: string; - description?: string; - parameters?: any; - }; -} - export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); @@ -47,6 +44,24 @@ export class BedrockApi implements LLMApi { const model = modelConfig.model; const visionModel = isVisionModel(modelConfig.model); + // Handle Nova models + if (model.startsWith("us.amazon.nova")) { + return { + inferenceConfig: { + max_tokens: modelConfig.max_tokens || 1000, + }, + messages: messages.map((message) => ({ + role: message.role, + content: [ + { + type: "text", + text: getMessageTextContent(message), + }, + ], + })), + }; + } + // Handle Titan models if (model.startsWith("amazon.titan")) { const inputText = messages @@ -223,11 +238,34 @@ export class BedrockApi implements LLMApi { ); } + let finalRequestBody = this.formatRequestBody(messages, modelConfig); + try { - const chatPath = this.path("chat"); - const headers = getHeaders(); - headers.XModelID = modelConfig.model; - headers.XEncryptionKey = accessStore.encryptionKey; + const isApp = !!getClientConfig()?.isApp; + const bedrockAPIPath = `${BEDROCK_BASE_URL}/model/${ + modelConfig.model + }/invoke${shouldStream ? "-with-response-stream" : ""}`; + const chatPath = isApp ? bedrockAPIPath : ApiPath.Bedrock + "/chat"; + + const headers = isApp + ? await sign({ + method: "POST", + url: chatPath, + region: accessStore.awsRegion, + accessKeyId: accessStore.awsAccessKey, + secretAccessKey: accessStore.awsSecretKey, + body: finalRequestBody, + service: "bedrock", + isStreaming: shouldStream, + }) + : getHeaders(); + + if (!isApp) { + headers.XModelID = modelConfig.model; + headers.XEncryptionKey = accessStore.encryptionKey; + headers.ShouldStream = shouldStream + ""; + } + if (process.env.NODE_ENV !== "production") { console.debug("[Bedrock Client] Request:", { path: chatPath, @@ -236,20 +274,14 @@ export class BedrockApi implements LLMApi { stream: shouldStream, }); } - const finalRequestBody = this.formatRequestBody(messages, modelConfig); - console.log( - "[Bedrock Client] Request Body:", - JSON.stringify(finalRequestBody, null, 2), - ); if (shouldStream) { - let index = -1; const [tools, funcs] = usePluginStore .getState() .getAsTools( useChatStore.getState().currentSession().mask?.plugin || [], ); - return stream( + return bedrockStream( chatPath, finalRequestBody, headers, @@ -261,59 +293,12 @@ export class BedrockApi implements LLMApi { })), funcs, controller, - // parseSSE - (text: string, runTools: ChatMessageTool[]) => { - // console.log("parseSSE", text, runTools); - let chunkJson: - | undefined - | { - type: "content_block_delta" | "content_block_stop"; - content_block?: { - type: "tool_use"; - id: string; - name: string; - }; - delta?: { - type: "text_delta" | "input_json_delta"; - text?: string; - partial_json?: string; - }; - index: number; - }; - chunkJson = JSON.parse(text); - - if (chunkJson?.content_block?.type == "tool_use") { - index += 1; - const id = chunkJson?.content_block.id; - const name = chunkJson?.content_block.name; - runTools.push({ - id, - type: "function", - function: { - name, - arguments: "", - }, - }); - } - if ( - chunkJson?.delta?.type == "input_json_delta" && - chunkJson?.delta?.partial_json - ) { - // @ts-ignore - runTools[index]["function"]["arguments"] += - chunkJson?.delta?.partial_json; - } - return chunkJson?.delta?.text; - }, // processToolMessage, include tool_calls message and tool call results ( requestPayload: RequestPayload, toolCallMessage: any, toolCallResult: any[], ) => { - // reset index value - index = -1; - const modelId = modelConfig.model; const isMistral = modelId.startsWith("mistral.mistral"); const isClaude = modelId.includes("anthropic.claude"); @@ -384,30 +369,26 @@ export class BedrockApi implements LLMApi { options, ); } else { - headers.ShouldStream = "false"; - const res = await fetch(chatPath, { - method: "POST", - headers, - body: JSON.stringify(finalRequestBody), - }); - - if (!res.ok) { - const errorText = await res.text(); - console.error("[Bedrock Client] Error response:", errorText); - throw new Error(`Request failed: ${errorText}`); + try { + controller.signal.onabort = () => + options.onFinish("", new Response(null, { status: 400 })); + const res = await fetch(chatPath, { + method: "POST", + headers: headers, + body: JSON.stringify(finalRequestBody), + }); + const contentType = res.headers.get("content-type"); + console.log( + "[Bedrock Not Stream Request] response content type: ", + contentType, + ); + const resJson = await res.json(); + const message = extractMessage(resJson); + options.onFinish(message, res); + } catch (e) { + console.error("failed to chat", e); + options.onError?.(e as Error); } - - const resJson = await res.json(); - if (!resJson) { - throw new Error("Empty response from server"); - } - - const message = extractMessage(resJson, modelConfig.model); - if (!message) { - throw new Error("Failed to extract message from response"); - } - - options.onFinish(message, res); } } catch (e) { console.error("[Bedrock Client] Chat error:", e); @@ -415,26 +396,6 @@ export class BedrockApi implements LLMApi { } } - path(path: string): string { - const accessStore = useAccessStore.getState(); - let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : ""; - - if (baseUrl.length === 0) { - const isApp = !!getClientConfig()?.isApp; - const apiPath = ApiPath.Bedrock; - baseUrl = isApp ? BEDROCK_BASE_URL : apiPath; - } - - baseUrl = baseUrl.endsWith("/") ? baseUrl.slice(0, -1) : baseUrl; - if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Bedrock)) { - baseUrl = "https://" + baseUrl; - } - - console.log("[Bedrock Client] API Endpoint:", baseUrl, path); - - return [baseUrl, path].join("/"); - } - async usage() { return { used: 0, total: 0 }; } @@ -443,3 +404,256 @@ export class BedrockApi implements LLMApi { return []; } } + +function bedrockStream( + chatPath: string, + requestPayload: any, + headers: any, + tools: any[], + funcs: Record, + controller: AbortController, + processToolMessage: ( + requestPayload: any, + toolCallMessage: any, + toolCallResult: any[], + ) => void, + options: any, +) { + let responseText = ""; + let remainText = ""; + let finished = false; + let running = false; + let runTools: any[] = []; + let responseRes: Response; + let index = -1; + let chunks: Uint8Array[] = []; // 使用数组存储二进制数据块 + let pendingChunk: Uint8Array | null = null; // 存储不完整的数据块 + + // Animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // Start animation + animateResponseText(); + + const finish = () => { + if (!finished) { + if (!running && runTools.length > 0) { + const toolCallMessage = { + role: "assistant", + tool_calls: [...runTools], + }; + running = true; + runTools.splice(0, runTools.length); // empty runTools + return Promise.all( + toolCallMessage.tool_calls.map((tool) => { + options?.onBeforeTool?.(tool); + return Promise.resolve( + funcs[tool.function.name]( + tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + ), + ) + .then((res) => { + let content = res.data || res?.statusText; + content = + typeof content === "string" + ? content + : JSON.stringify(content); + if (res.status >= 300) { + return Promise.reject(content); + } + return content; + }) + .then((content) => { + options?.onAfterTool?.({ + ...tool, + content, + isError: false, + }); + return content; + }) + .catch((e) => { + options?.onAfterTool?.({ + ...tool, + isError: true, + errorMsg: e.toString(), + }); + return e.toString(); + }) + .then((content) => ({ + name: tool.function.name, + role: "tool", + content, + tool_call_id: tool.id, + })); + }), + ).then((toolCallResult) => { + processToolMessage(requestPayload, toolCallMessage, toolCallResult); + setTimeout(() => { + // call again + console.debug("[BedrockAPI for toolCallResult] restart"); + running = false; + bedrockChatApi(chatPath, headers, requestPayload, tools); + }, 60); + }); + } + if (running) { + return; + } + console.debug("[BedrockAPI] end"); + finished = true; + options.onFinish(responseText + remainText, responseRes); + } + }; + + controller.signal.onabort = finish; + + async function bedrockChatApi( + chatPath: string, + headers: any, + requestPayload: any, + tools: any, + ) { + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + try { + const res = await fetch(chatPath, { + method: "POST", + headers, + body: JSON.stringify({ + ...requestPayload, + tools: tools && tools.length ? tools : undefined, + }), + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }); + + clearTimeout(requestTimeoutId); + responseRes = res; + + const contentType = res.headers.get("content-type"); + console.log( + "[Bedrock Stream Request] response content type: ", + contentType, + ); + + // Handle non-stream responses + if (contentType?.startsWith("text/plain")) { + responseText = await res.text(); + return finish(); + } + + // Handle error responses + if ( + !res.ok || + res.status !== 200 || + !contentType?.startsWith("application/vnd.amazon.eventstream") + ) { + const responseTexts = [responseText]; + let extraInfo = await res.text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + return finish(); + } + + // Process the stream using chunks + const reader = res.body?.getReader(); + if (!reader) { + throw new Error("No response body reader available"); + } + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + // Process final pending chunk + if (pendingChunk) { + try { + const parsed = parseEventData(pendingChunk); + if (parsed) { + const result = processMessage( + parsed, + remainText, + runTools, + index, + ); + remainText = result.remainText; + index = result.index; + } + } catch (e) { + console.error("[Final Chunk Process Error]:", e); + } + } + break; + } + + // Add new chunk to queue + chunks.push(value); + + // Process chunk queue + const result = processChunks( + chunks, + pendingChunk, + remainText, + runTools, + index, + ); + chunks = result.chunks; + pendingChunk = result.pendingChunk; + remainText = result.remainText; + index = result.index; + } + } catch (err) { + console.error("[Bedrock Stream Error]:", err); + throw err; + } finally { + reader.releaseLock(); + finish(); + } + } catch (e) { + console.error("[Bedrock Request] error", e); + options.onError?.(e); + throw e; + } + } + + console.debug("[BedrockAPI] start"); + bedrockChatApi(chatPath, headers, requestPayload, tools); +} diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 395c8fe3e..912df4811 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -67,8 +67,9 @@ export interface SignParams { region: string; accessKeyId: string; secretAccessKey: string; - body: string; + body: string | object; service: string; + headers?: Record; isStreaming?: boolean; } @@ -99,7 +100,7 @@ function normalizeHeaderValue(value: string): string { return value.replace(/\s+/g, " ").trim(); } -function encodeURIComponent_RFC3986(str: string): string { +function encodeRFC3986(str: string): string { return encodeURIComponent(str) .replace( /[!'()*]/g, @@ -108,41 +109,36 @@ function encodeURIComponent_RFC3986(str: string): string { .replace(/[-_.~]/g, (c) => c); } -function encodeURI_RFC3986(uri: string): string { - if (!uri || uri === "/") return ""; +function getCanonicalUri(path: string): string { + if (!path || path === "/") return "/"; - const segments = uri.split("/"); + return ( + "/" + + path + .split("/") + .map((segment) => { + if (!segment) return ""; + if (segment === "invoke-with-response-stream") return segment; - return segments - .map((segment) => { - if (!segment) return ""; - - if (segment.includes("model/")) { - const parts = segment.split(/(model\/)/); - return parts - .map((part) => { - if (part === "model/") return part; - if (part.includes(".") || part.includes(":")) { + if (segment.includes("model/")) { + return segment + .split(/(model\/)/) + .map((part) => { + if (part === "model/") return part; return part .split(/([.:])/g) - .map((subpart, i) => { - if (i % 2 === 1) return subpart; - return encodeURIComponent_RFC3986(subpart); - }) + .map((subpart, i) => + i % 2 === 1 ? subpart : encodeRFC3986(subpart), + ) .join(""); - } - return encodeURIComponent_RFC3986(part); - }) - .join(""); - } + }) + .join(""); + } - if (segment === "invoke-with-response-stream") { - return segment; - } - - return encodeURIComponent_RFC3986(segment); - }) - .join("/"); + return encodeRFC3986(segment); + }) + .join("/") + ); } export async function sign({ @@ -153,18 +149,20 @@ export async function sign({ secretAccessKey, body, service, + headers: customHeaders = {}, isStreaming = true, }: SignParams): Promise> { try { const endpoint = new URL(url); - const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); + const canonicalUri = getCanonicalUri(endpoint.pathname.slice(1)); const canonicalQueryString = endpoint.search.slice(1); const now = new Date(); const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); const dateStamp = amzDate.slice(0, 8); - const payloadHash = SHA256(body).toString(Hex); + const bodyString = typeof body === "string" ? body : JSON.stringify(body); + const payloadHash = SHA256(bodyString).toString(Hex); const headers: Record = { accept: isStreaming @@ -174,6 +172,7 @@ export async function sign({ host: endpoint.host, "x-amz-content-sha256": payloadHash, "x-amz-date": amzDate, + ...customHeaders, }; if (isStreaming) { @@ -237,54 +236,274 @@ export async function sign({ } // Bedrock utilities +function decodeBase64(base64String: string): string { + try { + return Buffer.from(base64String, "base64").toString("utf-8"); + } catch (e) { + console.error("[Base64 Decode Error]:", e); + return ""; + } +} + export function parseEventData(chunk: Uint8Array): any { const decoder = new TextDecoder(); const text = decoder.decode(chunk); + const results = []; try { + // First try to parse as JSON const parsed = JSON.parse(text); - // AWS Bedrock wraps the response in a 'body' field + + // Handle bytes field in the response + if (parsed.bytes) { + const decoded = decodeBase64(parsed.bytes); + try { + results.push(JSON.parse(decoded)); + } catch (e) { + results.push({ output: decoded }); + } + return results; + } + + // Handle body field if (typeof parsed.body === "string") { try { - return JSON.parse(parsed.body); + results.push(JSON.parse(parsed.body)); } catch (e) { - return { output: parsed.body }; + results.push({ output: parsed.body }); } + return results; } - return parsed.body || parsed; + + results.push(parsed.body || parsed); + return results; } catch (e) { try { - // Handle base64 encoded responses - const base64Match = text.match(/:"([A-Za-z0-9+/=]+)"/); - if (base64Match) { - const decoded = Buffer.from(base64Match[1], "base64").toString("utf-8"); + // Handle event-stream format + const eventRegex = /:event-type[^\{]+({.*?})/g; + let match; + while ((match = eventRegex.exec(text)) !== null) { try { - return JSON.parse(decoded); + const eventData = match[1]; + const parsed = JSON.parse(eventData); + if (parsed.bytes) { + const decoded = decodeBase64(parsed.bytes); + try { + results.push(JSON.parse(decoded)); + } catch (e) { + results.push({ output: decoded }); + } + } else { + results.push(parsed); + } } catch (e) { - return { output: decoded }; + results.push({ output: match[1] }); } } - // Handle event-type responses - const eventMatch = text.match(/:event-type[^\{]+({.*})/); - if (eventMatch) { - try { - return JSON.parse(eventMatch[1]); - } catch (e) { - return { output: eventMatch[1] }; - } + if (results.length > 0) { + return results; } // Handle plain text responses if (text.trim()) { const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); - return { output: cleanText }; + results.push({ output: cleanText.trim() }); + return results; } } catch (innerError) { console.error("[AWS Parse Error] Inner parsing failed:", innerError); } } - return null; + return []; +} + +export function processMessage( + data: any, + remainText: string, + runTools: any[], + index: number, +): { remainText: string; index: number } { + if (!data) return { remainText, index }; + + try { + // Handle message_start event + if (data.type === "message_start") { + // Keep existing text but mark the start of a new message + console.debug("[Message Start] Current text:", remainText); + return { remainText, index }; + } + + // Handle content_block_start event + if (data.type === "content_block_start") { + if (data.content_block?.type === "tool_use") { + index += 1; + runTools.push({ + id: data.content_block.id, + type: "function", + function: { + name: data.content_block.name, + arguments: "", + }, + }); + } + return { remainText, index }; + } + + // Handle content_block_delta event + if (data.type === "content_block_delta") { + if (data.delta?.type === "input_json_delta" && runTools[index]) { + runTools[index].function.arguments += data.delta.partial_json; + } else if (data.delta?.type === "text_delta") { + const newText = data.delta.text || ""; + // console.debug("[Text Delta] Adding:", newText); + remainText += newText; + } + return { remainText, index }; + } + + // Handle tool calls + if (data.choices?.[0]?.message?.tool_calls) { + for (const toolCall of data.choices[0].message.tool_calls) { + index += 1; + runTools.push({ + id: toolCall.id || `tool-${Date.now()}`, + type: "function", + function: { + name: toolCall.function?.name, + arguments: toolCall.function?.arguments || "", + }, + }); + } + return { remainText, index }; + } + + // Handle various response formats + let newText = ""; + if (data.delta?.text) { + newText = data.delta.text; + } else if (data.choices?.[0]?.message?.content) { + newText = data.choices[0].message.content; + } else if (data.content?.[0]?.text) { + newText = data.content[0].text; + } else if (data.generation) { + newText = data.generation; + } else if (data.outputText) { + newText = data.outputText; + } else if (data.response) { + newText = data.response; + } else if (data.output) { + newText = data.output; + } + + // Only append if we have new text + if (newText) { + // console.debug("[New Text] Adding:", newText); + remainText += newText; + } + } catch (e) { + console.error("[Bedrock Request] parse error", e); + } + + return { remainText, index }; +} + +export function processChunks( + chunks: Uint8Array[], + pendingChunk: Uint8Array | null, + remainText: string, + runTools: any[], + index: number, +): { + chunks: Uint8Array[]; + pendingChunk: Uint8Array | null; + remainText: string; + index: number; +} { + let currentText = remainText; + let currentIndex = index; + + while (chunks.length > 0) { + const chunk = chunks[0]; + try { + // Try to process the chunk + const parsedEvents = parseEventData(chunk); + if (parsedEvents.length > 0) { + // Process each event in the chunk + for (const parsed of parsedEvents) { + const result = processMessage( + parsed, + currentText, + runTools, + currentIndex, + ); + currentText = result.remainText; + currentIndex = result.index; + } + chunks.shift(); // Remove processed chunk + + // If there's a pending chunk, try to process it now + if (pendingChunk) { + const pendingEvents = parseEventData(pendingChunk); + if (pendingEvents.length > 0) { + for (const pendingParsed of pendingEvents) { + const pendingResult = processMessage( + pendingParsed, + currentText, + runTools, + currentIndex, + ); + currentText = pendingResult.remainText; + currentIndex = pendingResult.index; + } + pendingChunk = null; + } + } + } else { + // If parsing fails, it might be an incomplete chunk + if (pendingChunk) { + // Merge with pending chunk + const mergedChunk = new Uint8Array( + pendingChunk.length + chunk.length, + ); + mergedChunk.set(pendingChunk); + mergedChunk.set(chunk, pendingChunk.length); + pendingChunk = mergedChunk; + } else { + pendingChunk = chunk; + } + chunks.shift(); + } + } catch (e) { + console.error("[Chunk Process Error]:", e); + chunks.shift(); // Remove error chunk + } + } + + // Try to process any remaining pending chunk one last time + if (pendingChunk) { + const finalEvents = parseEventData(pendingChunk); + if (finalEvents.length > 0) { + for (const finalParsed of finalEvents) { + const finalResult = processMessage( + finalParsed, + currentText, + runTools, + currentIndex, + ); + currentText = finalResult.remainText; + currentIndex = finalResult.index; + } + pendingChunk = null; + } + } + + return { + chunks, + pendingChunk, + remainText: currentText, + index: currentIndex, + }; } export function getBedrockEndpoint( @@ -309,150 +528,32 @@ export function extractMessage(res: any, modelId: string = ""): string { return ""; } + let message = ""; + // Handle Mistral model response format if (modelId.toLowerCase().includes("mistral")) { if (res.choices?.[0]?.message?.content) { - return res.choices[0].message.content; + message = res.choices[0].message.content; + } else { + message = res.output || ""; } - return res.output || ""; } - // Handle Llama model response format - if (modelId.toLowerCase().includes("llama")) { - return res?.generation || ""; + else if (modelId.toLowerCase().includes("llama")) { + message = res?.generation || ""; } - // Handle Titan model response format - if (modelId.toLowerCase().includes("titan")) { - return res?.outputText || ""; + else if (modelId.toLowerCase().includes("titan")) { + message = res?.outputText || ""; } - // Handle Claude and other models - return res?.content?.[0]?.text || ""; -} - -export async function* transformBedrockStream( - stream: ReadableStream, - modelId: string, -) { - const reader = stream.getReader(); - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const parsed = parseEventData(value); - if (!parsed) continue; - - // console.log("parseEventData========================="); - // console.log(parsed); - // Handle Claude 3 models - if (modelId.startsWith("anthropic.claude")) { - if (parsed.type === "message_start") { - // Initialize message - continue; - } else if (parsed.type === "content_block_start") { - if (parsed.content_block?.type === "tool_use") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } - continue; - } else if (parsed.type === "content_block_delta") { - if (parsed.delta?.type === "text_delta") { - yield `data: ${JSON.stringify({ - delta: { text: parsed.delta.text }, - })}\n\n`; - } else if (parsed.delta?.type === "input_json_delta") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } - } else if (parsed.type === "content_block_stop") { - yield `data: ${JSON.stringify(parsed)}\n\n`; - } else if ( - parsed.type === "message_delta" && - parsed.delta?.stop_reason - ) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.delta.stop_reason }, - })}\n\n`; - } - } - // Handle Mistral models - else if (modelId.toLowerCase().includes("mistral")) { - if (parsed.choices?.[0]?.message?.tool_calls) { - const toolCalls = parsed.choices[0].message.tool_calls; - for (const toolCall of toolCalls) { - yield `data: ${JSON.stringify({ - type: "content_block_start", - content_block: { - type: "tool_use", - id: toolCall.id || `tool-${Date.now()}`, - name: toolCall.function?.name, - }, - })}\n\n`; - - if (toolCall.function?.arguments) { - yield `data: ${JSON.stringify({ - type: "content_block_delta", - delta: { - type: "input_json_delta", - partial_json: toolCall.function.arguments, - }, - })}\n\n`; - } - - yield `data: ${JSON.stringify({ - type: "content_block_stop", - })}\n\n`; - } - } else if (parsed.choices?.[0]?.message?.content) { - yield `data: ${JSON.stringify({ - delta: { text: parsed.choices[0].message.content }, - })}\n\n`; - } - - if (parsed.choices?.[0]?.finish_reason) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.choices[0].finish_reason }, - })}\n\n`; - } - } - // Handle Llama models - else if (modelId.toLowerCase().includes("llama")) { - if (parsed.generation) { - yield `data: ${JSON.stringify({ - delta: { text: parsed.generation }, - })}\n\n`; - } - if (parsed.stop_reason) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.stop_reason }, - })}\n\n`; - } - } - // Handle Titan models - else if (modelId.toLowerCase().includes("titan")) { - if (parsed.outputText) { - yield `data: ${JSON.stringify({ - delta: { text: parsed.outputText }, - })}\n\n`; - } - if (parsed.completionReason) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.completionReason }, - })}\n\n`; - } - } - // Handle other models with basic text output - else { - const text = parsed.response || parsed.output || ""; - if (text) { - yield `data: ${JSON.stringify({ - delta: { text }, - })}\n\n`; - } - } - } - } finally { - reader.releaseLock(); + else if (res.content?.[0]?.text) { + message = res.content[0].text; } + // Handle other response formats + else { + message = res.output || res.response || res.message || ""; + } + + return message; }