diff --git a/app/api/common.ts b/app/api/common.ts index 24453dd96..25decbf62 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -32,10 +32,7 @@ export async function requestOpenai(req: NextRequest) { authHeaderName = "Authorization"; } - let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( - "/api/openai/", - "", - ); + let path = `${req.nextUrl.pathname}`.replaceAll("/api/openai/", ""); let baseUrl = (isAzure ? serverConfig.azureUrl : serverConfig.baseUrl) || OPENAI_BASE_URL; diff --git a/app/client/api.ts b/app/client/api.ts index d7fb023a2..cecc453ba 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -5,7 +5,13 @@ import { ModelProvider, ServiceProvider, } from "../constant"; -import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; +import { + ChatMessageTool, + ChatMessage, + ModelType, + useAccessStore, + useChatStore, +} from "../store"; import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; @@ -56,6 +62,8 @@ export interface ChatOptions { onFinish: (message: string) => void; onError?: (err: Error) => void; onController?: (controller: AbortController) => void; + onBeforeTool?: (tool: ChatMessageTool) => void; + onAfterTool?: (tool: ChatMessageTool) => void; } export interface LLMUsage { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index d4e262c16..03bc3e09f 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -250,6 +250,8 @@ export class ChatGPTApi implements LLMApi { let responseText = ""; let remainText = ""; let finished = false; + let running = false; + let runTools = []; // animate response to make it looks smooth function animateResponseText() { @@ -276,8 +278,70 @@ export class ChatGPTApi implements LLMApi { // start animaion animateResponseText(); + // TODO 后面这里是从选择的plugins中获取function列表 + const funcs = { + get_current_weather: (args) => { + console.log("call get_current_weather", args); + return "30"; + }, + }; const finish = () => { if (!finished) { + console.log("try run tools", runTools.length, finished, running); + if (!running && runTools.length > 0) { + const toolCallMessage = { + role: "assistant", + tool_calls: [...runTools], + }; + running = true; + runTools.splice(0, runTools.length); // empty runTools + return Promise.all( + toolCallMessage.tool_calls.map((tool) => { + options?.onBeforeTool(tool); + return Promise.resolve( + funcs[tool.function.name]( + JSON.parse(tool.function.arguments), + ), + ) + .then((content) => { + options?.onAfterTool({ + ...tool, + content, + isError: false, + }); + return content; + }) + .catch((e) => { + options?.onAfterTool({ ...tool, isError: true }); + return e.toString(); + }) + .then((content) => ({ + role: "tool", + content, + tool_call_id: tool.id, + })); + }), + ).then((toolCallResult) => { + console.log("end runTools", toolCallMessage, toolCallResult); + requestPayload["messages"].splice( + requestPayload["messages"].length, + 0, + toolCallMessage, + ...toolCallResult, + ); + setTimeout(() => { + // call again + console.log("start again"); + running = false; + chatApi(chatPath, requestPayload); // call fetchEventSource + }, 0); + }); + console.log("try run tools", runTools.length, finished); + return; + } + if (running) { + return; + } finished = true; options.onFinish(responseText + remainText); } @@ -285,90 +349,148 @@ export class ChatGPTApi implements LLMApi { controller.signal.onabort = finish; - fetchEventSource(chatPath, { - ...chatPayload, - async onopen(res) { - clearTimeout(requestTimeoutId); - const contentType = res.headers.get("content-type"); - console.log( - "[OpenAI] request response content type: ", - contentType, - ); + function chatApi(chatPath, requestPayload) { + const chatPayload = { + method: "POST", + body: JSON.stringify({ + ...requestPayload, + // TODO 这里暂时写死的,后面从store.tools中按照当前session中选择的获取 + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: + "The city and country, eg. San Francisco, USA", + }, + format: { + type: "string", + enum: ["celsius", "fahrenheit"], + }, + }, + required: ["location", "format"], + }, + }, + }, + ], + }), + signal: controller.signal, + headers: getHeaders(), + }; + console.log("chatApi", chatPath, requestPayload, chatPayload); + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log( + "[OpenAI] request response content type: ", + contentType, + ); - if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); - } - - if ( - !res.ok || - !res.headers - .get("content-type") - ?.startsWith(EventStreamContentType) || - res.status !== 200 - ) { - const responseTexts = [responseText]; - let extraInfo = await res.clone().text(); - try { - const resJson = await res.clone().json(); - extraInfo = prettyObject(resJson); - } catch {} - - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - - if (extraInfo) { - responseTexts.push(extraInfo); - } - - responseText = responseTexts.join("\n\n"); - - return finish(); - } - }, - onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); - } - const text = msg.data; - try { - const json = JSON.parse(text); - const choices = json.choices as Array<{ - delta: { content: string }; - }>; - const delta = choices[0]?.delta?.content; - const textmoderation = json?.prompt_filter_results; - - if (delta) { - remainText += delta; + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); } if ( - textmoderation && - textmoderation.length > 0 && - ServiceProvider.Azure + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 ) { - const contentFilterResults = - textmoderation[0]?.content_filter_results; - console.log( - `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`, - contentFilterResults, - ); + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); } - } catch (e) { - console.error("[Request] parse error", text, msg); - } - }, - onclose() { - finish(); - }, - onerror(e) { - options.onError?.(e); - throw e; - }, - openWhenHidden: true, - }); + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text); + const choices = json.choices as Array<{ + delta: { content: string }; + }>; + console.log("choices", choices); + const delta = choices[0]?.delta?.content; + const tool_calls = choices[0]?.delta?.tool_calls; + const textmoderation = json?.prompt_filter_results; + + if (delta) { + remainText += delta; + } + if (tool_calls?.length > 0) { + const index = tool_calls[0]?.index; + const id = tool_calls[0]?.id; + const args = tool_calls[0]?.function?.arguments; + if (id) { + runTools.push({ + id, + type: tool_calls[0]?.type, + function: { + name: tool_calls[0]?.function?.name, + arguments: args, + }, + }); + } else { + runTools[index]["function"]["arguments"] += args; + } + } + + console.log("runTools", runTools); + + if ( + textmoderation && + textmoderation.length > 0 && + ServiceProvider.Azure + ) { + const contentFilterResults = + textmoderation[0]?.content_filter_results; + console.log( + `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`, + contentFilterResults, + ); + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } + chatApi(chatPath, requestPayload); // call fetchEventSource } else { const res = await fetch(chatPath, chatPayload); clearTimeout(requestTimeoutId); diff --git a/app/components/chat.module.scss b/app/components/chat.module.scss index 3b5c143b9..33ccaf523 100644 --- a/app/components/chat.module.scss +++ b/app/components/chat.module.scss @@ -413,6 +413,21 @@ margin-top: 5px; } +.chat-message-tools { + font-size: 12px; + color: #aaa; + line-height: 1.5; + margin-top: 5px; + .chat-message-tool { + display: inline-flex; + align-items: end; + svg { + margin-left: 5px; + margin-right: 5px; + } + } +} + .chat-message-item { box-sizing: border-box; max-width: 100%; @@ -630,4 +645,4 @@ .chat-input-send { bottom: 30px; } -} \ No newline at end of file +} diff --git a/app/components/chat.tsx b/app/components/chat.tsx index ed5b06799..3ad8cd5c9 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -28,6 +28,7 @@ import DeleteIcon from "../icons/clear.svg"; import PinIcon from "../icons/pin.svg"; import EditIcon from "../icons/rename.svg"; import ConfirmIcon from "../icons/confirm.svg"; +import CloseIcon from "../icons/close.svg"; import CancelIcon from "../icons/cancel.svg"; import ImageIcon from "../icons/image.svg"; @@ -1573,11 +1574,30 @@ function _Chat() { )} - {showTyping && ( + {message?.tools?.length == 0 && showTyping && (
{Locale.Chat.Typing}
)} + {message?.tools?.length > 0 && ( +
+ {message?.tools?.map((tool) => ( +
+ {tool.isError === false ? ( + + ) : tool.isError === true ? ( + + ) : ( + + )} + {tool.function.name} +
+ ))} +
+ )}
): ChatMessage { @@ -389,6 +401,23 @@ export const useChatStore = createPersistStore( } ChatControllerPool.remove(session.id, botMessage.id); }, + onBeforeTool(tool: ChatMessageTool) { + (botMessage.tools = botMessage?.tools || []).push(tool); + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, + onAfterTool(tool: ChatMessageTool) { + console.log("onAfterTool", botMessage); + botMessage?.tools?.forEach((t, i, tools) => { + if (tool.id == t.id) { + tools[i] = { ...tool }; + } + }); + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, onError(error) { const isAborted = error.message.includes("aborted"); botMessage.content +=