refactor: refactor code

This commit is contained in:
Hk-Gosuto 2023-11-29 13:09:03 +08:00
parent 293679fa64
commit e88685fc98
4 changed files with 38 additions and 54 deletions

View File

@ -36,3 +36,8 @@ ENABLE_BALANCE_QUERY=
# Default: Empty # Default: Empty
# If you want to disable parse settings from url, set this value to 1. # If you want to disable parse settings from url, set this value to 1.
DISABLE_FAST_LINK= DISABLE_FAST_LINK=
# (optional)
# Default: 1
# If your project is not deployed on Vercel, set this value to 1.
NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN=1

View File

@ -9,21 +9,13 @@ import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory"; import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents"; import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import * as langchainTools from "langchain/tools"; import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { WebBrowser } from "langchain/tools/webbrowser";
import { Calculator } from "langchain/tools/calculator";
import { DynamicTool, Tool } from "langchain/tools"; import { DynamicTool, Tool } from "langchain/tools";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator";
import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv";
import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser";
export interface RequestMessage { export interface RequestMessage {
role: string; role: string;
@ -158,6 +150,29 @@ export class AgentApi {
}); });
} }
async getOpenAIApiKey(token: string) {
const serverConfig = getServerSideConfig();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
}
return apiKey;
}
async getOpenAIBaseUrl(reqBaseUrl: string | undefined) {
const serverConfig = getServerSideConfig();
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://"))
baseUrl = reqBaseUrl;
if (!baseUrl.endsWith("/v1"))
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
return baseUrl;
}
async getApiHandler( async getApiHandler(
req: NextRequest, req: NextRequest,
reqBody: RequestBody, reqBody: RequestBody,

View File

@ -2,7 +2,6 @@ import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth"; import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { getServerSideConfig } from "@/app/config/server"; import { getServerSideConfig } from "@/app/config/server";
import { OpenAI } from "langchain/llms/openai"; import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai";
@ -19,33 +18,17 @@ async function handle(req: NextRequest) {
}); });
} }
const serverConfig = getServerSideConfig();
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const transformStream = new TransformStream(); const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter(); const writer = transformStream.writable.getWriter();
const agentApi = new AgentApi(encoder, transformStream, writer);
const reqBody: RequestBody = await req.json(); const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? ""; const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey; const apiKey = await agentApi.getOpenAIApiKey(token);
if (isOpenAiKey && token) { const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
apiKey = token;
}
// support base url
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
)
baseUrl = reqBody.baseUrl;
if (!baseUrl.endsWith("/v1"))
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
const model = new OpenAI( const model = new OpenAI(
{ {
@ -78,8 +61,8 @@ async function handle(req: NextRequest) {
embeddings, embeddings,
dalleCallback, dalleCallback,
); );
var tools = await edgeTool.getCustomTools(); var edgeTools = await edgeTool.getCustomTools();
var agentApi = new AgentApi(encoder, transformStream, writer); var tools = [...edgeTools];
return await agentApi.getApiHandler(req, reqBody, tools); return await agentApi.getApiHandler(req, reqBody, tools);
} catch (e) { } catch (e) {
return new Response(JSON.stringify({ error: (e as any).message }), { return new Response(JSON.stringify({ error: (e as any).message }), {

View File

@ -2,8 +2,6 @@ import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth"; import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { getServerSideConfig } from "@/app/config/server";
import { OpenAI } from "langchain/llms/openai"; import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
@ -20,33 +18,17 @@ async function handle(req: NextRequest) {
}); });
} }
const serverConfig = getServerSideConfig();
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const transformStream = new TransformStream(); const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter(); const writer = transformStream.writable.getWriter();
const agentApi = new AgentApi(encoder, transformStream, writer);
const reqBody: RequestBody = await req.json(); const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? ""; const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey; const apiKey = await agentApi.getOpenAIApiKey(token);
if (isOpenAiKey && token) { const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
apiKey = token;
}
// support base url
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
)
baseUrl = reqBody.baseUrl;
if (!baseUrl.endsWith("/v1"))
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
const model = new OpenAI( const model = new OpenAI(
{ {
@ -88,9 +70,8 @@ async function handle(req: NextRequest) {
); );
var edgeTools = await edgeTool.getCustomTools(); var edgeTools = await edgeTool.getCustomTools();
var nodejsTools = await nodejsTool.getCustomTools(); var nodejsTools = await nodejsTool.getCustomTools();
edgeTools.push(nodejsTools); var tools = [...edgeTools, ...nodejsTools];
var agentApi = new AgentApi(encoder, transformStream, writer); return await agentApi.getApiHandler(req, reqBody, tools);
return await agentApi.getApiHandler(req, reqBody, nodejsTools);
} catch (e) { } catch (e) {
return new Response(JSON.stringify({ error: (e as any).message }), { return new Response(JSON.stringify({ error: (e as any).message }), {
status: 500, status: 500,