From 8c5e92d66ad65e9f1352d1fce599df82daca694f Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Sat, 17 Aug 2024 13:05:49 +0000 Subject: [PATCH] feat: claude function call --- app/api/langchain-tools/edge_tools.ts | 17 ++-- app/api/langchain-tools/nodejs_tools.ts | 22 +++-- app/api/langchain/tool/agent/agentapi.ts | 98 ++++++++++++++------ app/api/langchain/tool/agent/edge/route.ts | 25 ++--- app/api/langchain/tool/agent/nodejs/route.ts | 37 ++++---- 5 files changed, 119 insertions(+), 80 deletions(-) diff --git a/app/api/langchain-tools/edge_tools.ts b/app/api/langchain-tools/edge_tools.ts index 35d4eba3b..4388a1feb 100644 --- a/app/api/langchain-tools/edge_tools.ts +++ b/app/api/langchain-tools/edge_tools.ts @@ -18,7 +18,7 @@ export class EdgeTool { private model: BaseLanguageModel; - private embeddings: Embeddings; + private embeddings: Embeddings | null; private callback?: (data: string) => Promise; @@ -26,7 +26,7 @@ export class EdgeTool { apiKey: string | undefined, baseUrl: string, model: BaseLanguageModel, - embeddings: Embeddings, + embeddings: Embeddings | null, callback?: (data: string) => Promise, ) { this.apiKey = apiKey; @@ -37,10 +37,6 @@ export class EdgeTool { } async getCustomTools(): Promise { - const webBrowserTool = new WebBrowser({ - model: this.model, - embeddings: this.embeddings, - }); const calculatorTool = new Calculator(); const dallEAPITool = new DallEAPIWrapper( this.apiKey, @@ -56,7 +52,7 @@ export class EdgeTool { const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool(); let tools = [ calculatorTool, - webBrowserTool, + // webBrowserTool, dallEAPITool, stableDiffusionTool, arxivAPITool, @@ -66,6 +62,13 @@ export class EdgeTool { bilibiliMusicRecognitionTool, bilibiliVideoConclusionTool, ]; + if (this.embeddings != null) { + const webBrowserTool = new WebBrowser({ + model: this.model, + embeddings: this.embeddings, + }); + tools.push(webBrowserTool); + } return tools; } } diff --git a/app/api/langchain-tools/nodejs_tools.ts b/app/api/langchain-tools/nodejs_tools.ts index 6553318c7..419eb5715 100644 --- a/app/api/langchain-tools/nodejs_tools.ts +++ b/app/api/langchain-tools/nodejs_tools.ts @@ -17,7 +17,7 @@ export class NodeJSTool { private apiKey: string | undefined; private baseUrl: string; private model: BaseLanguageModel; - private embeddings: Embeddings; + private embeddings: Embeddings | null; private sessionId: string; private ragEmbeddings: Embeddings; private callback?: (data: string) => Promise; @@ -26,7 +26,7 @@ export class NodeJSTool { apiKey: string | undefined, baseUrl: string, model: BaseLanguageModel, - embeddings: Embeddings, + embeddings: Embeddings | null, sessionId: string, ragEmbeddings: Embeddings, callback?: (data: string) => Promise, @@ -41,10 +41,6 @@ export class NodeJSTool { } async getCustomTools(): Promise { - const webBrowserTool = new WebBrowser({ - model: this.model, - embeddings: this.embeddings, - }); const calculatorTool = new Calculator(); const dallEAPITool = new DallEAPINodeWrapper( this.apiKey, @@ -54,24 +50,32 @@ export class NodeJSTool { const stableDiffusionTool = new StableDiffusionNodeWrapper(); const arxivAPITool = new ArxivAPIWrapper(); const wolframAlphaTool = new WolframAlphaTool(); - const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings); const bilibiliVideoInfoTool = new BilibiliVideoInfoTool(); const bilibiliVideoSearchTool = new BilibiliVideoSearchTool(); const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool(); const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool(); let tools: any = [ + // webBrowserTool, + // pdfBrowserTool, calculatorTool, - webBrowserTool, dallEAPITool, stableDiffusionTool, arxivAPITool, wolframAlphaTool, - pdfBrowserTool, bilibiliVideoInfoTool, bilibiliVideoSearchTool, bilibiliMusicRecognitionTool, bilibiliVideoConclusionTool, ]; + if (this.embeddings != null) { + const webBrowserTool = new WebBrowser({ + model: this.model, + embeddings: this.embeddings, + }); + const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings); + tools.push(webBrowserTool); + tools.push(pdfBrowserTool); + } if (!!process.env.ENABLE_RAG) { tools.push( new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings), diff --git a/app/api/langchain/tool/agent/agentapi.ts b/app/api/langchain/tool/agent/agentapi.ts index 2a20faa53..4e9016b26 100644 --- a/app/api/langchain/tool/agent/agentapi.ts +++ b/app/api/langchain/tool/agent/agentapi.ts @@ -10,7 +10,12 @@ import { createToolCallingAgent, createReactAgent, } from "langchain/agents"; -import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant"; +import { + ACCESS_CODE_PREFIX, + ANTHROPIC_BASE_URL, + OPENAI_BASE_URL, + ServiceProvider, +} from "@/app/constant"; // import * as langchainTools from "langchain/tools"; import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index"; @@ -33,7 +38,7 @@ import { ChatPromptTemplate, MessagesPlaceholder, } from "@langchain/core/prompts"; -import { ChatOpenAI } from "@langchain/openai"; +import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; import { ChatAnthropic } from "@langchain/anthropic"; import { BaseMessage, @@ -45,6 +50,7 @@ import { } from "@langchain/core/messages"; import { MultimodalContent } from "@/app/client/api"; import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index"; +import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"; export interface RequestMessage { role: string; @@ -202,29 +208,81 @@ export class AgentApi { }); } - async getOpenAIApiKey(token: string) { + getApiKey(token: string, provider: ServiceProvider) { const serverConfig = getServerSideConfig(); const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX); - let apiKey = serverConfig.apiKey; if (isApiKey && token) { - apiKey = token; + return token; } - return apiKey; + if (provider === ServiceProvider.OpenAI) return serverConfig.apiKey; + if (provider === ServiceProvider.Anthropic) + return serverConfig.anthropicApiKey; + throw new Error("Unsupported provider"); } - async getOpenAIBaseUrl(reqBaseUrl: string | undefined) { + getBaseUrl(reqBaseUrl: string | undefined, provider: ServiceProvider) { const serverConfig = getServerSideConfig(); - let baseUrl = "https://api.openai.com/v1"; - if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + let baseUrl = ""; + if (provider === ServiceProvider.OpenAI) { + baseUrl = OPENAI_BASE_URL; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + } + if (provider === ServiceProvider.Anthropic) { + baseUrl = ANTHROPIC_BASE_URL; + if (serverConfig.anthropicUrl) baseUrl = serverConfig.anthropicUrl; + } if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://")) baseUrl = reqBaseUrl; - if (!baseUrl.endsWith("/v1")) + if (!baseUrl.endsWith("/v1") && provider === ServiceProvider.OpenAI) baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; - console.log("[openai baseUrl]", baseUrl); return baseUrl; } + getToolBaseLanguageModel( + reqBody: RequestBody, + apiKey: string, + baseUrl: string, + ) { + if (reqBody.provider === ServiceProvider.Anthropic) { + return new ChatAnthropic({ + temperature: 0, + modelName: reqBody.model, + apiKey: apiKey, + clientOptions: { + baseURL: baseUrl, + }, + }); + } + return new ChatOpenAI( + { + temperature: 0, + modelName: reqBody.model, + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + } + + getToolEmbeddings(reqBody: RequestBody, apiKey: string, baseUrl: string) { + if (reqBody.provider === ServiceProvider.Anthropic) { + if (process.env.OLLAMA_BASE_URL) { + return new OllamaEmbeddings({ + model: process.env.RAG_EMBEDDING_MODEL, + baseUrl: process.env.OLLAMA_BASE_URL, + }); + } else { + return null; + } + } + return new OpenAIEmbeddings( + { + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + } + getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) { const serverConfig = getServerSideConfig(); if (reqBody.isAzure || serverConfig.isAzure) { @@ -266,7 +324,6 @@ export class AgentApi { temperature: reqBody.temperature, streaming: reqBody.stream, topP: reqBody.top_p, - // maxTokens: 1024, clientOptions: { baseURL: baseUrl, }, @@ -300,22 +357,9 @@ export class AgentApi { const authToken = req.headers.get(authHeaderName) ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); - let apiKey = await this.getOpenAIApiKey(token); + let apiKey = this.getApiKey(token, reqBody.provider); if (isAzure) apiKey = token; - let baseUrl = "https://api.openai.com/v1"; - if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; - if ( - reqBody.baseUrl?.startsWith("http://") || - reqBody.baseUrl?.startsWith("https://") - ) { - baseUrl = reqBody.baseUrl; - } - if ( - reqBody.provider === ServiceProvider.OpenAI && - !baseUrl.endsWith("/v1") - ) { - baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; - } + let baseUrl = this.getBaseUrl(reqBody.baseUrl, reqBody.provider); if (!reqBody.isAzure && serverConfig.isAzure) { baseUrl = serverConfig.azureUrl || baseUrl; } diff --git a/app/api/langchain/tool/agent/edge/route.ts b/app/api/langchain/tool/agent/edge/route.ts index f3b6343c9..cd21540e0 100644 --- a/app/api/langchain/tool/agent/edge/route.ts +++ b/app/api/langchain/tool/agent/edge/route.ts @@ -3,7 +3,8 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { auth } from "@/app/api/auth"; import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { ModelProvider } from "@/app/constant"; -import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; +import { Embeddings } from "@langchain/core/embeddings"; +import { BaseLanguageModel } from "@langchain/core/language_models/base"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { @@ -27,23 +28,13 @@ async function handle(req: NextRequest) { const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); - const apiKey = await agentApi.getOpenAIApiKey(token); - const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); + const apiKey = agentApi.getApiKey(token, reqBody.provider); + const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider); - const model = new ChatOpenAI( - { - temperature: 0, - modelName: reqBody.model, - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); - const embeddings = new OpenAIEmbeddings( - { - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); + let model: BaseLanguageModel; + let embeddings: Embeddings | null; + model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl); + embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl); var dalleCallback = async (data: string) => { var response = new ResponseBody(); diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index c9bb728f3..1f4364b8b 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -2,17 +2,23 @@ import { NextRequest, NextResponse } from "next/server"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { auth } from "@/app/api/auth"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; -import { ModelProvider } from "@/app/constant"; +import { ModelProvider, ServiceProvider } from "@/app/constant"; import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"; import { Embeddings } from "@langchain/core/embeddings"; +import { BaseLanguageModel } from "@langchain/core/language_models/base"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } try { - const authResult = auth(req, ModelProvider.GPT); + const reqBody: RequestBody = await req.json(); + const modelProvider = + reqBody.provider === ServiceProvider.Anthropic + ? ModelProvider.Claude + : ModelProvider.GPT; + const authResult = auth(req, modelProvider); if (authResult.error) { return NextResponse.json(authResult, { status: 401, @@ -25,27 +31,17 @@ async function handle(req: NextRequest) { const controller = new AbortController(); const agentApi = new AgentApi(encoder, transformStream, writer, controller); - const reqBody: RequestBody = await req.json(); - const authToken = req.headers.get("Authorization") ?? ""; + const authToken = + (req.headers.get("Authorization") || req.headers.get("x-api-key")) ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); - const apiKey = await agentApi.getOpenAIApiKey(token); - const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); + const apiKey = agentApi.getApiKey(token, reqBody.provider); + const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider); + let model: BaseLanguageModel; + let embeddings: Embeddings | null; + model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl); + embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl); - const model = new ChatOpenAI( - { - temperature: 0, - modelName: reqBody.model, - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); - const embeddings = new OpenAIEmbeddings( - { - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); let ragEmbeddings: Embeddings; if (process.env.OLLAMA_BASE_URL) { ragEmbeddings = new OllamaEmbeddings({ @@ -98,6 +94,7 @@ async function handle(req: NextRequest) { export const GET = handle; export const POST = handle; +export const maxDuration = 60; export const runtime = "nodejs"; export const preferredRegion = [ "arn1",