google gemini support function call

This commit is contained in:
lloydzhou 2024-10-03 20:28:15 +08:00
parent cea5b91f96
commit 450766a44b
3 changed files with 87 additions and 112 deletions

View File

@ -7,21 +7,25 @@ import {
LLMUsage, LLMUsage,
SpeechOptions, SpeechOptions,
} from "../api"; } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import {
useAccessStore,
useAppConfig,
useChatStore,
usePluginStore,
ChatMessageTool,
} from "@/app/store";
import { stream } from "@/app/utils/chat";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant"; import { DEFAULT_API_HOST } from "@/app/constant";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { import {
getMessageTextContent, getMessageTextContent,
getMessageImages, getMessageImages,
isVisionModel, isVisionModel,
} from "@/app/utils"; } from "@/app/utils";
import { preProcessImageContent } from "@/app/utils/chat"; import { preProcessImageContent } from "@/app/utils/chat";
import { nanoid } from "nanoid";
import { RequestPayload } from "./openai";
export class GeminiProApi implements LLMApi { export class GeminiProApi implements LLMApi {
path(path: string): string { path(path: string): string {
@ -177,114 +181,81 @@ export class GeminiProApi implements LLMApi {
); );
if (shouldStream) { if (shouldStream) {
let responseText = ""; const [tools, funcs] = usePluginStore
let remainText = ""; .getState()
let finished = false; .getAsTools(
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
chatPath,
requestPayload,
getHeaders(),
// @ts-ignore
[{ functionDeclarations: tools.map((tool) => tool.function) }],
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
const chunkJson = JSON.parse(text);
const finish = () => { const functionCall = chunkJson?.candidates
if (!finished) { ?.at(0)
finished = true; ?.content.parts.at(0)?.functionCall;
options.onFinish(responseText + remainText); if (functionCall) {
} const { name, args } = functionCall;
}; runTools.push({
id: nanoid(),
// animate response to make it looks smooth type: "function",
function animateResponseText() { function: {
if (finished || controller.signal.aborted) { name,
responseText += remainText; arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse
finish(); },
return; });
} }
return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text;
if (remainText.length > 0) { },
const fetchCount = Math.max(1, Math.round(remainText.length / 60)); // processToolMessage, include tool_calls message and tool call results
const fetchText = remainText.slice(0, fetchCount); (
responseText += fetchText; requestPayload: RequestPayload,
remainText = remainText.slice(fetchCount); toolCallMessage: any,
options.onUpdate?.(responseText, fetchText); toolCallResult: any[],
} ) => {
// @ts-ignore
requestAnimationFrame(animateResponseText); requestPayload?.contents?.splice(
} // @ts-ignore
requestPayload?.contents?.length,
// start animaion 0,
animateResponseText(); {
role: "model",
controller.signal.onabort = finish; parts: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
fetchEventSource(chatPath, { functionCall: {
...chatPayload, name: tool?.function?.name,
async onopen(res) { args: JSON.parse(tool?.function?.arguments as string),
clearTimeout(requestTimeoutId); },
const contentType = res.headers.get("content-type"); }),
console.log( ),
"[Gemini] request response content type: ", },
contentType, // @ts-ignore
...toolCallResult.map((result) => ({
role: "function",
parts: [
{
functionResponse: {
name: result.name,
response: {
name: result.name,
content: result.content, // TODO just text content...
},
},
},
],
})),
); );
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
}, },
onmessage(msg) { options,
if (msg.data === "[DONE]" || finished) { );
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const delta = apiClient.extractMessage(json);
if (delta) {
remainText += delta;
}
const blockReason = json?.promptFeedback?.blockReason;
if (blockReason) {
// being blocked
console.log(`[Google] [Safety Ratings] result:`, blockReason);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
} else { } else {
const res = await fetch(chatPath, chatPayload); const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);

View File

@ -284,6 +284,9 @@ export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true; return true;
} }
if (provider == ServiceProvider.Google && !model.includes("vision")) {
return true;
}
return false; return false;
} }

View File

@ -240,6 +240,7 @@ export function stream(
return e.toString(); return e.toString();
}) })
.then((content) => ({ .then((content) => ({
name: tool.function.name,
role: "tool", role: "tool",
content, content,
tool_call_id: tool.id, tool_call_id: tool.id,