Merge pull request #4930 from ConnectAI-E/feature-azure

support azure deployment name
This commit is contained in:
Dogtiti 2024-07-06 10:50:33 +08:00 committed by GitHub
commit 2d1f522aaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 204 additions and 95 deletions

View File

@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
break;
case ModelProvider.GPT:
default:
if (serverConfig.isAzure) {
if (req.nextUrl.pathname.includes("azure/deployments")) {
systemApiKey = serverConfig.azureApiKey;
} else {
systemApiKey = serverConfig.apiKey;

View File

@ -0,0 +1,57 @@
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { requestOpenai } from "../../common";
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Azure Route] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const subpath = params.path.join("/");
const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try {
return await requestOpenai(req);
} catch (e) {
console.error("[Azure] ", e);
return NextResponse.json(prettyObject(e));
}
}
export const GET = handle;
export const POST = handle;
export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];

View File

@ -7,16 +7,17 @@ import {
ServiceProvider,
} from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { makeAzurePath } from "../azure";
const serverConfig = getServerSideConfig();
export async function requestOpenai(req: NextRequest) {
const controller = new AbortController();
const isAzure = req.nextUrl.pathname.includes("azure/deployments");
var authValue,
authHeaderName = "";
if (serverConfig.isAzure) {
if (isAzure) {
authValue =
req.headers
.get("Authorization")
@ -56,14 +57,15 @@ export async function requestOpenai(req: NextRequest) {
10 * 60 * 1000,
);
if (serverConfig.isAzure) {
if (!serverConfig.azureApiVersion) {
return NextResponse.json({
error: true,
message: `missing AZURE_API_VERSION in server env vars`,
});
}
path = makeAzurePath(path, serverConfig.azureApiVersion);
if (isAzure) {
const azureApiVersion =
req?.nextUrl?.searchParams?.get("api-version") ||
serverConfig.azureApiVersion;
baseUrl = baseUrl.split("/deployments").shift() as string;
path = `${req.nextUrl.pathname.replaceAll(
"/api/azure/",
"",
)}?api-version=${azureApiVersion}`;
}
const fetchUrl = `${baseUrl}/${path}`;

View File

@ -1,9 +0,0 @@
export function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
}

View File

@ -30,6 +30,7 @@ export interface RequestMessage {
export interface LLMConfig {
model: string;
providerName?: string;
temperature?: number;
top_p?: number;
stream?: boolean;
@ -54,6 +55,7 @@ export interface LLMUsage {
export interface LLMModel {
name: string;
displayName?: string;
available: boolean;
provider: LLMModelProvider;
}
@ -160,10 +162,14 @@ export function getHeaders() {
Accept: "application/json",
};
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model.startsWith("gemini");
const isAzure = accessStore.provider === ServiceProvider.Azure;
const isAnthropic = accessStore.provider === ServiceProvider.Anthropic;
const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization";
const isGoogle = modelConfig.providerName == ServiceProvider.Google;
const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const authHeader = isAzure
? "api-key"
: isAnthropic
? "x-api-key"
: "Authorization";
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
@ -172,7 +178,8 @@ export function getHeaders() {
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;
const clientConfig = getClientConfig();
const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const makeBearer = (s: string) =>
`${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0;
// when using google api in app, not set auth header
@ -185,7 +192,7 @@ export function getHeaders() {
validString(accessStore.accessCode)
) {
// access_code must send with header named `Authorization`, will using in auth middleware.
headers['Authorization'] = makeBearer(
headers["Authorization"] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}

View File

@ -1,13 +1,16 @@
"use client";
// azure and openai, using same models. so using same LLMApi.
import {
ApiPath,
DEFAULT_API_HOST,
DEFAULT_MODELS,
OpenaiPath,
Azure,
REQUEST_TIMEOUT_MS,
ServiceProvider,
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";
import {
ChatOptions,
@ -24,7 +27,6 @@ import {
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client";
import { makeAzurePath } from "@/app/azure";
import {
getMessageTextContent,
getMessageImages,
@ -62,33 +64,31 @@ export class ChatGPTApi implements LLMApi {
let baseUrl = "";
const isAzure = path.includes("deployments");
if (accessStore.useCustomConfig) {
const isAzure = accessStore.provider === ServiceProvider.Azure;
if (isAzure && !accessStore.isValidAzure()) {
throw Error(
"incomplete azure config, please check it in your settings page",
);
}
if (isAzure) {
path = makeAzurePath(path, accessStore.azureApiVersion);
}
baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
}
if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp
? DEFAULT_API_HOST + "/proxy" + ApiPath.OpenAI
: ApiPath.OpenAI;
const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI;
baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) {
if (
!baseUrl.startsWith("http") &&
!isAzure &&
!baseUrl.startsWith(ApiPath.OpenAI)
) {
baseUrl = "https://" + baseUrl;
}
@ -113,6 +113,7 @@ export class ChatGPTApi implements LLMApi {
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
providerName: options.config.providerName,
},
};
@ -140,7 +141,35 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller);
try {
const chatPath = this.path(OpenaiPath.ChatPath);
let chatPath = "";
if (modelConfig.providerName === ServiceProvider.Azure) {
// find model, and get displayName as deployName
const { models: configModels, customModels: configCustomModels } =
useAppConfig.getState();
const {
defaultModel,
customModels: accessCustomModels,
useCustomConfig,
} = useAccessStore.getState();
const models = collectModelsWithDefaultModel(
configModels,
[configCustomModels, accessCustomModels].join(","),
defaultModel,
);
const model = models.find(
(model) =>
model.name === modelConfig.model &&
model?.provider?.providerName === ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
}
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),

View File

@ -88,6 +88,7 @@ import {
Path,
REQUEST_TIMEOUT_MS,
UNFINISHED_INPUT,
ServiceProvider,
} from "../constant";
import { Avatar } from "./emoji";
import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask";
@ -448,6 +449,9 @@ export function ChatActions(props: {
// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
const allModels = useAllModels();
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
@ -479,13 +483,13 @@ export function ChatActions(props: {
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
if (isUnavaliableModel && models.length > 0) {
// show next model to default model if exist
let nextModel: ModelType = (
models.find((model) => model.isDefault) || models[0]
).name;
chatStore.updateCurrentSession(
(session) => (session.mask.modelConfig.model = nextModel),
);
showToast(nextModel);
let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider;
});
showToast(nextModel.name);
}
}, [chatStore, currentModel, models]);
@ -573,19 +577,26 @@ export function ChatActions(props: {
{showModelSelector && (
<Selector
defaultSelectedValue={currentModel}
defaultSelectedValue={`${currentModel}@${currentProviderName}`}
items={models.map((m) => ({
title: m.displayName,
value: m.name,
title: `${m.displayName}${
m?.provider?.providerName
? "(" + m?.provider?.providerName + ")"
: ""
}`,
value: `${m.name}@${m?.provider?.providerName}`,
}))}
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = s[0] as ModelType;
session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
providerName as ServiceProvider;
session.mask.syncGlobalConfig = false;
});
showToast(s[0]);
showToast(model);
}}
/>
)}

View File

@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask";
import { prettyObject } from "../utils/format";
import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { getMessageTextContent } from "../utils";
import { identifyDefaultClaudeModel } from "../utils/checkers";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
@ -314,9 +317,9 @@ export function PreviewActions(props: {
setShouldExport(false);
var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) {
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);

View File

@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
import { getCSSVar, useMobileScreen } from "../utils";
import dynamic from "next/dynamic";
import { ModelProvider, Path, SlotID } from "../constant";
import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
import { ErrorBoundary } from "./error";
import { getISOLang, getLang } from "../locales";
@ -29,7 +29,6 @@ import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { useAccessStore } from "../store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
export function Loading(props: { noLogo?: boolean }) {
return (
@ -172,9 +171,9 @@ export function useLoadData() {
const config = useAppConfig();
var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) {
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);

View File

@ -1,3 +1,4 @@
import { ServiceProvider } from "@/app/constant";
import { ModalConfigValidator, ModelConfig } from "../store";
import Locale from "../locales";
@ -10,25 +11,25 @@ export function ModelConfigList(props: {
updateConfig: (updater: (config: ModelConfig) => void) => void;
}) {
const allModels = useAllModels();
const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`;
return (
<>
<ListItem title={Locale.Settings.Model}>
<Select
value={props.modelConfig.model}
value={value}
onChange={(e) => {
props.updateConfig(
(config) =>
(config.model = ModalConfigValidator.model(
e.currentTarget.value,
)),
);
const [model, providerName] = e.currentTarget.value.split("@");
props.updateConfig((config) => {
config.model = ModalConfigValidator.model(model);
config.providerName = providerName as ServiceProvider;
});
}}
>
{allModels
.filter((v) => v.available)
.map((v, i) => (
<option value={v.name} key={i}>
<option value={`${v.name}@${v.provider?.providerName}`} key={i}>
{v.displayName}({v.provider?.providerName})
</option>
))}
@ -92,7 +93,7 @@ export function ModelConfigList(props: {
></input>
</ListItem>
{props.modelConfig.model.startsWith("gemini") ? null : (
{props.modelConfig?.providerName == ServiceProvider.Google ? null : (
<>
<ListItem
title={Locale.Settings.PresencePenalty.Title}

View File

@ -25,6 +25,7 @@ export enum Path {
export enum ApiPath {
Cors = "",
Azure = "/api/azure",
OpenAI = "/api/openai",
Anthropic = "/api/anthropic",
}
@ -93,6 +94,8 @@ export const OpenaiPath = {
};
export const Azure = {
ChatPath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
};
@ -150,6 +153,7 @@ const openaiModels = [
"gpt-4o-2024-05-13",
"gpt-4-vision-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-1106-preview",
];
const googleModels = [
@ -179,6 +183,15 @@ export const DEFAULT_MODELS = [
providerType: "openai",
},
})),
...openaiModels.map((name) => ({
name,
available: true,
provider: {
id: "azure",
providerName: "Azure",
providerType: "azure",
},
})),
...googleModels.map((name) => ({
name,
available: true,

View File

@ -17,6 +17,11 @@ const DEFAULT_OPENAI_URL =
? DEFAULT_API_HOST + "/api/proxy/openai"
: ApiPath.OpenAI;
const DEFAULT_AZURE_URL =
getClientConfig()?.buildMode === "export"
? DEFAULT_API_HOST + "/api/proxy/azure/{resource_name}"
: ApiPath.Azure;
const DEFAULT_ACCESS_STATE = {
accessCode: "",
useCustomConfig: false,
@ -28,7 +33,7 @@ const DEFAULT_ACCESS_STATE = {
openaiApiKey: "",
// azure
azureUrl: "",
azureUrl: DEFAULT_AZURE_URL,
azureApiKey: "",
azureApiVersion: "2023-08-01-preview",

View File

@ -9,6 +9,7 @@ import {
DEFAULT_MODELS,
DEFAULT_SYSTEM_TEMPLATE,
KnowledgeCutOffDate,
ServiceProvider,
ModelProvider,
StoreKey,
SUMMARIZE_MODEL,
@ -20,7 +21,6 @@ import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
@ -364,9 +364,9 @@ export const useChatStore = createPersistStore(
});
var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) {
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) {
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
@ -548,9 +548,9 @@ export const useChatStore = createPersistStore(
const modelConfig = session.mask.modelConfig;
var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) {
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) {
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);

View File

@ -5,6 +5,7 @@ import {
DEFAULT_MODELS,
DEFAULT_SIDEBAR_WIDTH,
StoreKey,
ServiceProvider,
} from "../constant";
import { createPersistStore } from "../utils/store";
@ -48,6 +49,7 @@ export const DEFAULT_CONFIG = {
modelConfig: {
model: "gpt-3.5-turbo" as ModelType,
providerName: "Openai" as ServiceProvider,
temperature: 0.5,
top_p: 1,
max_tokens: 4000,

View File

@ -1,21 +0,0 @@
import { useAccessStore } from "../store/access";
import { useAppConfig } from "../store/config";
import { collectModels } from "./model";
export function identifyDefaultClaudeModel(modelName: string) {
const accessStore = useAccessStore.getState();
const configStore = useAppConfig.getState();
const allModals = collectModels(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
);
const modelMeta = allModals.find((m) => m.name === modelName);
return (
modelName.startsWith("claude") &&
modelMeta &&
modelMeta.provider?.providerType === "anthropic"
);
}

View File

@ -11,7 +11,12 @@ export function useAllModels() {
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
}, [accessStore.customModels, configStore.customModels, configStore.models]);
}, [
accessStore.customModels,
accessStore.defaultModel,
configStore.customModels,
configStore.models,
]);
return models;
}

View File

@ -69,6 +69,11 @@ if (mode !== "export") {
source: "/api/proxy/v1/:path*",
destination: "https://api.openai.com/v1/:path*",
},
{
// https://{resource_name}.openai.azure.com/openai/deployments/{deploy_name}/chat/completions
source: "/api/proxy/azure/:resource_name/deployments/:deploy_name/:path*",
destination: "https://:resource_name.openai.azure.com/openai/deployments/:deploy_name/:path*",
},
{
source: "/api/proxy/google/:path*",
destination: "https://generativelanguage.googleapis.com/:path*",