完善llama和mistral模型的推理功能

This commit is contained in:
glay 2024-11-24 23:54:04 +08:00
parent 2ccdd1706a
commit 6f7a635030
4 changed files with 204 additions and 89 deletions

View File

@ -4,7 +4,6 @@ import {
sign, sign,
decrypt, decrypt,
getBedrockEndpoint, getBedrockEndpoint,
getModelHeaders,
transformBedrockStream, transformBedrockStream,
parseEventData, parseEventData,
BedrockCredentials, BedrockCredentials,
@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) {
} catch (e) { } catch (e) {
throw new Error(`Invalid JSON in request body: ${e}`); throw new Error(`Invalid JSON in request body: ${e}`);
} }
console.log(
"[Bedrock Request] original Body:",
JSON.stringify(bodyJson, null, 2),
);
// Extract tool configuration if present // Extract tool configuration if present
let tools: any[] | undefined; let tools: any[] | undefined;
@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) {
modelId, modelId,
shouldStream, shouldStream,
); );
const additionalHeaders = getModelHeaders(modelId);
console.log("[Bedrock Request] Endpoint:", endpoint); console.log("[Bedrock Request] Endpoint:", endpoint);
console.log("[Bedrock Request] Model ID:", modelId); console.log("[Bedrock Request] Model ID:", modelId);
// Only include tools for Claude models // Handle tools for different models
const isClaudeModel = modelId.toLowerCase().includes("claude3"); const isMistralModel = modelId.toLowerCase().includes("mistral");
const isClaudeModel = modelId.toLowerCase().includes("claude");
const requestBody = { const requestBody = {
...bodyJson, ...bodyJson,
...(isClaudeModel && tools && { tools }),
}; };
if (tools && tools.length > 0) {
if (isClaudeModel) {
// Claude models already have correct tool format
requestBody.tools = tools;
} else if (isMistralModel) {
// Format messages for Mistral
if (typeof requestBody.prompt === "string") {
requestBody.messages = [
{ role: "user", content: requestBody.prompt },
];
delete requestBody.prompt;
}
// Add tools in Mistral's format
requestBody.tool_choice = "auto";
requestBody.tools = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}));
}
}
// Sign request // Sign request
const headers = await sign({ const headers = await sign({
method: "POST", method: "POST",
@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) {
body: JSON.stringify(requestBody), body: JSON.stringify(requestBody),
service: "bedrock", service: "bedrock",
isStreaming: shouldStream, isStreaming: shouldStream,
additionalHeaders,
}); });
// Make request to AWS Bedrock // Make request to AWS Bedrock
console.log( console.log(
"[Bedrock Request] Body:", "[Bedrock Request] Final Body:",
JSON.stringify(requestBody, null, 2), JSON.stringify(requestBody, null, 2),
); );
const res = await fetch(endpoint, { const res = await fetch(endpoint, {
@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) {
// Handle streaming response // Handle streaming response
const transformedStream = transformBedrockStream(res.body, modelId); const transformedStream = transformBedrockStream(res.body, modelId);
const encoder = new TextEncoder();
const stream = new ReadableStream({ const stream = new ReadableStream({
async start(controller) { async start(controller) {
try { try {
for await (const chunk of transformedStream) { for await (const chunk of transformedStream) {
controller.enqueue(new TextEncoder().encode(chunk)); // Ensure we're sending non-empty chunks
if (chunk && chunk.trim()) {
controller.enqueue(encoder.encode(chunk));
}
} }
controller.close(); controller.close();
} catch (err) { } catch (err) {

View File

@ -37,7 +37,7 @@ export class BedrockApi implements LLMApi {
if (model.startsWith("amazon.titan")) { if (model.startsWith("amazon.titan")) {
const inputText = messages const inputText = messages
.map((message) => { .map((message) => {
return `${message.role}: ${message.content}`; return `${message.role}: ${getMessageTextContent(message)}`;
}) })
.join("\n\n"); .join("\n\n");
@ -52,32 +52,59 @@ export class BedrockApi implements LLMApi {
} }
// Handle LLaMA models // Handle LLaMA models
if (model.startsWith("us.meta.llama")) { if (model.includes("meta.llama")) {
const prompt = messages // Format conversation for Llama models
.map((message) => { let prompt = "";
return `${message.role}: ${message.content}`; let systemPrompt = "";
})
.join("\n\n"); // Extract system message if present
const systemMessage = messages.find((m) => m.role === "system");
if (systemMessage) {
systemPrompt = getMessageTextContent(systemMessage);
}
// Format the conversation
const conversationMessages = messages.filter((m) => m.role !== "system");
prompt = `<s>[INST] <<SYS>>\n${
systemPrompt || "You are a helpful, respectful and honest assistant."
}\n<</SYS>>\n\n`;
for (let i = 0; i < conversationMessages.length; i++) {
const message = conversationMessages[i];
const content = getMessageTextContent(message);
if (i === 0 && message.role === "user") {
// First user message goes in the same [INST] block as system prompt
prompt += `${content} [/INST]`;
} else {
if (message.role === "user") {
prompt += `\n\n<s>[INST] ${content} [/INST]`;
} else {
prompt += ` ${content} </s>`;
}
}
}
return { return {
prompt, prompt,
max_gen_len: modelConfig.max_tokens || 512, max_gen_len: modelConfig.max_tokens || 512,
temperature: modelConfig.temperature || 0.6, temperature: modelConfig.temperature || 0.7,
top_p: modelConfig.top_p || 0.9, top_p: modelConfig.top_p || 0.9,
stop: ["User:", "System:", "Assistant:", "\n\n"],
}; };
} }
// Handle Mistral models // Handle Mistral models
if (model.startsWith("mistral.mistral")) { if (model.startsWith("mistral.mistral")) {
const prompt = messages // Format messages for Mistral's chat format
.map((message) => { const formattedMessages = messages.map((message) => ({
return `${message.role}: ${message.content}`; role: message.role,
}) content: getMessageTextContent(message),
.join("\n\n"); }));
return { return {
prompt, messages: formattedMessages,
max_tokens: modelConfig.max_tokens || 4096, max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.7, temperature: modelConfig.temperature || 0.7,
top_p: modelConfig.top_p || 0.9,
}; };
} }

View File

@ -292,7 +292,10 @@ export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true; return true;
} }
if (provider == ServiceProvider.Bedrock && model.includes("claude-3")) { if (
(provider == ServiceProvider.Bedrock && model.includes("claude-3")) ||
model.includes("mistral-large")
) {
return true; return true;
} }
if (provider == ServiceProvider.Google && !model.includes("vision")) { if (provider == ServiceProvider.Google && !model.includes("vision")) {

View File

@ -75,7 +75,6 @@ export interface SignParams {
body: string; body: string;
service: string; service: string;
isStreaming?: boolean; isStreaming?: boolean;
additionalHeaders?: Record<string, string>;
} }
function hmac( function hmac(
@ -160,7 +159,6 @@ export async function sign({
body, body,
service, service,
isStreaming = true, isStreaming = true,
additionalHeaders = {},
}: SignParams): Promise<Record<string, string>> { }: SignParams): Promise<Record<string, string>> {
try { try {
const endpoint = new URL(url); const endpoint = new URL(url);
@ -181,7 +179,6 @@ export async function sign({
host: endpoint.host, host: endpoint.host,
"x-amz-content-sha256": payloadHash, "x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate, "x-amz-date": amzDate,
...additionalHeaders,
}; };
if (isStreaming) { if (isStreaming) {
@ -311,32 +308,25 @@ export function getBedrockEndpoint(
return endpoint; return endpoint;
} }
export function getModelHeaders(modelId: string): Record<string, string> {
if (!modelId) {
throw new Error("Model ID is required for headers");
}
const headers: Record<string, string> = {};
if (
modelId.startsWith("us.meta.llama") ||
modelId.startsWith("mistral.mistral")
) {
headers["content-type"] = "application/json";
headers["accept"] = "application/json";
}
return headers;
}
export function extractMessage(res: any, modelId: string = ""): string { export function extractMessage(res: any, modelId: string = ""): string {
if (!res) { if (!res) {
console.error("[AWS Extract Error] extractMessage Empty response"); console.error("[AWS Extract Error] extractMessage Empty response");
return ""; return "";
} }
console.log("[Response] extractMessage response: ", res); console.log("[Response] extractMessage response: ", res);
return res?.content?.[0]?.text;
return ""; // Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) {
return res?.outputs?.[0]?.text || "";
}
// Handle Llama model response format
if (modelId.toLowerCase().includes("llama")) {
return res?.generation || "";
}
// Handle Claude and other models
return res?.content?.[0]?.text || "";
} }
export async function* transformBedrockStream( export async function* transformBedrockStream(
@ -344,58 +334,105 @@ export async function* transformBedrockStream(
modelId: string, modelId: string,
) { ) {
const reader = stream.getReader(); const reader = stream.getReader();
let buffer = ""; let accumulatedText = "";
let toolCallStarted = false;
let currentToolCall = null;
try { try {
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) {
if (buffer) { if (done) break;
yield `data: ${JSON.stringify({
delta: { text: buffer },
})}\n\n`;
}
break;
}
const parsed = parseEventData(value); const parsed = parseEventData(value);
if (!parsed) continue; if (!parsed) continue;
// Handle Titan models console.log("parseEventData=========================");
if (modelId.startsWith("amazon.titan")) { console.log(parsed);
const text = parsed.outputText || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
// Handle LLaMA3 models
else if (modelId.startsWith("us.meta.llama3")) {
let text = "";
if (parsed.generation) {
text = parsed.generation;
} else if (parsed.output) {
text = parsed.output;
} else if (typeof parsed === "string") {
text = parsed;
}
if (text) {
// Clean up any control characters or invalid JSON characters
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
// Handle Mistral models // Handle Mistral models
else if (modelId.startsWith("mistral.mistral")) { if (modelId.toLowerCase().includes("mistral")) {
const text = // If we have content, accumulate it
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; if (
if (text) { parsed.choices?.[0]?.message?.role === "assistant" &&
parsed.choices?.[0]?.message?.content
) {
accumulatedText += parsed.choices?.[0]?.message?.content;
console.log("accumulatedText=========================");
console.log(accumulatedText);
// Check for tool call in the accumulated text
if (!toolCallStarted && accumulatedText.includes("```json")) {
const jsonMatch = accumulatedText.match(
/```json\s*({[\s\S]*?})\s*```/,
);
if (jsonMatch) {
try {
const toolData = JSON.parse(jsonMatch[1]);
currentToolCall = {
id: `tool-${Date.now()}`,
name: toolData.name,
arguments: toolData.arguments,
};
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: currentToolCall.id,
name: currentToolCall.name,
},
})}\n\n`;
// Emit tool arguments
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(currentToolCall.arguments),
},
})}\n\n`;
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
// Clear the accumulated text after processing the tool call
accumulatedText = accumulatedText.replace(
/```json\s*{[\s\S]*?}\s*```/,
"",
);
toolCallStarted = false;
currentToolCall = null;
} catch (e) {
console.error("Failed to parse tool JSON:", e);
}
}
}
// emit the text content if it's not empty
if (parsed.choices?.[0]?.message?.content.trim()) {
yield `data: ${JSON.stringify({
delta: { text: parsed.choices?.[0]?.message?.content },
})}\n\n`;
}
// Handle stop reason if present
if (parsed.choices?.[0]?.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].stop_reason },
})}\n\n`;
}
}
}
// Handle Llama models
else if (modelId.toLowerCase().includes("llama")) {
if (parsed.generation) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text }, delta: { text: parsed.generation },
})}\n\n`;
}
if (parsed.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
})}\n\n`; })}\n\n`;
} }
} }
@ -423,6 +460,22 @@ export async function* transformBedrockStream(
yield `data: ${JSON.stringify(parsed)}\n\n`; yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (parsed.type === "content_block_stop") { } else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`; yield `data: ${JSON.stringify(parsed)}\n\n`;
} else {
// Handle regular text responses
const text = parsed.response || parsed.output || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
} else {
// Handle other model text responses
const text = parsed.outputText || parsed.generation || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
} }
} }
} }