feat: claude function call

This commit is contained in:
Hk-Gosuto 2024-08-17 13:05:49 +00:00
parent 0a643dc71d
commit 8c5e92d66a
5 changed files with 119 additions and 80 deletions

View File

@ -18,7 +18,7 @@ export class EdgeTool {
private model: BaseLanguageModel; private model: BaseLanguageModel;
private embeddings: Embeddings; private embeddings: Embeddings | null;
private callback?: (data: string) => Promise<void>; private callback?: (data: string) => Promise<void>;
@ -26,7 +26,7 @@ export class EdgeTool {
apiKey: string | undefined, apiKey: string | undefined,
baseUrl: string, baseUrl: string,
model: BaseLanguageModel, model: BaseLanguageModel,
embeddings: Embeddings, embeddings: Embeddings | null,
callback?: (data: string) => Promise<void>, callback?: (data: string) => Promise<void>,
) { ) {
this.apiKey = apiKey; this.apiKey = apiKey;
@ -37,10 +37,6 @@ export class EdgeTool {
} }
async getCustomTools(): Promise<any[]> { async getCustomTools(): Promise<any[]> {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const calculatorTool = new Calculator(); const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper( const dallEAPITool = new DallEAPIWrapper(
this.apiKey, this.apiKey,
@ -56,7 +52,7 @@ export class EdgeTool {
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool(); const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools = [ let tools = [
calculatorTool, calculatorTool,
webBrowserTool, // webBrowserTool,
dallEAPITool, dallEAPITool,
stableDiffusionTool, stableDiffusionTool,
arxivAPITool, arxivAPITool,
@ -66,6 +62,13 @@ export class EdgeTool {
bilibiliMusicRecognitionTool, bilibiliMusicRecognitionTool,
bilibiliVideoConclusionTool, bilibiliVideoConclusionTool,
]; ];
if (this.embeddings != null) {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
tools.push(webBrowserTool);
}
return tools; return tools;
} }
} }

View File

@ -17,7 +17,7 @@ export class NodeJSTool {
private apiKey: string | undefined; private apiKey: string | undefined;
private baseUrl: string; private baseUrl: string;
private model: BaseLanguageModel; private model: BaseLanguageModel;
private embeddings: Embeddings; private embeddings: Embeddings | null;
private sessionId: string; private sessionId: string;
private ragEmbeddings: Embeddings; private ragEmbeddings: Embeddings;
private callback?: (data: string) => Promise<void>; private callback?: (data: string) => Promise<void>;
@ -26,7 +26,7 @@ export class NodeJSTool {
apiKey: string | undefined, apiKey: string | undefined,
baseUrl: string, baseUrl: string,
model: BaseLanguageModel, model: BaseLanguageModel,
embeddings: Embeddings, embeddings: Embeddings | null,
sessionId: string, sessionId: string,
ragEmbeddings: Embeddings, ragEmbeddings: Embeddings,
callback?: (data: string) => Promise<void>, callback?: (data: string) => Promise<void>,
@ -41,10 +41,6 @@ export class NodeJSTool {
} }
async getCustomTools(): Promise<any[]> { async getCustomTools(): Promise<any[]> {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const calculatorTool = new Calculator(); const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPINodeWrapper( const dallEAPITool = new DallEAPINodeWrapper(
this.apiKey, this.apiKey,
@ -54,24 +50,32 @@ export class NodeJSTool {
const stableDiffusionTool = new StableDiffusionNodeWrapper(); const stableDiffusionTool = new StableDiffusionNodeWrapper();
const arxivAPITool = new ArxivAPIWrapper(); const arxivAPITool = new ArxivAPIWrapper();
const wolframAlphaTool = new WolframAlphaTool(); const wolframAlphaTool = new WolframAlphaTool();
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
const bilibiliVideoInfoTool = new BilibiliVideoInfoTool(); const bilibiliVideoInfoTool = new BilibiliVideoInfoTool();
const bilibiliVideoSearchTool = new BilibiliVideoSearchTool(); const bilibiliVideoSearchTool = new BilibiliVideoSearchTool();
const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool(); const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool();
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool(); const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools: any = [ let tools: any = [
// webBrowserTool,
// pdfBrowserTool,
calculatorTool, calculatorTool,
webBrowserTool,
dallEAPITool, dallEAPITool,
stableDiffusionTool, stableDiffusionTool,
arxivAPITool, arxivAPITool,
wolframAlphaTool, wolframAlphaTool,
pdfBrowserTool,
bilibiliVideoInfoTool, bilibiliVideoInfoTool,
bilibiliVideoSearchTool, bilibiliVideoSearchTool,
bilibiliMusicRecognitionTool, bilibiliMusicRecognitionTool,
bilibiliVideoConclusionTool, bilibiliVideoConclusionTool,
]; ];
if (this.embeddings != null) {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
tools.push(webBrowserTool);
tools.push(pdfBrowserTool);
}
if (!!process.env.ENABLE_RAG) { if (!!process.env.ENABLE_RAG) {
tools.push( tools.push(
new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings), new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings),

View File

@ -10,7 +10,12 @@ import {
createToolCallingAgent, createToolCallingAgent,
createReactAgent, createReactAgent,
} from "langchain/agents"; } from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant"; import {
ACCESS_CODE_PREFIX,
ANTHROPIC_BASE_URL,
OPENAI_BASE_URL,
ServiceProvider,
} from "@/app/constant";
// import * as langchainTools from "langchain/tools"; // import * as langchainTools from "langchain/tools";
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index"; import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
@ -33,7 +38,7 @@ import {
ChatPromptTemplate, ChatPromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
} from "@langchain/core/prompts"; } from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai"; import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatAnthropic } from "@langchain/anthropic"; import { ChatAnthropic } from "@langchain/anthropic";
import { import {
BaseMessage, BaseMessage,
@ -45,6 +50,7 @@ import {
} from "@langchain/core/messages"; } from "@langchain/core/messages";
import { MultimodalContent } from "@/app/client/api"; import { MultimodalContent } from "@/app/client/api";
import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index"; import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
export interface RequestMessage { export interface RequestMessage {
role: string; role: string;
@ -202,29 +208,81 @@ export class AgentApi {
}); });
} }
async getOpenAIApiKey(token: string) { getApiKey(token: string, provider: ServiceProvider) {
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX); const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey;
if (isApiKey && token) { if (isApiKey && token) {
apiKey = token; return token;
} }
return apiKey; if (provider === ServiceProvider.OpenAI) return serverConfig.apiKey;
if (provider === ServiceProvider.Anthropic)
return serverConfig.anthropicApiKey;
throw new Error("Unsupported provider");
} }
async getOpenAIBaseUrl(reqBaseUrl: string | undefined) { getBaseUrl(reqBaseUrl: string | undefined, provider: ServiceProvider) {
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
let baseUrl = "https://api.openai.com/v1"; let baseUrl = "";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; if (provider === ServiceProvider.OpenAI) {
baseUrl = OPENAI_BASE_URL;
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
}
if (provider === ServiceProvider.Anthropic) {
baseUrl = ANTHROPIC_BASE_URL;
if (serverConfig.anthropicUrl) baseUrl = serverConfig.anthropicUrl;
}
if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://")) if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://"))
baseUrl = reqBaseUrl; baseUrl = reqBaseUrl;
if (!baseUrl.endsWith("/v1")) if (!baseUrl.endsWith("/v1") && provider === ServiceProvider.OpenAI)
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[openai baseUrl]", baseUrl);
return baseUrl; return baseUrl;
} }
getToolBaseLanguageModel(
reqBody: RequestBody,
apiKey: string,
baseUrl: string,
) {
if (reqBody.provider === ServiceProvider.Anthropic) {
return new ChatAnthropic({
temperature: 0,
modelName: reqBody.model,
apiKey: apiKey,
clientOptions: {
baseURL: baseUrl,
},
});
}
return new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getToolEmbeddings(reqBody: RequestBody, apiKey: string, baseUrl: string) {
if (reqBody.provider === ServiceProvider.Anthropic) {
if (process.env.OLLAMA_BASE_URL) {
return new OllamaEmbeddings({
model: process.env.RAG_EMBEDDING_MODEL,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
return null;
}
}
return new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) { getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) {
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
if (reqBody.isAzure || serverConfig.isAzure) { if (reqBody.isAzure || serverConfig.isAzure) {
@ -266,7 +324,6 @@ export class AgentApi {
temperature: reqBody.temperature, temperature: reqBody.temperature,
streaming: reqBody.stream, streaming: reqBody.stream,
topP: reqBody.top_p, topP: reqBody.top_p,
// maxTokens: 1024,
clientOptions: { clientOptions: {
baseURL: baseUrl, baseURL: baseUrl,
}, },
@ -300,22 +357,9 @@ export class AgentApi {
const authToken = req.headers.get(authHeaderName) ?? ""; const authToken = req.headers.get(authHeaderName) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
let apiKey = await this.getOpenAIApiKey(token); let apiKey = this.getApiKey(token, reqBody.provider);
if (isAzure) apiKey = token; if (isAzure) apiKey = token;
let baseUrl = "https://api.openai.com/v1"; let baseUrl = this.getBaseUrl(reqBody.baseUrl, reqBody.provider);
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
) {
baseUrl = reqBody.baseUrl;
}
if (
reqBody.provider === ServiceProvider.OpenAI &&
!baseUrl.endsWith("/v1")
) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
}
if (!reqBody.isAzure && serverConfig.isAzure) { if (!reqBody.isAzure && serverConfig.isAzure) {
baseUrl = serverConfig.azureUrl || baseUrl; baseUrl = serverConfig.azureUrl || baseUrl;
} }

View File

@ -3,7 +3,8 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth"; import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ModelProvider } from "@/app/constant"; import { ModelProvider } from "@/app/constant";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; import { Embeddings } from "@langchain/core/embeddings";
import { BaseLanguageModel } from "@langchain/core/language_models/base";
async function handle(req: NextRequest) { async function handle(req: NextRequest) {
if (req.method === "OPTIONS") { if (req.method === "OPTIONS") {
@ -27,23 +28,13 @@ async function handle(req: NextRequest) {
const authToken = req.headers.get("Authorization") ?? ""; const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
const apiKey = await agentApi.getOpenAIApiKey(token); const apiKey = agentApi.getApiKey(token, reqBody.provider);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
const model = new ChatOpenAI( let model: BaseLanguageModel;
{ let embeddings: Embeddings | null;
temperature: 0, model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
modelName: reqBody.model, embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
const embeddings = new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
var dalleCallback = async (data: string) => { var dalleCallback = async (data: string) => {
var response = new ResponseBody(); var response = new ResponseBody();

View File

@ -2,17 +2,23 @@ import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth"; import { auth } from "@/app/api/auth";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
import { ModelProvider } from "@/app/constant"; import { ModelProvider, ServiceProvider } from "@/app/constant";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"; import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
import { Embeddings } from "@langchain/core/embeddings"; import { Embeddings } from "@langchain/core/embeddings";
import { BaseLanguageModel } from "@langchain/core/language_models/base";
async function handle(req: NextRequest) { async function handle(req: NextRequest) {
if (req.method === "OPTIONS") { if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 }); return NextResponse.json({ body: "OK" }, { status: 200 });
} }
try { try {
const authResult = auth(req, ModelProvider.GPT); const reqBody: RequestBody = await req.json();
const modelProvider =
reqBody.provider === ServiceProvider.Anthropic
? ModelProvider.Claude
: ModelProvider.GPT;
const authResult = auth(req, modelProvider);
if (authResult.error) { if (authResult.error) {
return NextResponse.json(authResult, { return NextResponse.json(authResult, {
status: 401, status: 401,
@ -25,27 +31,17 @@ async function handle(req: NextRequest) {
const controller = new AbortController(); const controller = new AbortController();
const agentApi = new AgentApi(encoder, transformStream, writer, controller); const agentApi = new AgentApi(encoder, transformStream, writer, controller);
const reqBody: RequestBody = await req.json(); const authToken =
const authToken = req.headers.get("Authorization") ?? ""; (req.headers.get("Authorization") || req.headers.get("x-api-key")) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
const apiKey = await agentApi.getOpenAIApiKey(token); const apiKey = agentApi.getApiKey(token, reqBody.provider);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
let model: BaseLanguageModel;
let embeddings: Embeddings | null;
model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
const model = new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
const embeddings = new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let ragEmbeddings: Embeddings; let ragEmbeddings: Embeddings;
if (process.env.OLLAMA_BASE_URL) { if (process.env.OLLAMA_BASE_URL) {
ragEmbeddings = new OllamaEmbeddings({ ragEmbeddings = new OllamaEmbeddings({
@ -98,6 +94,7 @@ async function handle(req: NextRequest) {
export const GET = handle; export const GET = handle;
export const POST = handle; export const POST = handle;
export const maxDuration = 60;
export const runtime = "nodejs"; export const runtime = "nodejs";
export const preferredRegion = [ export const preferredRegion = [
"arn1", "arn1",