完善llama和mistral模型的推理功能
This commit is contained in:
parent
2ccdd1706a
commit
6f7a635030
|
@ -4,7 +4,6 @@ import {
|
||||||
sign,
|
sign,
|
||||||
decrypt,
|
decrypt,
|
||||||
getBedrockEndpoint,
|
getBedrockEndpoint,
|
||||||
getModelHeaders,
|
|
||||||
transformBedrockStream,
|
transformBedrockStream,
|
||||||
parseEventData,
|
parseEventData,
|
||||||
BedrockCredentials,
|
BedrockCredentials,
|
||||||
|
@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) {
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
throw new Error(`Invalid JSON in request body: ${e}`);
|
throw new Error(`Invalid JSON in request body: ${e}`);
|
||||||
}
|
}
|
||||||
|
console.log(
|
||||||
|
"[Bedrock Request] original Body:",
|
||||||
|
JSON.stringify(bodyJson, null, 2),
|
||||||
|
);
|
||||||
|
|
||||||
// Extract tool configuration if present
|
// Extract tool configuration if present
|
||||||
let tools: any[] | undefined;
|
let tools: any[] | undefined;
|
||||||
|
@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) {
|
||||||
modelId,
|
modelId,
|
||||||
shouldStream,
|
shouldStream,
|
||||||
);
|
);
|
||||||
const additionalHeaders = getModelHeaders(modelId);
|
|
||||||
|
|
||||||
console.log("[Bedrock Request] Endpoint:", endpoint);
|
console.log("[Bedrock Request] Endpoint:", endpoint);
|
||||||
console.log("[Bedrock Request] Model ID:", modelId);
|
console.log("[Bedrock Request] Model ID:", modelId);
|
||||||
|
|
||||||
// Only include tools for Claude models
|
// Handle tools for different models
|
||||||
const isClaudeModel = modelId.toLowerCase().includes("claude3");
|
const isMistralModel = modelId.toLowerCase().includes("mistral");
|
||||||
|
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
...bodyJson,
|
...bodyJson,
|
||||||
...(isClaudeModel && tools && { tools }),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (tools && tools.length > 0) {
|
||||||
|
if (isClaudeModel) {
|
||||||
|
// Claude models already have correct tool format
|
||||||
|
requestBody.tools = tools;
|
||||||
|
} else if (isMistralModel) {
|
||||||
|
// Format messages for Mistral
|
||||||
|
if (typeof requestBody.prompt === "string") {
|
||||||
|
requestBody.messages = [
|
||||||
|
{ role: "user", content: requestBody.prompt },
|
||||||
|
];
|
||||||
|
delete requestBody.prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tools in Mistral's format
|
||||||
|
requestBody.tool_choice = "auto";
|
||||||
|
requestBody.tools = tools.map((tool) => ({
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: tool.input_schema,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sign request
|
// Sign request
|
||||||
const headers = await sign({
|
const headers = await sign({
|
||||||
method: "POST",
|
method: "POST",
|
||||||
|
@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) {
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
service: "bedrock",
|
service: "bedrock",
|
||||||
isStreaming: shouldStream,
|
isStreaming: shouldStream,
|
||||||
additionalHeaders,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Make request to AWS Bedrock
|
// Make request to AWS Bedrock
|
||||||
console.log(
|
console.log(
|
||||||
"[Bedrock Request] Body:",
|
"[Bedrock Request] Final Body:",
|
||||||
JSON.stringify(requestBody, null, 2),
|
JSON.stringify(requestBody, null, 2),
|
||||||
);
|
);
|
||||||
const res = await fetch(endpoint, {
|
const res = await fetch(endpoint, {
|
||||||
|
@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) {
|
||||||
|
|
||||||
// Handle streaming response
|
// Handle streaming response
|
||||||
const transformedStream = transformBedrockStream(res.body, modelId);
|
const transformedStream = transformBedrockStream(res.body, modelId);
|
||||||
|
const encoder = new TextEncoder();
|
||||||
const stream = new ReadableStream({
|
const stream = new ReadableStream({
|
||||||
async start(controller) {
|
async start(controller) {
|
||||||
try {
|
try {
|
||||||
for await (const chunk of transformedStream) {
|
for await (const chunk of transformedStream) {
|
||||||
controller.enqueue(new TextEncoder().encode(chunk));
|
// Ensure we're sending non-empty chunks
|
||||||
|
if (chunk && chunk.trim()) {
|
||||||
|
controller.enqueue(encoder.encode(chunk));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
controller.close();
|
controller.close();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|
|
@ -37,7 +37,7 @@ export class BedrockApi implements LLMApi {
|
||||||
if (model.startsWith("amazon.titan")) {
|
if (model.startsWith("amazon.titan")) {
|
||||||
const inputText = messages
|
const inputText = messages
|
||||||
.map((message) => {
|
.map((message) => {
|
||||||
return `${message.role}: ${message.content}`;
|
return `${message.role}: ${getMessageTextContent(message)}`;
|
||||||
})
|
})
|
||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
|
|
||||||
|
@ -52,32 +52,59 @@ export class BedrockApi implements LLMApi {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle LLaMA models
|
// Handle LLaMA models
|
||||||
if (model.startsWith("us.meta.llama")) {
|
if (model.includes("meta.llama")) {
|
||||||
const prompt = messages
|
// Format conversation for Llama models
|
||||||
.map((message) => {
|
let prompt = "";
|
||||||
return `${message.role}: ${message.content}`;
|
let systemPrompt = "";
|
||||||
})
|
|
||||||
.join("\n\n");
|
// Extract system message if present
|
||||||
|
const systemMessage = messages.find((m) => m.role === "system");
|
||||||
|
if (systemMessage) {
|
||||||
|
systemPrompt = getMessageTextContent(systemMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the conversation
|
||||||
|
const conversationMessages = messages.filter((m) => m.role !== "system");
|
||||||
|
prompt = `<s>[INST] <<SYS>>\n${
|
||||||
|
systemPrompt || "You are a helpful, respectful and honest assistant."
|
||||||
|
}\n<</SYS>>\n\n`;
|
||||||
|
|
||||||
|
for (let i = 0; i < conversationMessages.length; i++) {
|
||||||
|
const message = conversationMessages[i];
|
||||||
|
const content = getMessageTextContent(message);
|
||||||
|
if (i === 0 && message.role === "user") {
|
||||||
|
// First user message goes in the same [INST] block as system prompt
|
||||||
|
prompt += `${content} [/INST]`;
|
||||||
|
} else {
|
||||||
|
if (message.role === "user") {
|
||||||
|
prompt += `\n\n<s>[INST] ${content} [/INST]`;
|
||||||
|
} else {
|
||||||
|
prompt += ` ${content} </s>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
max_gen_len: modelConfig.max_tokens || 512,
|
max_gen_len: modelConfig.max_tokens || 512,
|
||||||
temperature: modelConfig.temperature || 0.6,
|
temperature: modelConfig.temperature || 0.7,
|
||||||
top_p: modelConfig.top_p || 0.9,
|
top_p: modelConfig.top_p || 0.9,
|
||||||
stop: ["User:", "System:", "Assistant:", "\n\n"],
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
if (model.startsWith("mistral.mistral")) {
|
if (model.startsWith("mistral.mistral")) {
|
||||||
const prompt = messages
|
// Format messages for Mistral's chat format
|
||||||
.map((message) => {
|
const formattedMessages = messages.map((message) => ({
|
||||||
return `${message.role}: ${message.content}`;
|
role: message.role,
|
||||||
})
|
content: getMessageTextContent(message),
|
||||||
.join("\n\n");
|
}));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
prompt,
|
messages: formattedMessages,
|
||||||
max_tokens: modelConfig.max_tokens || 4096,
|
max_tokens: modelConfig.max_tokens || 4096,
|
||||||
temperature: modelConfig.temperature || 0.7,
|
temperature: modelConfig.temperature || 0.7,
|
||||||
|
top_p: modelConfig.top_p || 0.9,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -292,7 +292,10 @@ export function showPlugins(provider: ServiceProvider, model: string) {
|
||||||
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
|
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (provider == ServiceProvider.Bedrock && model.includes("claude-3")) {
|
if (
|
||||||
|
(provider == ServiceProvider.Bedrock && model.includes("claude-3")) ||
|
||||||
|
model.includes("mistral-large")
|
||||||
|
) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (provider == ServiceProvider.Google && !model.includes("vision")) {
|
if (provider == ServiceProvider.Google && !model.includes("vision")) {
|
||||||
|
|
183
app/utils/aws.ts
183
app/utils/aws.ts
|
@ -75,7 +75,6 @@ export interface SignParams {
|
||||||
body: string;
|
body: string;
|
||||||
service: string;
|
service: string;
|
||||||
isStreaming?: boolean;
|
isStreaming?: boolean;
|
||||||
additionalHeaders?: Record<string, string>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function hmac(
|
function hmac(
|
||||||
|
@ -160,7 +159,6 @@ export async function sign({
|
||||||
body,
|
body,
|
||||||
service,
|
service,
|
||||||
isStreaming = true,
|
isStreaming = true,
|
||||||
additionalHeaders = {},
|
|
||||||
}: SignParams): Promise<Record<string, string>> {
|
}: SignParams): Promise<Record<string, string>> {
|
||||||
try {
|
try {
|
||||||
const endpoint = new URL(url);
|
const endpoint = new URL(url);
|
||||||
|
@ -181,7 +179,6 @@ export async function sign({
|
||||||
host: endpoint.host,
|
host: endpoint.host,
|
||||||
"x-amz-content-sha256": payloadHash,
|
"x-amz-content-sha256": payloadHash,
|
||||||
"x-amz-date": amzDate,
|
"x-amz-date": amzDate,
|
||||||
...additionalHeaders,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (isStreaming) {
|
if (isStreaming) {
|
||||||
|
@ -311,32 +308,25 @@ export function getBedrockEndpoint(
|
||||||
return endpoint;
|
return endpoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getModelHeaders(modelId: string): Record<string, string> {
|
|
||||||
if (!modelId) {
|
|
||||||
throw new Error("Model ID is required for headers");
|
|
||||||
}
|
|
||||||
|
|
||||||
const headers: Record<string, string> = {};
|
|
||||||
|
|
||||||
if (
|
|
||||||
modelId.startsWith("us.meta.llama") ||
|
|
||||||
modelId.startsWith("mistral.mistral")
|
|
||||||
) {
|
|
||||||
headers["content-type"] = "application/json";
|
|
||||||
headers["accept"] = "application/json";
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function extractMessage(res: any, modelId: string = ""): string {
|
export function extractMessage(res: any, modelId: string = ""): string {
|
||||||
if (!res) {
|
if (!res) {
|
||||||
console.error("[AWS Extract Error] extractMessage Empty response");
|
console.error("[AWS Extract Error] extractMessage Empty response");
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
console.log("[Response] extractMessage response: ", res);
|
console.log("[Response] extractMessage response: ", res);
|
||||||
return res?.content?.[0]?.text;
|
|
||||||
return "";
|
// Handle Mistral model response format
|
||||||
|
if (modelId.toLowerCase().includes("mistral")) {
|
||||||
|
return res?.outputs?.[0]?.text || "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Llama model response format
|
||||||
|
if (modelId.toLowerCase().includes("llama")) {
|
||||||
|
return res?.generation || "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Claude and other models
|
||||||
|
return res?.content?.[0]?.text || "";
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function* transformBedrockStream(
|
export async function* transformBedrockStream(
|
||||||
|
@ -344,58 +334,105 @@ export async function* transformBedrockStream(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
) {
|
) {
|
||||||
const reader = stream.getReader();
|
const reader = stream.getReader();
|
||||||
let buffer = "";
|
let accumulatedText = "";
|
||||||
|
let toolCallStarted = false;
|
||||||
|
let currentToolCall = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
if (done) {
|
|
||||||
if (buffer) {
|
if (done) break;
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
delta: { text: buffer },
|
|
||||||
})}\n\n`;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const parsed = parseEventData(value);
|
const parsed = parseEventData(value);
|
||||||
if (!parsed) continue;
|
if (!parsed) continue;
|
||||||
|
|
||||||
// Handle Titan models
|
console.log("parseEventData=========================");
|
||||||
if (modelId.startsWith("amazon.titan")) {
|
console.log(parsed);
|
||||||
const text = parsed.outputText || "";
|
|
||||||
if (text) {
|
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
delta: { text },
|
|
||||||
})}\n\n`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Handle LLaMA3 models
|
|
||||||
else if (modelId.startsWith("us.meta.llama3")) {
|
|
||||||
let text = "";
|
|
||||||
if (parsed.generation) {
|
|
||||||
text = parsed.generation;
|
|
||||||
} else if (parsed.output) {
|
|
||||||
text = parsed.output;
|
|
||||||
} else if (typeof parsed === "string") {
|
|
||||||
text = parsed;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (text) {
|
|
||||||
// Clean up any control characters or invalid JSON characters
|
|
||||||
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
|
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
delta: { text },
|
|
||||||
})}\n\n`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
else if (modelId.startsWith("mistral.mistral")) {
|
if (modelId.toLowerCase().includes("mistral")) {
|
||||||
const text =
|
// If we have content, accumulate it
|
||||||
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
|
if (
|
||||||
if (text) {
|
parsed.choices?.[0]?.message?.role === "assistant" &&
|
||||||
|
parsed.choices?.[0]?.message?.content
|
||||||
|
) {
|
||||||
|
accumulatedText += parsed.choices?.[0]?.message?.content;
|
||||||
|
console.log("accumulatedText=========================");
|
||||||
|
console.log(accumulatedText);
|
||||||
|
// Check for tool call in the accumulated text
|
||||||
|
if (!toolCallStarted && accumulatedText.includes("```json")) {
|
||||||
|
const jsonMatch = accumulatedText.match(
|
||||||
|
/```json\s*({[\s\S]*?})\s*```/,
|
||||||
|
);
|
||||||
|
if (jsonMatch) {
|
||||||
|
try {
|
||||||
|
const toolData = JSON.parse(jsonMatch[1]);
|
||||||
|
currentToolCall = {
|
||||||
|
id: `tool-${Date.now()}`,
|
||||||
|
name: toolData.name,
|
||||||
|
arguments: toolData.arguments,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Emit tool call start
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
type: "content_block_start",
|
||||||
|
content_block: {
|
||||||
|
type: "tool_use",
|
||||||
|
id: currentToolCall.id,
|
||||||
|
name: currentToolCall.name,
|
||||||
|
},
|
||||||
|
})}\n\n`;
|
||||||
|
|
||||||
|
// Emit tool arguments
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
type: "content_block_delta",
|
||||||
|
delta: {
|
||||||
|
type: "input_json_delta",
|
||||||
|
partial_json: JSON.stringify(currentToolCall.arguments),
|
||||||
|
},
|
||||||
|
})}\n\n`;
|
||||||
|
|
||||||
|
// Emit tool call stop
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
type: "content_block_stop",
|
||||||
|
})}\n\n`;
|
||||||
|
|
||||||
|
// Clear the accumulated text after processing the tool call
|
||||||
|
accumulatedText = accumulatedText.replace(
|
||||||
|
/```json\s*{[\s\S]*?}\s*```/,
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
toolCallStarted = false;
|
||||||
|
currentToolCall = null;
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Failed to parse tool JSON:", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// emit the text content if it's not empty
|
||||||
|
if (parsed.choices?.[0]?.message?.content.trim()) {
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
delta: { text: parsed.choices?.[0]?.message?.content },
|
||||||
|
})}\n\n`;
|
||||||
|
}
|
||||||
|
// Handle stop reason if present
|
||||||
|
if (parsed.choices?.[0]?.stop_reason) {
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
delta: { stop_reason: parsed.choices[0].stop_reason },
|
||||||
|
})}\n\n`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle Llama models
|
||||||
|
else if (modelId.toLowerCase().includes("llama")) {
|
||||||
|
if (parsed.generation) {
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { text },
|
delta: { text: parsed.generation },
|
||||||
|
})}\n\n`;
|
||||||
|
}
|
||||||
|
if (parsed.stop_reason) {
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
delta: { stop_reason: parsed.stop_reason },
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -423,6 +460,22 @@ export async function* transformBedrockStream(
|
||||||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||||
} else if (parsed.type === "content_block_stop") {
|
} else if (parsed.type === "content_block_stop") {
|
||||||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||||
|
} else {
|
||||||
|
// Handle regular text responses
|
||||||
|
const text = parsed.response || parsed.output || "";
|
||||||
|
if (text) {
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
delta: { text },
|
||||||
|
})}\n\n`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle other model text responses
|
||||||
|
const text = parsed.outputText || parsed.generation || "";
|
||||||
|
if (text) {
|
||||||
|
yield `data: ${JSON.stringify({
|
||||||
|
delta: { text },
|
||||||
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue