feat: claude function call
This commit is contained in:
parent
0a643dc71d
commit
8c5e92d66a
|
@ -18,7 +18,7 @@ export class EdgeTool {
|
|||
|
||||
private model: BaseLanguageModel;
|
||||
|
||||
private embeddings: Embeddings;
|
||||
private embeddings: Embeddings | null;
|
||||
|
||||
private callback?: (data: string) => Promise<void>;
|
||||
|
||||
|
@ -26,7 +26,7 @@ export class EdgeTool {
|
|||
apiKey: string | undefined,
|
||||
baseUrl: string,
|
||||
model: BaseLanguageModel,
|
||||
embeddings: Embeddings,
|
||||
embeddings: Embeddings | null,
|
||||
callback?: (data: string) => Promise<void>,
|
||||
) {
|
||||
this.apiKey = apiKey;
|
||||
|
@ -37,10 +37,6 @@ export class EdgeTool {
|
|||
}
|
||||
|
||||
async getCustomTools(): Promise<any[]> {
|
||||
const webBrowserTool = new WebBrowser({
|
||||
model: this.model,
|
||||
embeddings: this.embeddings,
|
||||
});
|
||||
const calculatorTool = new Calculator();
|
||||
const dallEAPITool = new DallEAPIWrapper(
|
||||
this.apiKey,
|
||||
|
@ -56,7 +52,7 @@ export class EdgeTool {
|
|||
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
|
||||
let tools = [
|
||||
calculatorTool,
|
||||
webBrowserTool,
|
||||
// webBrowserTool,
|
||||
dallEAPITool,
|
||||
stableDiffusionTool,
|
||||
arxivAPITool,
|
||||
|
@ -66,6 +62,13 @@ export class EdgeTool {
|
|||
bilibiliMusicRecognitionTool,
|
||||
bilibiliVideoConclusionTool,
|
||||
];
|
||||
if (this.embeddings != null) {
|
||||
const webBrowserTool = new WebBrowser({
|
||||
model: this.model,
|
||||
embeddings: this.embeddings,
|
||||
});
|
||||
tools.push(webBrowserTool);
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ export class NodeJSTool {
|
|||
private apiKey: string | undefined;
|
||||
private baseUrl: string;
|
||||
private model: BaseLanguageModel;
|
||||
private embeddings: Embeddings;
|
||||
private embeddings: Embeddings | null;
|
||||
private sessionId: string;
|
||||
private ragEmbeddings: Embeddings;
|
||||
private callback?: (data: string) => Promise<void>;
|
||||
|
@ -26,7 +26,7 @@ export class NodeJSTool {
|
|||
apiKey: string | undefined,
|
||||
baseUrl: string,
|
||||
model: BaseLanguageModel,
|
||||
embeddings: Embeddings,
|
||||
embeddings: Embeddings | null,
|
||||
sessionId: string,
|
||||
ragEmbeddings: Embeddings,
|
||||
callback?: (data: string) => Promise<void>,
|
||||
|
@ -41,10 +41,6 @@ export class NodeJSTool {
|
|||
}
|
||||
|
||||
async getCustomTools(): Promise<any[]> {
|
||||
const webBrowserTool = new WebBrowser({
|
||||
model: this.model,
|
||||
embeddings: this.embeddings,
|
||||
});
|
||||
const calculatorTool = new Calculator();
|
||||
const dallEAPITool = new DallEAPINodeWrapper(
|
||||
this.apiKey,
|
||||
|
@ -54,24 +50,32 @@ export class NodeJSTool {
|
|||
const stableDiffusionTool = new StableDiffusionNodeWrapper();
|
||||
const arxivAPITool = new ArxivAPIWrapper();
|
||||
const wolframAlphaTool = new WolframAlphaTool();
|
||||
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
|
||||
const bilibiliVideoInfoTool = new BilibiliVideoInfoTool();
|
||||
const bilibiliVideoSearchTool = new BilibiliVideoSearchTool();
|
||||
const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool();
|
||||
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
|
||||
let tools: any = [
|
||||
// webBrowserTool,
|
||||
// pdfBrowserTool,
|
||||
calculatorTool,
|
||||
webBrowserTool,
|
||||
dallEAPITool,
|
||||
stableDiffusionTool,
|
||||
arxivAPITool,
|
||||
wolframAlphaTool,
|
||||
pdfBrowserTool,
|
||||
bilibiliVideoInfoTool,
|
||||
bilibiliVideoSearchTool,
|
||||
bilibiliMusicRecognitionTool,
|
||||
bilibiliVideoConclusionTool,
|
||||
];
|
||||
if (this.embeddings != null) {
|
||||
const webBrowserTool = new WebBrowser({
|
||||
model: this.model,
|
||||
embeddings: this.embeddings,
|
||||
});
|
||||
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
|
||||
tools.push(webBrowserTool);
|
||||
tools.push(pdfBrowserTool);
|
||||
}
|
||||
if (!!process.env.ENABLE_RAG) {
|
||||
tools.push(
|
||||
new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings),
|
||||
|
|
|
@ -10,7 +10,12 @@ import {
|
|||
createToolCallingAgent,
|
||||
createReactAgent,
|
||||
} from "langchain/agents";
|
||||
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
|
||||
import {
|
||||
ACCESS_CODE_PREFIX,
|
||||
ANTHROPIC_BASE_URL,
|
||||
OPENAI_BASE_URL,
|
||||
ServiceProvider,
|
||||
} from "@/app/constant";
|
||||
|
||||
// import * as langchainTools from "langchain/tools";
|
||||
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
|
||||
|
@ -33,7 +38,7 @@ import {
|
|||
ChatPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
} from "@langchain/core/prompts";
|
||||
import { ChatOpenAI } from "@langchain/openai";
|
||||
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
|
||||
import { ChatAnthropic } from "@langchain/anthropic";
|
||||
import {
|
||||
BaseMessage,
|
||||
|
@ -45,6 +50,7 @@ import {
|
|||
} from "@langchain/core/messages";
|
||||
import { MultimodalContent } from "@/app/client/api";
|
||||
import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index";
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
|
||||
|
||||
export interface RequestMessage {
|
||||
role: string;
|
||||
|
@ -202,29 +208,81 @@ export class AgentApi {
|
|||
});
|
||||
}
|
||||
|
||||
async getOpenAIApiKey(token: string) {
|
||||
getApiKey(token: string, provider: ServiceProvider) {
|
||||
const serverConfig = getServerSideConfig();
|
||||
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
|
||||
|
||||
let apiKey = serverConfig.apiKey;
|
||||
if (isApiKey && token) {
|
||||
apiKey = token;
|
||||
return token;
|
||||
}
|
||||
return apiKey;
|
||||
if (provider === ServiceProvider.OpenAI) return serverConfig.apiKey;
|
||||
if (provider === ServiceProvider.Anthropic)
|
||||
return serverConfig.anthropicApiKey;
|
||||
throw new Error("Unsupported provider");
|
||||
}
|
||||
|
||||
async getOpenAIBaseUrl(reqBaseUrl: string | undefined) {
|
||||
getBaseUrl(reqBaseUrl: string | undefined, provider: ServiceProvider) {
|
||||
const serverConfig = getServerSideConfig();
|
||||
let baseUrl = "https://api.openai.com/v1";
|
||||
let baseUrl = "";
|
||||
if (provider === ServiceProvider.OpenAI) {
|
||||
baseUrl = OPENAI_BASE_URL;
|
||||
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
|
||||
}
|
||||
if (provider === ServiceProvider.Anthropic) {
|
||||
baseUrl = ANTHROPIC_BASE_URL;
|
||||
if (serverConfig.anthropicUrl) baseUrl = serverConfig.anthropicUrl;
|
||||
}
|
||||
if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://"))
|
||||
baseUrl = reqBaseUrl;
|
||||
if (!baseUrl.endsWith("/v1"))
|
||||
if (!baseUrl.endsWith("/v1") && provider === ServiceProvider.OpenAI)
|
||||
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
|
||||
console.log("[openai baseUrl]", baseUrl);
|
||||
return baseUrl;
|
||||
}
|
||||
|
||||
getToolBaseLanguageModel(
|
||||
reqBody: RequestBody,
|
||||
apiKey: string,
|
||||
baseUrl: string,
|
||||
) {
|
||||
if (reqBody.provider === ServiceProvider.Anthropic) {
|
||||
return new ChatAnthropic({
|
||||
temperature: 0,
|
||||
modelName: reqBody.model,
|
||||
apiKey: apiKey,
|
||||
clientOptions: {
|
||||
baseURL: baseUrl,
|
||||
},
|
||||
});
|
||||
}
|
||||
return new ChatOpenAI(
|
||||
{
|
||||
temperature: 0,
|
||||
modelName: reqBody.model,
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
}
|
||||
|
||||
getToolEmbeddings(reqBody: RequestBody, apiKey: string, baseUrl: string) {
|
||||
if (reqBody.provider === ServiceProvider.Anthropic) {
|
||||
if (process.env.OLLAMA_BASE_URL) {
|
||||
return new OllamaEmbeddings({
|
||||
model: process.env.RAG_EMBEDDING_MODEL,
|
||||
baseUrl: process.env.OLLAMA_BASE_URL,
|
||||
});
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return new OpenAIEmbeddings(
|
||||
{
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
}
|
||||
|
||||
getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) {
|
||||
const serverConfig = getServerSideConfig();
|
||||
if (reqBody.isAzure || serverConfig.isAzure) {
|
||||
|
@ -266,7 +324,6 @@ export class AgentApi {
|
|||
temperature: reqBody.temperature,
|
||||
streaming: reqBody.stream,
|
||||
topP: reqBody.top_p,
|
||||
// maxTokens: 1024,
|
||||
clientOptions: {
|
||||
baseURL: baseUrl,
|
||||
},
|
||||
|
@ -300,22 +357,9 @@ export class AgentApi {
|
|||
const authToken = req.headers.get(authHeaderName) ?? "";
|
||||
const token = authToken.trim().replaceAll("Bearer ", "").trim();
|
||||
|
||||
let apiKey = await this.getOpenAIApiKey(token);
|
||||
let apiKey = this.getApiKey(token, reqBody.provider);
|
||||
if (isAzure) apiKey = token;
|
||||
let baseUrl = "https://api.openai.com/v1";
|
||||
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
|
||||
if (
|
||||
reqBody.baseUrl?.startsWith("http://") ||
|
||||
reqBody.baseUrl?.startsWith("https://")
|
||||
) {
|
||||
baseUrl = reqBody.baseUrl;
|
||||
}
|
||||
if (
|
||||
reqBody.provider === ServiceProvider.OpenAI &&
|
||||
!baseUrl.endsWith("/v1")
|
||||
) {
|
||||
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
|
||||
}
|
||||
let baseUrl = this.getBaseUrl(reqBody.baseUrl, reqBody.provider);
|
||||
if (!reqBody.isAzure && serverConfig.isAzure) {
|
||||
baseUrl = serverConfig.azureUrl || baseUrl;
|
||||
}
|
||||
|
|
|
@ -3,7 +3,8 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
|
|||
import { auth } from "@/app/api/auth";
|
||||
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
|
||||
import { ModelProvider } from "@/app/constant";
|
||||
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
|
||||
import { Embeddings } from "@langchain/core/embeddings";
|
||||
import { BaseLanguageModel } from "@langchain/core/language_models/base";
|
||||
|
||||
async function handle(req: NextRequest) {
|
||||
if (req.method === "OPTIONS") {
|
||||
|
@ -27,23 +28,13 @@ async function handle(req: NextRequest) {
|
|||
const authToken = req.headers.get("Authorization") ?? "";
|
||||
const token = authToken.trim().replaceAll("Bearer ", "").trim();
|
||||
|
||||
const apiKey = await agentApi.getOpenAIApiKey(token);
|
||||
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
|
||||
const apiKey = agentApi.getApiKey(token, reqBody.provider);
|
||||
const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
|
||||
|
||||
const model = new ChatOpenAI(
|
||||
{
|
||||
temperature: 0,
|
||||
modelName: reqBody.model,
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
const embeddings = new OpenAIEmbeddings(
|
||||
{
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
let model: BaseLanguageModel;
|
||||
let embeddings: Embeddings | null;
|
||||
model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
|
||||
embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
|
||||
|
||||
var dalleCallback = async (data: string) => {
|
||||
var response = new ResponseBody();
|
||||
|
|
|
@ -2,17 +2,23 @@ import { NextRequest, NextResponse } from "next/server";
|
|||
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
|
||||
import { auth } from "@/app/api/auth";
|
||||
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
|
||||
import { ModelProvider } from "@/app/constant";
|
||||
import { ModelProvider, ServiceProvider } from "@/app/constant";
|
||||
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
|
||||
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
|
||||
import { Embeddings } from "@langchain/core/embeddings";
|
||||
import { BaseLanguageModel } from "@langchain/core/language_models/base";
|
||||
|
||||
async function handle(req: NextRequest) {
|
||||
if (req.method === "OPTIONS") {
|
||||
return NextResponse.json({ body: "OK" }, { status: 200 });
|
||||
}
|
||||
try {
|
||||
const authResult = auth(req, ModelProvider.GPT);
|
||||
const reqBody: RequestBody = await req.json();
|
||||
const modelProvider =
|
||||
reqBody.provider === ServiceProvider.Anthropic
|
||||
? ModelProvider.Claude
|
||||
: ModelProvider.GPT;
|
||||
const authResult = auth(req, modelProvider);
|
||||
if (authResult.error) {
|
||||
return NextResponse.json(authResult, {
|
||||
status: 401,
|
||||
|
@ -25,27 +31,17 @@ async function handle(req: NextRequest) {
|
|||
const controller = new AbortController();
|
||||
const agentApi = new AgentApi(encoder, transformStream, writer, controller);
|
||||
|
||||
const reqBody: RequestBody = await req.json();
|
||||
const authToken = req.headers.get("Authorization") ?? "";
|
||||
const authToken =
|
||||
(req.headers.get("Authorization") || req.headers.get("x-api-key")) ?? "";
|
||||
const token = authToken.trim().replaceAll("Bearer ", "").trim();
|
||||
|
||||
const apiKey = await agentApi.getOpenAIApiKey(token);
|
||||
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
|
||||
const apiKey = agentApi.getApiKey(token, reqBody.provider);
|
||||
const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
|
||||
let model: BaseLanguageModel;
|
||||
let embeddings: Embeddings | null;
|
||||
model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
|
||||
embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
|
||||
|
||||
const model = new ChatOpenAI(
|
||||
{
|
||||
temperature: 0,
|
||||
modelName: reqBody.model,
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
const embeddings = new OpenAIEmbeddings(
|
||||
{
|
||||
openAIApiKey: apiKey,
|
||||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
let ragEmbeddings: Embeddings;
|
||||
if (process.env.OLLAMA_BASE_URL) {
|
||||
ragEmbeddings = new OllamaEmbeddings({
|
||||
|
@ -98,6 +94,7 @@ async function handle(req: NextRequest) {
|
|||
export const GET = handle;
|
||||
export const POST = handle;
|
||||
|
||||
export const maxDuration = 60;
|
||||
export const runtime = "nodejs";
|
||||
export const preferredRegion = [
|
||||
"arn1",
|
||||
|
|
Loading…
Reference in New Issue