完善llama和mistral模型的推理功能
This commit is contained in:
parent
2ccdd1706a
commit
6f7a635030
|
@ -4,7 +4,6 @@ import {
|
|||
sign,
|
||||
decrypt,
|
||||
getBedrockEndpoint,
|
||||
getModelHeaders,
|
||||
transformBedrockStream,
|
||||
parseEventData,
|
||||
BedrockCredentials,
|
||||
|
@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) {
|
|||
} catch (e) {
|
||||
throw new Error(`Invalid JSON in request body: ${e}`);
|
||||
}
|
||||
console.log(
|
||||
"[Bedrock Request] original Body:",
|
||||
JSON.stringify(bodyJson, null, 2),
|
||||
);
|
||||
|
||||
// Extract tool configuration if present
|
||||
let tools: any[] | undefined;
|
||||
|
@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) {
|
|||
modelId,
|
||||
shouldStream,
|
||||
);
|
||||
const additionalHeaders = getModelHeaders(modelId);
|
||||
|
||||
console.log("[Bedrock Request] Endpoint:", endpoint);
|
||||
console.log("[Bedrock Request] Model ID:", modelId);
|
||||
|
||||
// Only include tools for Claude models
|
||||
const isClaudeModel = modelId.toLowerCase().includes("claude3");
|
||||
// Handle tools for different models
|
||||
const isMistralModel = modelId.toLowerCase().includes("mistral");
|
||||
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
||||
|
||||
const requestBody = {
|
||||
...bodyJson,
|
||||
...(isClaudeModel && tools && { tools }),
|
||||
};
|
||||
|
||||
if (tools && tools.length > 0) {
|
||||
if (isClaudeModel) {
|
||||
// Claude models already have correct tool format
|
||||
requestBody.tools = tools;
|
||||
} else if (isMistralModel) {
|
||||
// Format messages for Mistral
|
||||
if (typeof requestBody.prompt === "string") {
|
||||
requestBody.messages = [
|
||||
{ role: "user", content: requestBody.prompt },
|
||||
];
|
||||
delete requestBody.prompt;
|
||||
}
|
||||
|
||||
// Add tools in Mistral's format
|
||||
requestBody.tool_choice = "auto";
|
||||
requestBody.tools = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Sign request
|
||||
const headers = await sign({
|
||||
method: "POST",
|
||||
|
@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) {
|
|||
body: JSON.stringify(requestBody),
|
||||
service: "bedrock",
|
||||
isStreaming: shouldStream,
|
||||
additionalHeaders,
|
||||
});
|
||||
|
||||
// Make request to AWS Bedrock
|
||||
console.log(
|
||||
"[Bedrock Request] Body:",
|
||||
"[Bedrock Request] Final Body:",
|
||||
JSON.stringify(requestBody, null, 2),
|
||||
);
|
||||
const res = await fetch(endpoint, {
|
||||
|
@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) {
|
|||
|
||||
// Handle streaming response
|
||||
const transformedStream = transformBedrockStream(res.body, modelId);
|
||||
const encoder = new TextEncoder();
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of transformedStream) {
|
||||
controller.enqueue(new TextEncoder().encode(chunk));
|
||||
// Ensure we're sending non-empty chunks
|
||||
if (chunk && chunk.trim()) {
|
||||
controller.enqueue(encoder.encode(chunk));
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
} catch (err) {
|
||||
|
|
|
@ -37,7 +37,7 @@ export class BedrockApi implements LLMApi {
|
|||
if (model.startsWith("amazon.titan")) {
|
||||
const inputText = messages
|
||||
.map((message) => {
|
||||
return `${message.role}: ${message.content}`;
|
||||
return `${message.role}: ${getMessageTextContent(message)}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
|
||||
|
@ -52,32 +52,59 @@ export class BedrockApi implements LLMApi {
|
|||
}
|
||||
|
||||
// Handle LLaMA models
|
||||
if (model.startsWith("us.meta.llama")) {
|
||||
const prompt = messages
|
||||
.map((message) => {
|
||||
return `${message.role}: ${message.content}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
if (model.includes("meta.llama")) {
|
||||
// Format conversation for Llama models
|
||||
let prompt = "";
|
||||
let systemPrompt = "";
|
||||
|
||||
// Extract system message if present
|
||||
const systemMessage = messages.find((m) => m.role === "system");
|
||||
if (systemMessage) {
|
||||
systemPrompt = getMessageTextContent(systemMessage);
|
||||
}
|
||||
|
||||
// Format the conversation
|
||||
const conversationMessages = messages.filter((m) => m.role !== "system");
|
||||
prompt = `<s>[INST] <<SYS>>\n${
|
||||
systemPrompt || "You are a helpful, respectful and honest assistant."
|
||||
}\n<</SYS>>\n\n`;
|
||||
|
||||
for (let i = 0; i < conversationMessages.length; i++) {
|
||||
const message = conversationMessages[i];
|
||||
const content = getMessageTextContent(message);
|
||||
if (i === 0 && message.role === "user") {
|
||||
// First user message goes in the same [INST] block as system prompt
|
||||
prompt += `${content} [/INST]`;
|
||||
} else {
|
||||
if (message.role === "user") {
|
||||
prompt += `\n\n<s>[INST] ${content} [/INST]`;
|
||||
} else {
|
||||
prompt += ` ${content} </s>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
prompt,
|
||||
max_gen_len: modelConfig.max_tokens || 512,
|
||||
temperature: modelConfig.temperature || 0.6,
|
||||
temperature: modelConfig.temperature || 0.7,
|
||||
top_p: modelConfig.top_p || 0.9,
|
||||
stop: ["User:", "System:", "Assistant:", "\n\n"],
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Mistral models
|
||||
if (model.startsWith("mistral.mistral")) {
|
||||
const prompt = messages
|
||||
.map((message) => {
|
||||
return `${message.role}: ${message.content}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
// Format messages for Mistral's chat format
|
||||
const formattedMessages = messages.map((message) => ({
|
||||
role: message.role,
|
||||
content: getMessageTextContent(message),
|
||||
}));
|
||||
|
||||
return {
|
||||
prompt,
|
||||
messages: formattedMessages,
|
||||
max_tokens: modelConfig.max_tokens || 4096,
|
||||
temperature: modelConfig.temperature || 0.7,
|
||||
top_p: modelConfig.top_p || 0.9,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -292,7 +292,10 @@ export function showPlugins(provider: ServiceProvider, model: string) {
|
|||
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
|
||||
return true;
|
||||
}
|
||||
if (provider == ServiceProvider.Bedrock && model.includes("claude-3")) {
|
||||
if (
|
||||
(provider == ServiceProvider.Bedrock && model.includes("claude-3")) ||
|
||||
model.includes("mistral-large")
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
if (provider == ServiceProvider.Google && !model.includes("vision")) {
|
||||
|
|
183
app/utils/aws.ts
183
app/utils/aws.ts
|
@ -75,7 +75,6 @@ export interface SignParams {
|
|||
body: string;
|
||||
service: string;
|
||||
isStreaming?: boolean;
|
||||
additionalHeaders?: Record<string, string>;
|
||||
}
|
||||
|
||||
function hmac(
|
||||
|
@ -160,7 +159,6 @@ export async function sign({
|
|||
body,
|
||||
service,
|
||||
isStreaming = true,
|
||||
additionalHeaders = {},
|
||||
}: SignParams): Promise<Record<string, string>> {
|
||||
try {
|
||||
const endpoint = new URL(url);
|
||||
|
@ -181,7 +179,6 @@ export async function sign({
|
|||
host: endpoint.host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate,
|
||||
...additionalHeaders,
|
||||
};
|
||||
|
||||
if (isStreaming) {
|
||||
|
@ -311,32 +308,25 @@ export function getBedrockEndpoint(
|
|||
return endpoint;
|
||||
}
|
||||
|
||||
export function getModelHeaders(modelId: string): Record<string, string> {
|
||||
if (!modelId) {
|
||||
throw new Error("Model ID is required for headers");
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
|
||||
if (
|
||||
modelId.startsWith("us.meta.llama") ||
|
||||
modelId.startsWith("mistral.mistral")
|
||||
) {
|
||||
headers["content-type"] = "application/json";
|
||||
headers["accept"] = "application/json";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
export function extractMessage(res: any, modelId: string = ""): string {
|
||||
if (!res) {
|
||||
console.error("[AWS Extract Error] extractMessage Empty response");
|
||||
return "";
|
||||
}
|
||||
console.log("[Response] extractMessage response: ", res);
|
||||
return res?.content?.[0]?.text;
|
||||
return "";
|
||||
|
||||
// Handle Mistral model response format
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
return res?.outputs?.[0]?.text || "";
|
||||
}
|
||||
|
||||
// Handle Llama model response format
|
||||
if (modelId.toLowerCase().includes("llama")) {
|
||||
return res?.generation || "";
|
||||
}
|
||||
|
||||
// Handle Claude and other models
|
||||
return res?.content?.[0]?.text || "";
|
||||
}
|
||||
|
||||
export async function* transformBedrockStream(
|
||||
|
@ -344,58 +334,105 @@ export async function* transformBedrockStream(
|
|||
modelId: string,
|
||||
) {
|
||||
const reader = stream.getReader();
|
||||
let buffer = "";
|
||||
let accumulatedText = "";
|
||||
let toolCallStarted = false;
|
||||
let currentToolCall = null;
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
if (buffer) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: buffer },
|
||||
})}\n\n`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (done) break;
|
||||
|
||||
const parsed = parseEventData(value);
|
||||
if (!parsed) continue;
|
||||
|
||||
// Handle Titan models
|
||||
if (modelId.startsWith("amazon.titan")) {
|
||||
const text = parsed.outputText || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
// Handle LLaMA3 models
|
||||
else if (modelId.startsWith("us.meta.llama3")) {
|
||||
let text = "";
|
||||
if (parsed.generation) {
|
||||
text = parsed.generation;
|
||||
} else if (parsed.output) {
|
||||
text = parsed.output;
|
||||
} else if (typeof parsed === "string") {
|
||||
text = parsed;
|
||||
}
|
||||
|
||||
if (text) {
|
||||
// Clean up any control characters or invalid JSON characters
|
||||
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
console.log("parseEventData=========================");
|
||||
console.log(parsed);
|
||||
// Handle Mistral models
|
||||
else if (modelId.startsWith("mistral.mistral")) {
|
||||
const text =
|
||||
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
|
||||
if (text) {
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
// If we have content, accumulate it
|
||||
if (
|
||||
parsed.choices?.[0]?.message?.role === "assistant" &&
|
||||
parsed.choices?.[0]?.message?.content
|
||||
) {
|
||||
accumulatedText += parsed.choices?.[0]?.message?.content;
|
||||
console.log("accumulatedText=========================");
|
||||
console.log(accumulatedText);
|
||||
// Check for tool call in the accumulated text
|
||||
if (!toolCallStarted && accumulatedText.includes("```json")) {
|
||||
const jsonMatch = accumulatedText.match(
|
||||
/```json\s*({[\s\S]*?})\s*```/,
|
||||
);
|
||||
if (jsonMatch) {
|
||||
try {
|
||||
const toolData = JSON.parse(jsonMatch[1]);
|
||||
currentToolCall = {
|
||||
id: `tool-${Date.now()}`,
|
||||
name: toolData.name,
|
||||
arguments: toolData.arguments,
|
||||
};
|
||||
|
||||
// Emit tool call start
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: currentToolCall.id,
|
||||
name: currentToolCall.name,
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool arguments
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: JSON.stringify(currentToolCall.arguments),
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool call stop
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_stop",
|
||||
})}\n\n`;
|
||||
|
||||
// Clear the accumulated text after processing the tool call
|
||||
accumulatedText = accumulatedText.replace(
|
||||
/```json\s*{[\s\S]*?}\s*```/,
|
||||
"",
|
||||
);
|
||||
toolCallStarted = false;
|
||||
currentToolCall = null;
|
||||
} catch (e) {
|
||||
console.error("Failed to parse tool JSON:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// emit the text content if it's not empty
|
||||
if (parsed.choices?.[0]?.message?.content.trim()) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: parsed.choices?.[0]?.message?.content },
|
||||
})}\n\n`;
|
||||
}
|
||||
// Handle stop reason if present
|
||||
if (parsed.choices?.[0]?.stop_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.choices[0].stop_reason },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle Llama models
|
||||
else if (modelId.toLowerCase().includes("llama")) {
|
||||
if (parsed.generation) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: parsed.generation },
|
||||
})}\n\n`;
|
||||
}
|
||||
if (parsed.stop_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.stop_reason },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
|
@ -423,6 +460,22 @@ export async function* transformBedrockStream(
|
|||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||
} else if (parsed.type === "content_block_stop") {
|
||||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||
} else {
|
||||
// Handle regular text responses
|
||||
const text = parsed.response || parsed.output || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle other model text responses
|
||||
const text = parsed.outputText || parsed.generation || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue