diff --git a/app/api/langchain/tool/agent/route.ts b/app/api/langchain/tool/agent/route.ts index 2fd163428..64f4c45de 100644 --- a/app/api/langchain/tool/agent/route.ts +++ b/app/api/langchain/tool/agent/route.ts @@ -5,24 +5,19 @@ import { auth } from "../../../auth"; import { ChatOpenAI } from "langchain/chat_models/openai"; import { BaseCallbackHandler } from "langchain/callbacks"; -import { - BingSerpAPI, - DynamicTool, - RequestsGetTool, - RequestsPostTool, - Tool, -} from "langchain/tools"; import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; import { BufferMemory, ChatMessageHistory } from "langchain/memory"; import { initializeAgentExecutorWithOptions } from "langchain/agents"; -import { SerpAPI } from "langchain/tools"; -import { Calculator } from "langchain/tools/calculator"; -import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; -import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; + +import * as langchainTools from "langchain/tools"; +import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; +import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; import { WebBrowser } from "langchain/tools/webbrowser"; +import { Calculator } from "langchain/tools/calculator"; +import { DynamicTool, Tool } from "langchain/tools"; const serverConfig = getServerSideConfig(); @@ -76,6 +71,7 @@ async function handle(req: NextRequest) { const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); + let useTools = reqBody.useTools ?? []; let apiKey = serverConfig.apiKey; if (isOpenAiKey && token) { apiKey = token; @@ -177,14 +173,18 @@ async function handle(req: NextRequest) { let searchTool: Tool = new DuckDuckGo(); if (process.env.BING_SEARCH_API_KEY) { - let bingSearchTool = new BingSerpAPI(process.env.BING_SEARCH_API_KEY); + let bingSearchTool = new langchainTools["BingSerpAPI"]( + process.env.BING_SEARCH_API_KEY, + ); searchTool = new DynamicTool({ name: "bing_search", description: bingSearchTool.description, func: async (input: string) => bingSearchTool.call(input), }); } else if (process.env.SERPAPI_API_KEY) { - let serpAPITool = new SerpAPI(process.env.SERPAPI_API_KEY); + let serpAPITool = new langchainTools["SerpAPI"]( + process.env.SERPAPI_API_KEY, + ); searchTool = new DynamicTool({ name: "google_search", description: serpAPITool.description, @@ -213,11 +213,20 @@ async function handle(req: NextRequest) { ]; const webBrowserTool = new WebBrowser({ model, embeddings }); const calculatorTool = new Calculator(); - if (reqBody.useTools.includes("web-search")) tools.push(searchTool); - if (reqBody.useTools.includes(webBrowserTool.name)) - tools.push(webBrowserTool); - if (reqBody.useTools.includes(calculatorTool.name)) - tools.push(calculatorTool); + if (useTools.includes("web-search")) tools.push(searchTool); + if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); + if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); + + useTools.forEach((toolName) => { + if (toolName) { + var tool = langchainTools[ + toolName as keyof typeof langchainTools + ] as any; + if (tool) { + tools.push(new tool()); + } + } + }); const pastMessages = new Array(); diff --git a/app/components/plugin.tsx b/app/components/plugin.tsx index fbaed29cc..599204a2f 100644 --- a/app/components/plugin.tsx +++ b/app/components/plugin.tsx @@ -423,7 +423,6 @@ export function PluginPage() { - {/* 操作按钮 */}