fix: support azure
This commit is contained in:
parent
fa2e046285
commit
24de1bb77a
19
README.md
19
README.md
|
@ -199,6 +199,25 @@ Google Gemini Pro Api Key.
|
|||
|
||||
Google Gemini Pro Api Url.
|
||||
|
||||
### `AZURE_URL` (可选)
|
||||
|
||||
> 形如:https://{azure-resource-url}/openai/deployments
|
||||
>
|
||||
> ⚠️ 注意:这里与原项目配置不同,不需要指定 {deploy-name},将模型名修改为 {deploy-name} 即可切换不同的模型
|
||||
>
|
||||
> ⚠️ DALL-E 等需要 openai 密钥的插件暂不支持 Azure
|
||||
|
||||
Azure 部署地址。
|
||||
|
||||
### `AZURE_API_KEY` (可选)
|
||||
|
||||
Azure 密钥。
|
||||
|
||||
### `AZURE_API_VERSION` (可选)
|
||||
|
||||
Azure Api 版本,你可以在这里找到:[Azure 文档](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)。
|
||||
|
||||
|
||||
## 部署
|
||||
|
||||
### 容器部署 (推荐)
|
||||
|
|
|
@ -94,7 +94,7 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填
|
|||
|
||||
### `AZURE_URL` (可选)
|
||||
|
||||
> 形如:https://{azure-resource-url}/openai/deployments/{deploy-name}
|
||||
> 形如:https://{azure-resource-url}/openai/deployments
|
||||
|
||||
Azure 部署地址。
|
||||
|
||||
|
|
|
@ -60,7 +60,6 @@ export async function requestOpenai(req: NextRequest) {
|
|||
path = makeAzurePath(path, serverConfig.azureApiVersion);
|
||||
}
|
||||
|
||||
const fetchUrl = `${baseUrl}/${path}`;
|
||||
const fetchOptions: RequestInit = {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -78,6 +77,12 @@ export async function requestOpenai(req: NextRequest) {
|
|||
duplex: "half",
|
||||
signal: controller.signal,
|
||||
};
|
||||
const clonedBody = await req.text();
|
||||
const jsonBody = JSON.parse(clonedBody) as { model?: string };
|
||||
if (serverConfig.isAzure) {
|
||||
baseUrl = `${baseUrl}/${jsonBody.model}`;
|
||||
}
|
||||
const fetchUrl = `${baseUrl}/${path}`;
|
||||
|
||||
// #1815 try to refuse gpt4 request
|
||||
if (serverConfig.customModels && req.body) {
|
||||
|
@ -86,11 +91,10 @@ export async function requestOpenai(req: NextRequest) {
|
|||
DEFAULT_MODELS,
|
||||
serverConfig.customModels,
|
||||
);
|
||||
const clonedBody = await req.text();
|
||||
// const clonedBody = await req.text();
|
||||
// const jsonBody = JSON.parse(clonedBody) as { model?: string };
|
||||
fetchOptions.body = clonedBody;
|
||||
|
||||
const jsonBody = JSON.parse(clonedBody) as { model?: string };
|
||||
|
||||
// not undefined and is false
|
||||
if (modelTable[jsonBody?.model ?? ""].available === false) {
|
||||
return NextResponse.json(
|
||||
|
|
|
@ -8,7 +8,7 @@ import { BaseCallbackHandler } from "langchain/callbacks";
|
|||
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
|
||||
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
|
||||
import { initializeAgentExecutorWithOptions } from "langchain/agents";
|
||||
import { ACCESS_CODE_PREFIX } from "@/app/constant";
|
||||
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
|
||||
|
||||
import * as langchainTools from "langchain/tools";
|
||||
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
|
||||
|
@ -16,6 +16,7 @@ import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
|
|||
import { DynamicTool, Tool } from "langchain/tools";
|
||||
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
|
||||
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
|
||||
import { useAccessStore } from "@/app/store";
|
||||
|
||||
export interface RequestMessage {
|
||||
role: string;
|
||||
|
@ -24,6 +25,8 @@ export interface RequestMessage {
|
|||
|
||||
export interface RequestBody {
|
||||
messages: RequestMessage[];
|
||||
isAzure: boolean;
|
||||
azureApiVersion?: string;
|
||||
model: string;
|
||||
stream?: boolean;
|
||||
temperature: number;
|
||||
|
@ -152,10 +155,10 @@ export class AgentApi {
|
|||
|
||||
async getOpenAIApiKey(token: string) {
|
||||
const serverConfig = getServerSideConfig();
|
||||
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
|
||||
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
|
||||
|
||||
let apiKey = serverConfig.apiKey;
|
||||
if (isOpenAiKey && token) {
|
||||
if (isApiKey && token) {
|
||||
apiKey = token;
|
||||
}
|
||||
return apiKey;
|
||||
|
@ -179,27 +182,31 @@ export class AgentApi {
|
|||
customTools: any[],
|
||||
) {
|
||||
try {
|
||||
let useTools = reqBody.useTools ?? [];
|
||||
const serverConfig = getServerSideConfig();
|
||||
|
||||
// const reqBody: RequestBody = await req.json();
|
||||
const authToken = req.headers.get("Authorization") ?? "";
|
||||
const isAzure = reqBody.isAzure || serverConfig.isAzure;
|
||||
const authHeaderName = isAzure ? "api-key" : "Authorization";
|
||||
const authToken = req.headers.get(authHeaderName) ?? "";
|
||||
const token = authToken.trim().replaceAll("Bearer ", "").trim();
|
||||
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
|
||||
let useTools = reqBody.useTools ?? [];
|
||||
let apiKey = serverConfig.apiKey;
|
||||
if (isOpenAiKey && token) {
|
||||
apiKey = token;
|
||||
}
|
||||
|
||||
let apiKey = await this.getOpenAIApiKey(token);
|
||||
if (isAzure) apiKey = token;
|
||||
let baseUrl = "https://api.openai.com/v1";
|
||||
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
|
||||
if (
|
||||
reqBody.baseUrl?.startsWith("http://") ||
|
||||
reqBody.baseUrl?.startsWith("https://")
|
||||
)
|
||||
) {
|
||||
baseUrl = reqBody.baseUrl;
|
||||
if (!baseUrl.endsWith("/v1"))
|
||||
}
|
||||
if (!isAzure && !baseUrl.endsWith("/v1")) {
|
||||
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
|
||||
}
|
||||
if (!reqBody.isAzure && serverConfig.isAzure) {
|
||||
baseUrl = serverConfig.azureUrl || baseUrl;
|
||||
}
|
||||
console.log("[baseUrl]", baseUrl);
|
||||
|
||||
var handler = await this.getHandler(reqBody);
|
||||
|
@ -281,7 +288,7 @@ export class AgentApi {
|
|||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
});
|
||||
|
||||
const llm = new ChatOpenAI(
|
||||
let llm = new ChatOpenAI(
|
||||
{
|
||||
modelName: reqBody.model,
|
||||
openAIApiKey: apiKey,
|
||||
|
@ -293,6 +300,23 @@ export class AgentApi {
|
|||
},
|
||||
{ basePath: baseUrl },
|
||||
);
|
||||
|
||||
if (reqBody.isAzure || serverConfig.isAzure) {
|
||||
llm = new ChatOpenAI({
|
||||
temperature: reqBody.temperature,
|
||||
streaming: reqBody.stream,
|
||||
topP: reqBody.top_p,
|
||||
presencePenalty: reqBody.presence_penalty,
|
||||
frequencyPenalty: reqBody.frequency_penalty,
|
||||
azureOpenAIApiKey: apiKey,
|
||||
azureOpenAIApiVersion: reqBody.isAzure
|
||||
? reqBody.azureApiVersion
|
||||
: serverConfig.azureApiVersion,
|
||||
azureOpenAIApiDeploymentName: reqBody.model,
|
||||
azureOpenAIBasePath: baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
|
||||
agentType: "openai-functions",
|
||||
returnIntermediateSteps: reqBody.returnIntermediateSteps,
|
||||
|
|
|
@ -38,7 +38,7 @@ export interface OpenAIListModelResponse {
|
|||
export class ChatGPTApi implements LLMApi {
|
||||
private disableListModels = true;
|
||||
|
||||
path(path: string): string {
|
||||
path(path: string, model?: string): string {
|
||||
const accessStore = useAccessStore.getState();
|
||||
|
||||
const isAzure = accessStore.provider === ServiceProvider.Azure;
|
||||
|
@ -65,6 +65,7 @@ export class ChatGPTApi implements LLMApi {
|
|||
|
||||
if (isAzure) {
|
||||
path = makeAzurePath(path, accessStore.azureApiVersion);
|
||||
return [baseUrl, model, path].join("/");
|
||||
}
|
||||
|
||||
return [baseUrl, path].join("/");
|
||||
|
@ -136,7 +137,7 @@ export class ChatGPTApi implements LLMApi {
|
|||
options.onController?.(controller);
|
||||
|
||||
try {
|
||||
const chatPath = this.path(OpenaiPath.ChatPath);
|
||||
const chatPath = this.path(OpenaiPath.ChatPath, modelConfig.model);
|
||||
const chatPayload = {
|
||||
method: "POST",
|
||||
body: JSON.stringify(requestPayload),
|
||||
|
@ -284,16 +285,20 @@ export class ChatGPTApi implements LLMApi {
|
|||
model: options.config.model,
|
||||
},
|
||||
};
|
||||
|
||||
const accessStore = useAccessStore.getState();
|
||||
const isAzure = accessStore.provider === ServiceProvider.Azure;
|
||||
let baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
|
||||
const requestPayload = {
|
||||
messages,
|
||||
isAzure,
|
||||
azureApiVersion: accessStore.azureApiVersion,
|
||||
stream: options.config.stream,
|
||||
model: modelConfig.model,
|
||||
temperature: modelConfig.temperature,
|
||||
presence_penalty: modelConfig.presence_penalty,
|
||||
frequency_penalty: modelConfig.frequency_penalty,
|
||||
top_p: modelConfig.top_p,
|
||||
baseUrl: useAccessStore.getState().openaiUrl,
|
||||
baseUrl: baseUrl,
|
||||
maxIterations: options.agentConfig.maxIterations,
|
||||
returnIntermediateSteps: options.agentConfig.returnIntermediateSteps,
|
||||
useTools: options.agentConfig.useTools,
|
||||
|
@ -321,7 +326,7 @@ export class ChatGPTApi implements LLMApi {
|
|||
() => controller.abort(),
|
||||
REQUEST_TIMEOUT_MS,
|
||||
);
|
||||
console.log("shouldStream", shouldStream);
|
||||
// console.log("shouldStream", shouldStream);
|
||||
|
||||
if (shouldStream) {
|
||||
let responseText = "";
|
||||
|
|
|
@ -23,7 +23,7 @@ declare global {
|
|||
CUSTOM_MODELS?: string; // to control custom models
|
||||
|
||||
// azure only
|
||||
AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name}
|
||||
AZURE_URL?: string; // https://{azure-url}/openai/deployments
|
||||
AZURE_API_KEY?: string;
|
||||
AZURE_API_VERSION?: string;
|
||||
|
||||
|
|
|
@ -88,13 +88,8 @@ export const OpenaiPath = {
|
|||
ListModelPath: "v1/models",
|
||||
};
|
||||
|
||||
export const GooglePath = {
|
||||
ChatPath: "v1/models/{{model}}:streamGenerateContent",
|
||||
ListModelPath: "v1/models",
|
||||
};
|
||||
|
||||
export const Azure = {
|
||||
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
|
||||
ExampleEndpoint: "https://{resource-url}/openai/deployments",
|
||||
};
|
||||
|
||||
export const Google = {
|
||||
|
|
Loading…
Reference in New Issue