add processToolMessage callback

This commit is contained in:
lloydzhou 2024-08-29 17:28:15 +08:00
parent 7fc0d11931
commit d2cb984ced
2 changed files with 27 additions and 16 deletions

View File

@ -240,6 +240,7 @@ export class ChatGPTApi implements LLMApi {
); );
} }
if (shouldStream) { if (shouldStream) {
// TODO mock tools and funcs
const tools = [ const tools = [
{ {
type: "function", type: "function",
@ -278,8 +279,9 @@ export class ChatGPTApi implements LLMApi {
tools, tools,
funcs, funcs,
controller, controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => { (text: string, runTools: ChatMessageTool[]) => {
console.log("parseSSE", text, runTools); // console.log("parseSSE", text, runTools);
const json = JSON.parse(text); const json = JSON.parse(text);
const choices = json.choices as Array<{ const choices = json.choices as Array<{
delta: { delta: {
@ -306,10 +308,23 @@ export class ChatGPTApi implements LLMApi {
runTools[index]["function"]["arguments"] += args; runTools[index]["function"]["arguments"] += args;
} }
} }
console.log("runTools", runTools);
return choices[0]?.delta?.content; return choices[0]?.delta?.content;
}, },
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
toolCallMessage,
...toolCallResult,
);
},
options, options,
); );
} else { } else {

View File

@ -161,6 +161,11 @@ export function stream(
funcs: any, funcs: any,
controller: AbortController, controller: AbortController,
parseSSE: (text: string, runTools: any[]) => string | undefined, parseSSE: (text: string, runTools: any[]) => string | undefined,
processToolMessage: (
requestPayload: any,
toolCallMessage: any,
toolCallResult: any[],
) => void,
options: any, options: any,
) { ) {
let responseText = ""; let responseText = "";
@ -196,7 +201,6 @@ export function stream(
const finish = () => { const finish = () => {
if (!finished) { if (!finished) {
console.log("try run tools", runTools.length, finished, running);
if (!running && runTools.length > 0) { if (!running && runTools.length > 0) {
const toolCallMessage = { const toolCallMessage = {
role: "assistant", role: "assistant",
@ -233,28 +237,20 @@ export function stream(
})); }));
}), }),
).then((toolCallResult) => { ).then((toolCallResult) => {
console.log("end runTools", toolCallMessage, toolCallResult); processToolMessage(requestPayload, toolCallMessage, toolCallResult);
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
toolCallMessage,
...toolCallResult,
);
setTimeout(() => { setTimeout(() => {
// call again // call again
console.log("start again"); console.debug("[ChatAPI] restart");
running = false; running = false;
chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
}, 60); }, 60);
}); });
console.log("try run tools", runTools.length, finished);
return; return;
} }
if (running) { if (running) {
return; return;
} }
console.debug("[ChatAPI] end");
finished = true; finished = true;
options.onFinish(responseText + remainText); options.onFinish(responseText + remainText);
} }
@ -343,7 +339,7 @@ export function stream(
}, },
openWhenHidden: true, openWhenHidden: true,
}); });
console.log("chatApi", chatPath, requestPayload, tools);
} }
console.debug("[ChatAPI] start");
chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
} }