stash code

This commit is contained in:
lloydzhou 2024-08-28 23:58:46 +08:00
parent c99cd31b6b
commit f5209fc344
6 changed files with 276 additions and 85 deletions

View File

@ -32,10 +32,7 @@ export async function requestOpenai(req: NextRequest) {
authHeaderName = "Authorization"; authHeaderName = "Authorization";
} }
let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( let path = `${req.nextUrl.pathname}`.replaceAll("/api/openai/", "");
"/api/openai/",
"",
);
let baseUrl = let baseUrl =
(isAzure ? serverConfig.azureUrl : serverConfig.baseUrl) || OPENAI_BASE_URL; (isAzure ? serverConfig.azureUrl : serverConfig.baseUrl) || OPENAI_BASE_URL;

View File

@ -5,7 +5,13 @@ import {
ModelProvider, ModelProvider,
ServiceProvider, ServiceProvider,
} from "../constant"; } from "../constant";
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import {
ChatMessageTool,
ChatMessage,
ModelType,
useAccessStore,
useChatStore,
} from "../store";
import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
import { GeminiProApi } from "./platforms/google"; import { GeminiProApi } from "./platforms/google";
import { ClaudeApi } from "./platforms/anthropic"; import { ClaudeApi } from "./platforms/anthropic";
@ -56,6 +62,8 @@ export interface ChatOptions {
onFinish: (message: string) => void; onFinish: (message: string) => void;
onError?: (err: Error) => void; onError?: (err: Error) => void;
onController?: (controller: AbortController) => void; onController?: (controller: AbortController) => void;
onBeforeTool?: (tool: ChatMessageTool) => void;
onAfterTool?: (tool: ChatMessageTool) => void;
} }
export interface LLMUsage { export interface LLMUsage {

View File

@ -250,6 +250,8 @@ export class ChatGPTApi implements LLMApi {
let responseText = ""; let responseText = "";
let remainText = ""; let remainText = "";
let finished = false; let finished = false;
let running = false;
let runTools = [];
// animate response to make it looks smooth // animate response to make it looks smooth
function animateResponseText() { function animateResponseText() {
@ -276,8 +278,70 @@ export class ChatGPTApi implements LLMApi {
// start animaion // start animaion
animateResponseText(); animateResponseText();
// TODO 后面这里是从选择的plugins中获取function列表
const funcs = {
get_current_weather: (args) => {
console.log("call get_current_weather", args);
return "30";
},
};
const finish = () => { const finish = () => {
if (!finished) { if (!finished) {
console.log("try run tools", runTools.length, finished, running);
if (!running && runTools.length > 0) {
const toolCallMessage = {
role: "assistant",
tool_calls: [...runTools],
};
running = true;
runTools.splice(0, runTools.length); // empty runTools
return Promise.all(
toolCallMessage.tool_calls.map((tool) => {
options?.onBeforeTool(tool);
return Promise.resolve(
funcs[tool.function.name](
JSON.parse(tool.function.arguments),
),
)
.then((content) => {
options?.onAfterTool({
...tool,
content,
isError: false,
});
return content;
})
.catch((e) => {
options?.onAfterTool({ ...tool, isError: true });
return e.toString();
})
.then((content) => ({
role: "tool",
content,
tool_call_id: tool.id,
}));
}),
).then((toolCallResult) => {
console.log("end runTools", toolCallMessage, toolCallResult);
requestPayload["messages"].splice(
requestPayload["messages"].length,
0,
toolCallMessage,
...toolCallResult,
);
setTimeout(() => {
// call again
console.log("start again");
running = false;
chatApi(chatPath, requestPayload); // call fetchEventSource
}, 0);
});
console.log("try run tools", runTools.length, finished);
return;
}
if (running) {
return;
}
finished = true; finished = true;
options.onFinish(responseText + remainText); options.onFinish(responseText + remainText);
} }
@ -285,90 +349,148 @@ export class ChatGPTApi implements LLMApi {
controller.signal.onabort = finish; controller.signal.onabort = finish;
fetchEventSource(chatPath, { function chatApi(chatPath, requestPayload) {
...chatPayload, const chatPayload = {
async onopen(res) { method: "POST",
clearTimeout(requestTimeoutId); body: JSON.stringify({
const contentType = res.headers.get("content-type"); ...requestPayload,
console.log( // TODO 这里暂时写死的后面从store.tools中按照当前session中选择的获取
"[OpenAI] request response content type: ", tools: [
contentType, {
); type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description:
"The city and country, eg. San Francisco, USA",
},
format: {
type: "string",
enum: ["celsius", "fahrenheit"],
},
},
required: ["location", "format"],
},
},
},
],
}),
signal: controller.signal,
headers: getHeaders(),
};
console.log("chatApi", chatPath, requestPayload, chatPayload);
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[OpenAI] request response content type: ",
contentType,
);
if (contentType?.startsWith("text/plain")) { if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text(); responseText = await res.clone().text();
return finish(); return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const choices = json.choices as Array<{
delta: { content: string };
}>;
const delta = choices[0]?.delta?.content;
const textmoderation = json?.prompt_filter_results;
if (delta) {
remainText += delta;
} }
if ( if (
textmoderation && !res.ok ||
textmoderation.length > 0 && !res.headers
ServiceProvider.Azure .get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) { ) {
const contentFilterResults = const responseTexts = [responseText];
textmoderation[0]?.content_filter_results; let extraInfo = await res.clone().text();
console.log( try {
`[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`, const resJson = await res.clone().json();
contentFilterResults, extraInfo = prettyObject(resJson);
); } catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
} }
} catch (e) { },
console.error("[Request] parse error", text, msg); onmessage(msg) {
} if (msg.data === "[DONE]" || finished) {
}, return finish();
onclose() { }
finish(); const text = msg.data;
}, try {
onerror(e) { const json = JSON.parse(text);
options.onError?.(e); const choices = json.choices as Array<{
throw e; delta: { content: string };
}, }>;
openWhenHidden: true, console.log("choices", choices);
}); const delta = choices[0]?.delta?.content;
const tool_calls = choices[0]?.delta?.tool_calls;
const textmoderation = json?.prompt_filter_results;
if (delta) {
remainText += delta;
}
if (tool_calls?.length > 0) {
const index = tool_calls[0]?.index;
const id = tool_calls[0]?.id;
const args = tool_calls[0]?.function?.arguments;
if (id) {
runTools.push({
id,
type: tool_calls[0]?.type,
function: {
name: tool_calls[0]?.function?.name,
arguments: args,
},
});
} else {
runTools[index]["function"]["arguments"] += args;
}
}
console.log("runTools", runTools);
if (
textmoderation &&
textmoderation.length > 0 &&
ServiceProvider.Azure
) {
const contentFilterResults =
textmoderation[0]?.content_filter_results;
console.log(
`[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`,
contentFilterResults,
);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
}
chatApi(chatPath, requestPayload); // call fetchEventSource
} else { } else {
const res = await fetch(chatPath, chatPayload); const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);

View File

@ -413,6 +413,21 @@
margin-top: 5px; margin-top: 5px;
} }
.chat-message-tools {
font-size: 12px;
color: #aaa;
line-height: 1.5;
margin-top: 5px;
.chat-message-tool {
display: inline-flex;
align-items: end;
svg {
margin-left: 5px;
margin-right: 5px;
}
}
}
.chat-message-item { .chat-message-item {
box-sizing: border-box; box-sizing: border-box;
max-width: 100%; max-width: 100%;
@ -630,4 +645,4 @@
.chat-input-send { .chat-input-send {
bottom: 30px; bottom: 30px;
} }
} }

View File

@ -28,6 +28,7 @@ import DeleteIcon from "../icons/clear.svg";
import PinIcon from "../icons/pin.svg"; import PinIcon from "../icons/pin.svg";
import EditIcon from "../icons/rename.svg"; import EditIcon from "../icons/rename.svg";
import ConfirmIcon from "../icons/confirm.svg"; import ConfirmIcon from "../icons/confirm.svg";
import CloseIcon from "../icons/close.svg";
import CancelIcon from "../icons/cancel.svg"; import CancelIcon from "../icons/cancel.svg";
import ImageIcon from "../icons/image.svg"; import ImageIcon from "../icons/image.svg";
@ -1573,11 +1574,30 @@ function _Chat() {
</div> </div>
)} )}
</div> </div>
{showTyping && ( {message?.tools?.length == 0 && showTyping && (
<div className={styles["chat-message-status"]}> <div className={styles["chat-message-status"]}>
{Locale.Chat.Typing} {Locale.Chat.Typing}
</div> </div>
)} )}
{message?.tools?.length > 0 && (
<div className={styles["chat-message-tools"]}>
{message?.tools?.map((tool) => (
<div
key={tool.id}
className={styles["chat-message-tool"]}
>
{tool.isError === false ? (
<ConfirmIcon />
) : tool.isError === true ? (
<CloseIcon />
) : (
<LoadingButtonIcon />
)}
<span>{tool.function.name}</span>
</div>
))}
</div>
)}
<div className={styles["chat-message-item"]}> <div className={styles["chat-message-item"]}>
<Markdown <Markdown
key={message.streaming ? "loading" : "done"} key={message.streaming ? "loading" : "done"}

View File

@ -28,12 +28,24 @@ import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access"; import { useAccessStore } from "./access";
import { isDalle3 } from "../utils"; import { isDalle3 } from "../utils";
export type ChatMessageTool = {
id: string;
type?: string;
function?: {
name: string;
arguments?: string;
};
content?: string;
isError?: boolean;
};
export type ChatMessage = RequestMessage & { export type ChatMessage = RequestMessage & {
date: string; date: string;
streaming?: boolean; streaming?: boolean;
isError?: boolean; isError?: boolean;
id: string; id: string;
model?: ModelType; model?: ModelType;
tools?: ChatMessageTool[];
}; };
export function createMessage(override: Partial<ChatMessage>): ChatMessage { export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@ -389,6 +401,23 @@ export const useChatStore = createPersistStore(
} }
ChatControllerPool.remove(session.id, botMessage.id); ChatControllerPool.remove(session.id, botMessage.id);
}, },
onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool);
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onAfterTool(tool: ChatMessageTool) {
console.log("onAfterTool", botMessage);
botMessage?.tools?.forEach((t, i, tools) => {
if (tool.id == t.id) {
tools[i] = { ...tool };
}
});
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onError(error) { onError(error) {
const isAborted = error.message.includes("aborted"); const isAborted = error.message.includes("aborted");
botMessage.content += botMessage.content +=