feat: claude function call

This commit is contained in:
Hk-Gosuto 2024-08-17 13:05:49 +00:00
parent 0a643dc71d
commit 8c5e92d66a
5 changed files with 119 additions and 80 deletions

View File

@ -18,7 +18,7 @@ export class EdgeTool {
private model: BaseLanguageModel;
private embeddings: Embeddings;
private embeddings: Embeddings | null;
private callback?: (data: string) => Promise<void>;
@ -26,7 +26,7 @@ export class EdgeTool {
apiKey: string | undefined,
baseUrl: string,
model: BaseLanguageModel,
embeddings: Embeddings,
embeddings: Embeddings | null,
callback?: (data: string) => Promise<void>,
) {
this.apiKey = apiKey;
@ -37,10 +37,6 @@ export class EdgeTool {
}
async getCustomTools(): Promise<any[]> {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(
this.apiKey,
@ -56,7 +52,7 @@ export class EdgeTool {
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools = [
calculatorTool,
webBrowserTool,
// webBrowserTool,
dallEAPITool,
stableDiffusionTool,
arxivAPITool,
@ -66,6 +62,13 @@ export class EdgeTool {
bilibiliMusicRecognitionTool,
bilibiliVideoConclusionTool,
];
if (this.embeddings != null) {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
tools.push(webBrowserTool);
}
return tools;
}
}

View File

@ -17,7 +17,7 @@ export class NodeJSTool {
private apiKey: string | undefined;
private baseUrl: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
private embeddings: Embeddings | null;
private sessionId: string;
private ragEmbeddings: Embeddings;
private callback?: (data: string) => Promise<void>;
@ -26,7 +26,7 @@ export class NodeJSTool {
apiKey: string | undefined,
baseUrl: string,
model: BaseLanguageModel,
embeddings: Embeddings,
embeddings: Embeddings | null,
sessionId: string,
ragEmbeddings: Embeddings,
callback?: (data: string) => Promise<void>,
@ -41,10 +41,6 @@ export class NodeJSTool {
}
async getCustomTools(): Promise<any[]> {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPINodeWrapper(
this.apiKey,
@ -54,24 +50,32 @@ export class NodeJSTool {
const stableDiffusionTool = new StableDiffusionNodeWrapper();
const arxivAPITool = new ArxivAPIWrapper();
const wolframAlphaTool = new WolframAlphaTool();
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
const bilibiliVideoInfoTool = new BilibiliVideoInfoTool();
const bilibiliVideoSearchTool = new BilibiliVideoSearchTool();
const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool();
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools: any = [
// webBrowserTool,
// pdfBrowserTool,
calculatorTool,
webBrowserTool,
dallEAPITool,
stableDiffusionTool,
arxivAPITool,
wolframAlphaTool,
pdfBrowserTool,
bilibiliVideoInfoTool,
bilibiliVideoSearchTool,
bilibiliMusicRecognitionTool,
bilibiliVideoConclusionTool,
];
if (this.embeddings != null) {
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
tools.push(webBrowserTool);
tools.push(pdfBrowserTool);
}
if (!!process.env.ENABLE_RAG) {
tools.push(
new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings),

View File

@ -10,7 +10,12 @@ import {
createToolCallingAgent,
createReactAgent,
} from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import {
ACCESS_CODE_PREFIX,
ANTHROPIC_BASE_URL,
OPENAI_BASE_URL,
ServiceProvider,
} from "@/app/constant";
// import * as langchainTools from "langchain/tools";
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
@ -33,7 +38,7 @@ import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatAnthropic } from "@langchain/anthropic";
import {
BaseMessage,
@ -45,6 +50,7 @@ import {
} from "@langchain/core/messages";
import { MultimodalContent } from "@/app/client/api";
import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
export interface RequestMessage {
role: string;
@ -202,29 +208,81 @@ export class AgentApi {
});
}
async getOpenAIApiKey(token: string) {
getApiKey(token: string, provider: ServiceProvider) {
const serverConfig = getServerSideConfig();
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey;
if (isApiKey && token) {
apiKey = token;
return token;
}
return apiKey;
if (provider === ServiceProvider.OpenAI) return serverConfig.apiKey;
if (provider === ServiceProvider.Anthropic)
return serverConfig.anthropicApiKey;
throw new Error("Unsupported provider");
}
async getOpenAIBaseUrl(reqBaseUrl: string | undefined) {
getBaseUrl(reqBaseUrl: string | undefined, provider: ServiceProvider) {
const serverConfig = getServerSideConfig();
let baseUrl = "https://api.openai.com/v1";
let baseUrl = "";
if (provider === ServiceProvider.OpenAI) {
baseUrl = OPENAI_BASE_URL;
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
}
if (provider === ServiceProvider.Anthropic) {
baseUrl = ANTHROPIC_BASE_URL;
if (serverConfig.anthropicUrl) baseUrl = serverConfig.anthropicUrl;
}
if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://"))
baseUrl = reqBaseUrl;
if (!baseUrl.endsWith("/v1"))
if (!baseUrl.endsWith("/v1") && provider === ServiceProvider.OpenAI)
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[openai baseUrl]", baseUrl);
return baseUrl;
}
getToolBaseLanguageModel(
reqBody: RequestBody,
apiKey: string,
baseUrl: string,
) {
if (reqBody.provider === ServiceProvider.Anthropic) {
return new ChatAnthropic({
temperature: 0,
modelName: reqBody.model,
apiKey: apiKey,
clientOptions: {
baseURL: baseUrl,
},
});
}
return new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getToolEmbeddings(reqBody: RequestBody, apiKey: string, baseUrl: string) {
if (reqBody.provider === ServiceProvider.Anthropic) {
if (process.env.OLLAMA_BASE_URL) {
return new OllamaEmbeddings({
model: process.env.RAG_EMBEDDING_MODEL,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
return null;
}
}
return new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) {
const serverConfig = getServerSideConfig();
if (reqBody.isAzure || serverConfig.isAzure) {
@ -266,7 +324,6 @@ export class AgentApi {
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
// maxTokens: 1024,
clientOptions: {
baseURL: baseUrl,
},
@ -300,22 +357,9 @@ export class AgentApi {
const authToken = req.headers.get(authHeaderName) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
let apiKey = await this.getOpenAIApiKey(token);
let apiKey = this.getApiKey(token, reqBody.provider);
if (isAzure) apiKey = token;
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
) {
baseUrl = reqBody.baseUrl;
}
if (
reqBody.provider === ServiceProvider.OpenAI &&
!baseUrl.endsWith("/v1")
) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
}
let baseUrl = this.getBaseUrl(reqBody.baseUrl, reqBody.provider);
if (!reqBody.isAzure && serverConfig.isAzure) {
baseUrl = serverConfig.azureUrl || baseUrl;
}

View File

@ -3,7 +3,8 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ModelProvider } from "@/app/constant";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { Embeddings } from "@langchain/core/embeddings";
import { BaseLanguageModel } from "@langchain/core/language_models/base";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@ -27,23 +28,13 @@ async function handle(req: NextRequest) {
const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const apiKey = await agentApi.getOpenAIApiKey(token);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
const apiKey = agentApi.getApiKey(token, reqBody.provider);
const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
const model = new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
const embeddings = new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let model: BaseLanguageModel;
let embeddings: Embeddings | null;
model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
var dalleCallback = async (data: string) => {
var response = new ResponseBody();

View File

@ -2,17 +2,23 @@ import { NextRequest, NextResponse } from "next/server";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
import { ModelProvider } from "@/app/constant";
import { ModelProvider, ServiceProvider } from "@/app/constant";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
import { Embeddings } from "@langchain/core/embeddings";
import { BaseLanguageModel } from "@langchain/core/language_models/base";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
try {
const authResult = auth(req, ModelProvider.GPT);
const reqBody: RequestBody = await req.json();
const modelProvider =
reqBody.provider === ServiceProvider.Anthropic
? ModelProvider.Claude
: ModelProvider.GPT;
const authResult = auth(req, modelProvider);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
@ -25,27 +31,17 @@ async function handle(req: NextRequest) {
const controller = new AbortController();
const agentApi = new AgentApi(encoder, transformStream, writer, controller);
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const authToken =
(req.headers.get("Authorization") || req.headers.get("x-api-key")) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const apiKey = await agentApi.getOpenAIApiKey(token);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
const apiKey = agentApi.getApiKey(token, reqBody.provider);
const baseUrl = agentApi.getBaseUrl(reqBody.baseUrl, reqBody.provider);
let model: BaseLanguageModel;
let embeddings: Embeddings | null;
model = agentApi.getToolBaseLanguageModel(reqBody, apiKey, baseUrl);
embeddings = agentApi.getToolEmbeddings(reqBody, apiKey, baseUrl);
const model = new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
const embeddings = new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let ragEmbeddings: Embeddings;
if (process.env.OLLAMA_BASE_URL) {
ragEmbeddings = new OllamaEmbeddings({
@ -98,6 +94,7 @@ async function handle(req: NextRequest) {
export const GET = handle;
export const POST = handle;
export const maxDuration = 60;
export const runtime = "nodejs";
export const preferredRegion = [
"arn1",