fix: support azure

This commit is contained in:
Hk-Gosuto 2023-12-25 19:32:18 +08:00
parent fa2e046285
commit 24de1bb77a
7 changed files with 77 additions and 30 deletions

View File

@ -199,6 +199,25 @@ Google Gemini Pro Api Key.
Google Gemini Pro Api Url. Google Gemini Pro Api Url.
### `AZURE_URL` (可选)
> 形如https://{azure-resource-url}/openai/deployments
>
> ⚠️ 注意:这里与原项目配置不同,不需要指定 {deploy-name},将模型名修改为 {deploy-name} 即可切换不同的模型
>
> ⚠️ DALL-E 等需要 openai 密钥的插件暂不支持 Azure
Azure 部署地址。
### `AZURE_API_KEY` (可选)
Azure 密钥。
### `AZURE_API_VERSION` (可选)
Azure Api 版本,你可以在这里找到:[Azure 文档](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)。
## 部署 ## 部署
### 容器部署 (推荐) ### 容器部署 (推荐)

View File

@ -94,7 +94,7 @@ OpenAI 接口代理 URL如果你手动配置了 openai 接口代理,请填
### `AZURE_URL` (可选) ### `AZURE_URL` (可选)
> 形如https://{azure-resource-url}/openai/deployments/{deploy-name} > 形如https://{azure-resource-url}/openai/deployments
Azure 部署地址。 Azure 部署地址。

View File

@ -60,7 +60,6 @@ export async function requestOpenai(req: NextRequest) {
path = makeAzurePath(path, serverConfig.azureApiVersion); path = makeAzurePath(path, serverConfig.azureApiVersion);
} }
const fetchUrl = `${baseUrl}/${path}`;
const fetchOptions: RequestInit = { const fetchOptions: RequestInit = {
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -78,6 +77,12 @@ export async function requestOpenai(req: NextRequest) {
duplex: "half", duplex: "half",
signal: controller.signal, signal: controller.signal,
}; };
const clonedBody = await req.text();
const jsonBody = JSON.parse(clonedBody) as { model?: string };
if (serverConfig.isAzure) {
baseUrl = `${baseUrl}/${jsonBody.model}`;
}
const fetchUrl = `${baseUrl}/${path}`;
// #1815 try to refuse gpt4 request // #1815 try to refuse gpt4 request
if (serverConfig.customModels && req.body) { if (serverConfig.customModels && req.body) {
@ -86,11 +91,10 @@ export async function requestOpenai(req: NextRequest) {
DEFAULT_MODELS, DEFAULT_MODELS,
serverConfig.customModels, serverConfig.customModels,
); );
const clonedBody = await req.text(); // const clonedBody = await req.text();
// const jsonBody = JSON.parse(clonedBody) as { model?: string };
fetchOptions.body = clonedBody; fetchOptions.body = clonedBody;
const jsonBody = JSON.parse(clonedBody) as { model?: string };
// not undefined and is false // not undefined and is false
if (modelTable[jsonBody?.model ?? ""].available === false) { if (modelTable[jsonBody?.model ?? ""].available === false) {
return NextResponse.json( return NextResponse.json(

View File

@ -8,7 +8,7 @@ import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory"; import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents"; import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools"; import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
@ -16,6 +16,7 @@ import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { DynamicTool, Tool } from "langchain/tools"; import { DynamicTool, Tool } from "langchain/tools";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
export interface RequestMessage { export interface RequestMessage {
role: string; role: string;
@ -24,6 +25,8 @@ export interface RequestMessage {
export interface RequestBody { export interface RequestBody {
messages: RequestMessage[]; messages: RequestMessage[];
isAzure: boolean;
azureApiVersion?: string;
model: string; model: string;
stream?: boolean; stream?: boolean;
temperature: number; temperature: number;
@ -152,10 +155,10 @@ export class AgentApi {
async getOpenAIApiKey(token: string) { async getOpenAIApiKey(token: string) {
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey; let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) { if (isApiKey && token) {
apiKey = token; apiKey = token;
} }
return apiKey; return apiKey;
@ -179,27 +182,31 @@ export class AgentApi {
customTools: any[], customTools: any[],
) { ) {
try { try {
let useTools = reqBody.useTools ?? [];
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
// const reqBody: RequestBody = await req.json(); // const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? ""; const isAzure = reqBody.isAzure || serverConfig.isAzure;
const authHeaderName = isAzure ? "api-key" : "Authorization";
const authToken = req.headers.get(authHeaderName) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim(); const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let useTools = reqBody.useTools ?? [];
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
}
let apiKey = await this.getOpenAIApiKey(token);
if (isAzure) apiKey = token;
let baseUrl = "https://api.openai.com/v1"; let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if ( if (
reqBody.baseUrl?.startsWith("http://") || reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://") reqBody.baseUrl?.startsWith("https://")
) ) {
baseUrl = reqBody.baseUrl; baseUrl = reqBody.baseUrl;
if (!baseUrl.endsWith("/v1")) }
if (!isAzure && !baseUrl.endsWith("/v1")) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
}
if (!reqBody.isAzure && serverConfig.isAzure) {
baseUrl = serverConfig.azureUrl || baseUrl;
}
console.log("[baseUrl]", baseUrl); console.log("[baseUrl]", baseUrl);
var handler = await this.getHandler(reqBody); var handler = await this.getHandler(reqBody);
@ -281,7 +288,7 @@ export class AgentApi {
chatHistory: new ChatMessageHistory(pastMessages), chatHistory: new ChatMessageHistory(pastMessages),
}); });
const llm = new ChatOpenAI( let llm = new ChatOpenAI(
{ {
modelName: reqBody.model, modelName: reqBody.model,
openAIApiKey: apiKey, openAIApiKey: apiKey,
@ -293,6 +300,23 @@ export class AgentApi {
}, },
{ basePath: baseUrl }, { basePath: baseUrl },
); );
if (reqBody.isAzure || serverConfig.isAzure) {
llm = new ChatOpenAI({
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: reqBody.isAzure
? reqBody.azureApiVersion
: serverConfig.azureApiVersion,
azureOpenAIApiDeploymentName: reqBody.model,
azureOpenAIBasePath: baseUrl,
});
}
const executor = await initializeAgentExecutorWithOptions(tools, llm, { const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions", agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps, returnIntermediateSteps: reqBody.returnIntermediateSteps,

View File

@ -38,7 +38,7 @@ export interface OpenAIListModelResponse {
export class ChatGPTApi implements LLMApi { export class ChatGPTApi implements LLMApi {
private disableListModels = true; private disableListModels = true;
path(path: string): string { path(path: string, model?: string): string {
const accessStore = useAccessStore.getState(); const accessStore = useAccessStore.getState();
const isAzure = accessStore.provider === ServiceProvider.Azure; const isAzure = accessStore.provider === ServiceProvider.Azure;
@ -65,6 +65,7 @@ export class ChatGPTApi implements LLMApi {
if (isAzure) { if (isAzure) {
path = makeAzurePath(path, accessStore.azureApiVersion); path = makeAzurePath(path, accessStore.azureApiVersion);
return [baseUrl, model, path].join("/");
} }
return [baseUrl, path].join("/"); return [baseUrl, path].join("/");
@ -136,7 +137,7 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller); options.onController?.(controller);
try { try {
const chatPath = this.path(OpenaiPath.ChatPath); const chatPath = this.path(OpenaiPath.ChatPath, modelConfig.model);
const chatPayload = { const chatPayload = {
method: "POST", method: "POST",
body: JSON.stringify(requestPayload), body: JSON.stringify(requestPayload),
@ -284,16 +285,20 @@ export class ChatGPTApi implements LLMApi {
model: options.config.model, model: options.config.model,
}, },
}; };
const accessStore = useAccessStore.getState();
const isAzure = accessStore.provider === ServiceProvider.Azure;
let baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
const requestPayload = { const requestPayload = {
messages, messages,
isAzure,
azureApiVersion: accessStore.azureApiVersion,
stream: options.config.stream, stream: options.config.stream,
model: modelConfig.model, model: modelConfig.model,
temperature: modelConfig.temperature, temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty, presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty, frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p, top_p: modelConfig.top_p,
baseUrl: useAccessStore.getState().openaiUrl, baseUrl: baseUrl,
maxIterations: options.agentConfig.maxIterations, maxIterations: options.agentConfig.maxIterations,
returnIntermediateSteps: options.agentConfig.returnIntermediateSteps, returnIntermediateSteps: options.agentConfig.returnIntermediateSteps,
useTools: options.agentConfig.useTools, useTools: options.agentConfig.useTools,
@ -321,7 +326,7 @@ export class ChatGPTApi implements LLMApi {
() => controller.abort(), () => controller.abort(),
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
); );
console.log("shouldStream", shouldStream); // console.log("shouldStream", shouldStream);
if (shouldStream) { if (shouldStream) {
let responseText = ""; let responseText = "";

View File

@ -23,7 +23,7 @@ declare global {
CUSTOM_MODELS?: string; // to control custom models CUSTOM_MODELS?: string; // to control custom models
// azure only // azure only
AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name} AZURE_URL?: string; // https://{azure-url}/openai/deployments
AZURE_API_KEY?: string; AZURE_API_KEY?: string;
AZURE_API_VERSION?: string; AZURE_API_VERSION?: string;

View File

@ -88,13 +88,8 @@ export const OpenaiPath = {
ListModelPath: "v1/models", ListModelPath: "v1/models",
}; };
export const GooglePath = {
ChatPath: "v1/models/{{model}}:streamGenerateContent",
ListModelPath: "v1/models",
};
export const Azure = { export const Azure = {
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", ExampleEndpoint: "https://{resource-url}/openai/deployments",
}; };
export const Google = { export const Google = {