claude support function call

This commit is contained in:
lloydzhou 2024-09-02 21:45:47 +08:00
parent 877668b629
commit 801b62543a
6 changed files with 145 additions and 112 deletions

View File

@ -38,6 +38,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
console.log("[Auth] hashed access code:", hashedCode); console.log("[Auth] hashed access code:", hashedCode);
console.log("[User IP] ", getIP(req)); console.log("[User IP] ", getIP(req));
console.log("[Time] ", new Date().toLocaleString()); console.log("[Time] ", new Date().toLocaleString());
console.log("[ModelProvider] ", modelProvider);
if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !apiKey) { if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !apiKey) {
return { return {

View File

@ -1,6 +1,12 @@
import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant"; import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant";
import { ChatOptions, getHeaders, LLMApi, MultimodalContent } from "../api"; import { ChatOptions, getHeaders, LLMApi, MultimodalContent } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import {
useAccessStore,
useAppConfig,
useChatStore,
usePluginStore,
ChatMessageTool,
} from "@/app/store";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant"; import { DEFAULT_API_HOST } from "@/app/constant";
import { import {
@ -11,8 +17,9 @@ import {
import Locale from "../../locales"; import Locale from "../../locales";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import { getMessageTextContent, isVisionModel } from "@/app/utils"; import { getMessageTextContent, isVisionModel } from "@/app/utils";
import { preProcessImageContent } from "@/app/utils/chat"; import { preProcessImageContent, stream } from "@/app/utils/chat";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { RequestPayload } from "./openai";
export type MultiBlockContent = { export type MultiBlockContent = {
type: "image" | "text"; type: "image" | "text";
@ -191,112 +198,123 @@ export class ClaudeApi implements LLMApi {
const controller = new AbortController(); const controller = new AbortController();
options.onController?.(controller); options.onController?.(controller);
const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
...getHeaders(), // get common headers
"anthropic-version": accessStore.anthropicApiVersion,
// do not send `anthropicApiKey` in browser!!!
// Authorization: getAuthKey(accessStore.anthropicApiKey),
},
};
if (shouldStream) { if (shouldStream) {
try { let index = -1;
const context = { const [tools, funcs] = usePluginStore
text: "", .getState()
finished: false, .getAsTools(
}; useChatStore.getState().currentSession().mask?.plugin as string[],
);
const finish = () => { console.log("getAsTools", tools, funcs);
if (!context.finished) { return stream(
options.onFinish(context.text); path,
context.finished = true; requestBody,
} {
}; ...getHeaders(),
"anthropic-version": accessStore.anthropicApiVersion,
controller.signal.onabort = finish; },
fetchEventSource(path, { // @ts-ignore
...payload, tools.map((tool) => ({
async onopen(res) { name: tool?.function?.name,
const contentType = res.headers.get("content-type"); description: tool?.function?.description,
console.log("response content type: ", contentType); input_schema: tool?.function?.parameters,
})),
if (contentType?.startsWith("text/plain")) { funcs,
context.text = await res.clone().text(); controller,
return finish(); // parseSSE
} (text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
if ( let chunkJson:
!res.ok || | undefined
!res.headers | {
.get("content-type") type: "content_block_delta" | "content_block_stop";
?.startsWith(EventStreamContentType) || content_block?: {
res.status !== 200 type: "tool_use";
) { id: string;
const responseTexts = [context.text]; name: string;
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
context.text = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
let chunkJson:
| undefined
| {
type: "content_block_delta" | "content_block_stop";
delta?: {
type: "text_delta";
text: string;
};
index: number;
}; };
try { delta?: {
chunkJson = JSON.parse(msg.data); type: "text_delta" | "input_json_delta";
} catch (e) { text?: string;
console.error("[Response] parse error", msg.data); partial_json?: string;
} };
index: number;
};
chunkJson = JSON.parse(text);
if (!chunkJson || chunkJson.type === "content_block_stop") { if (chunkJson?.content_block?.type == "tool_use") {
return finish(); index += 1;
} const id = chunkJson?.content_block.id;
const name = chunkJson?.content_block.name;
const { delta } = chunkJson; runTools.push({
if (delta?.text) { id,
context.text += delta.text; type: "function",
options.onUpdate?.(context.text, delta.text); function: {
} name,
}, arguments: "",
onclose() { },
finish(); });
}, }
onerror(e) { if (
options.onError?.(e); chunkJson?.delta?.type == "input_json_delta" &&
throw e; chunkJson?.delta?.partial_json
}, ) {
openWhenHidden: true, // @ts-ignore
}); runTools[index]["function"]["arguments"] +=
} catch (e) { chunkJson?.delta?.partial_json;
console.error("failed to chat", e); }
options.onError?.(e as Error); return chunkJson?.delta?.text;
} },
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: JSON.parse(tool?.function?.arguments as string),
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
},
options,
);
} else { } else {
const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
...getHeaders(), // get common headers
"anthropic-version": accessStore.anthropicApiVersion,
// do not send `anthropicApiKey` in browser!!!
// Authorization: getAuthKey(accessStore.anthropicApiKey),
},
};
try { try {
controller.signal.onabort = () => options.onFinish(""); controller.signal.onabort = () => options.onFinish("");

View File

@ -246,7 +246,7 @@ export class ChatGPTApi implements LLMApi {
.getAsTools( .getAsTools(
useChatStore.getState().currentSession().mask?.plugin as string[], useChatStore.getState().currentSession().mask?.plugin as string[],
); );
console.log("getAsTools", tools, funcs); // console.log("getAsTools", tools, funcs);
stream( stream(
chatPath, chatPath,
requestPayload, requestPayload,

View File

@ -66,6 +66,7 @@ import {
getMessageImages, getMessageImages,
isVisionModel, isVisionModel,
isDalle3, isDalle3,
showPlugins,
} from "../utils"; } from "../utils";
import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@ -741,12 +742,14 @@ export function ChatActions(props: {
value: ArtifactsPlugin.Artifacts as string, value: ArtifactsPlugin.Artifacts as string,
}, },
].concat( ].concat(
pluginStore.getAll().map((item) => ({ showPlugins(currentProviderName, currentModel)
// @ts-ignore ? pluginStore.getAll().map((item) => ({
title: `${item?.title}@${item?.version}`, // @ts-ignore
// @ts-ignore title: `${item?.title}@${item?.version}`,
value: item?.id, // @ts-ignore
})), value: item?.id,
}))
: [],
)} )}
onClose={() => setShowPluginSelector(false)} onClose={() => setShowPluginSelector(false)}
onSelection={(s) => { onSelection={(s) => {

View File

@ -2,6 +2,7 @@ import { useEffect, useState } from "react";
import { showToast } from "./components/ui-lib"; import { showToast } from "./components/ui-lib";
import Locale from "./locales"; import Locale from "./locales";
import { RequestMessage } from "./client/api"; import { RequestMessage } from "./client/api";
import { ServiceProvider } from "./constant";
export function trimTopic(topic: string) { export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language // Fix an issue where double quotes still show in the Indonesian language
@ -270,3 +271,13 @@ export function isVisionModel(model: string) {
export function isDalle3(model: string) { export function isDalle3(model: string) {
return "dall-e-3" === model; return "dall-e-3" === model;
} }
export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.OpenAI || provider == ServiceProvider.Azure) {
return true;
}
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true;
}
return false;
}

View File

@ -334,7 +334,7 @@ export function stream(
remainText += chunk; remainText += chunk;
} }
} catch (e) { } catch (e) {
console.error("[Request] parse error", text, msg); console.error("[Request] parse error", text, msg, e);
} }
}, },
onclose() { onclose() {