From e88685fc98bf104c283d731ddf4b35cf9f87def7 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Wed, 29 Nov 2023 13:09:03 +0800 Subject: [PATCH] refactor: refactor code --- .env.template | 5 ++++ app/api/langchain/tool/agent/agentapi.ts | 31 +++++++++++++++----- app/api/langchain/tool/agent/edge/route.ts | 27 ++++------------- app/api/langchain/tool/agent/nodejs/route.ts | 29 ++++-------------- 4 files changed, 38 insertions(+), 54 deletions(-) diff --git a/.env.template b/.env.template index 3e3290369..0410768f7 100644 --- a/.env.template +++ b/.env.template @@ -36,3 +36,8 @@ ENABLE_BALANCE_QUERY= # Default: Empty # If you want to disable parse settings from url, set this value to 1. DISABLE_FAST_LINK= + +# (optional) +# Default: 1 +# If your project is not deployed on Vercel, set this value to 1. +NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN=1 \ No newline at end of file diff --git a/app/api/langchain/tool/agent/agentapi.ts b/app/api/langchain/tool/agent/agentapi.ts index e4ca14840..c5a61e0d8 100644 --- a/app/api/langchain/tool/agent/agentapi.ts +++ b/app/api/langchain/tool/agent/agentapi.ts @@ -9,21 +9,13 @@ import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; import { BufferMemory, ChatMessageHistory } from "langchain/memory"; import { initializeAgentExecutorWithOptions } from "langchain/agents"; import { ACCESS_CODE_PREFIX } from "@/app/constant"; -import { OpenAI } from "langchain/llms/openai"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import * as langchainTools from "langchain/tools"; import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; -import { WebBrowser } from "langchain/tools/webbrowser"; -import { Calculator } from "langchain/tools/calculator"; import { DynamicTool, Tool } from "langchain/tools"; -import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; -import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; -import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; -import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser"; export interface RequestMessage { role: string; @@ -158,6 +150,29 @@ export class AgentApi { }); } + async getOpenAIApiKey(token: string) { + const serverConfig = getServerSideConfig(); + const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); + + let apiKey = serverConfig.apiKey; + if (isOpenAiKey && token) { + apiKey = token; + } + return apiKey; + } + + async getOpenAIBaseUrl(reqBaseUrl: string | undefined) { + const serverConfig = getServerSideConfig(); + let baseUrl = "https://api.openai.com/v1"; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://")) + baseUrl = reqBaseUrl; + if (!baseUrl.endsWith("/v1")) + baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; + console.log("[baseUrl]", baseUrl); + return baseUrl; + } + async getApiHandler( req: NextRequest, reqBody: RequestBody, diff --git a/app/api/langchain/tool/agent/edge/route.ts b/app/api/langchain/tool/agent/edge/route.ts index 43f7c1b1c..7e9c8f06f 100644 --- a/app/api/langchain/tool/agent/edge/route.ts +++ b/app/api/langchain/tool/agent/edge/route.ts @@ -2,7 +2,6 @@ import { NextRequest, NextResponse } from "next/server"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { auth } from "@/app/api/auth"; import { EdgeTool } from "../../../../langchain-tools/edge_tools"; -import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { getServerSideConfig } from "@/app/config/server"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; @@ -19,33 +18,17 @@ async function handle(req: NextRequest) { }); } - const serverConfig = getServerSideConfig(); - const encoder = new TextEncoder(); const transformStream = new TransformStream(); const writer = transformStream.writable.getWriter(); + const agentApi = new AgentApi(encoder, transformStream, writer); const reqBody: RequestBody = await req.json(); const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); - const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); - let apiKey = serverConfig.apiKey; - if (isOpenAiKey && token) { - apiKey = token; - } - - // support base url - let baseUrl = "https://api.openai.com/v1"; - if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; - if ( - reqBody.baseUrl?.startsWith("http://") || - reqBody.baseUrl?.startsWith("https://") - ) - baseUrl = reqBody.baseUrl; - if (!baseUrl.endsWith("/v1")) - baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; - console.log("[baseUrl]", baseUrl); + const apiKey = await agentApi.getOpenAIApiKey(token); + const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); const model = new OpenAI( { @@ -78,8 +61,8 @@ async function handle(req: NextRequest) { embeddings, dalleCallback, ); - var tools = await edgeTool.getCustomTools(); - var agentApi = new AgentApi(encoder, transformStream, writer); + var edgeTools = await edgeTool.getCustomTools(); + var tools = [...edgeTools]; return await agentApi.getApiHandler(req, reqBody, tools); } catch (e) { return new Response(JSON.stringify({ error: (e as any).message }), { diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index debbd6c21..201dbd49c 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -2,8 +2,6 @@ import { NextRequest, NextResponse } from "next/server"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { auth } from "@/app/api/auth"; import { EdgeTool } from "../../../../langchain-tools/edge_tools"; -import { ACCESS_CODE_PREFIX } from "@/app/constant"; -import { getServerSideConfig } from "@/app/config/server"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; @@ -20,33 +18,17 @@ async function handle(req: NextRequest) { }); } - const serverConfig = getServerSideConfig(); - const encoder = new TextEncoder(); const transformStream = new TransformStream(); const writer = transformStream.writable.getWriter(); + const agentApi = new AgentApi(encoder, transformStream, writer); const reqBody: RequestBody = await req.json(); const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); - const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); - let apiKey = serverConfig.apiKey; - if (isOpenAiKey && token) { - apiKey = token; - } - - // support base url - let baseUrl = "https://api.openai.com/v1"; - if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; - if ( - reqBody.baseUrl?.startsWith("http://") || - reqBody.baseUrl?.startsWith("https://") - ) - baseUrl = reqBody.baseUrl; - if (!baseUrl.endsWith("/v1")) - baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; - console.log("[baseUrl]", baseUrl); + const apiKey = await agentApi.getOpenAIApiKey(token); + const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl); const model = new OpenAI( { @@ -88,9 +70,8 @@ async function handle(req: NextRequest) { ); var edgeTools = await edgeTool.getCustomTools(); var nodejsTools = await nodejsTool.getCustomTools(); - edgeTools.push(nodejsTools); - var agentApi = new AgentApi(encoder, transformStream, writer); - return await agentApi.getApiHandler(req, reqBody, nodejsTools); + var tools = [...edgeTools, ...nodejsTools]; + return await agentApi.getApiHandler(req, reqBody, tools); } catch (e) { return new Response(JSON.stringify({ error: (e as any).message }), { status: 500,