Merge pull request #4930 from ConnectAI-E/feature-azure

support azure deployment name
This commit is contained in:
Dogtiti 2024-07-06 10:50:33 +08:00 committed by GitHub
commit 2d1f522aaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 204 additions and 95 deletions

View File

@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
break; break;
case ModelProvider.GPT: case ModelProvider.GPT:
default: default:
if (serverConfig.isAzure) { if (req.nextUrl.pathname.includes("azure/deployments")) {
systemApiKey = serverConfig.azureApiKey; systemApiKey = serverConfig.azureApiKey;
} else { } else {
systemApiKey = serverConfig.apiKey; systemApiKey = serverConfig.apiKey;

View File

@ -0,0 +1,57 @@
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { requestOpenai } from "../../common";
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Azure Route] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const subpath = params.path.join("/");
const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try {
return await requestOpenai(req);
} catch (e) {
console.error("[Azure] ", e);
return NextResponse.json(prettyObject(e));
}
}
export const GET = handle;
export const POST = handle;
export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];

View File

@ -7,16 +7,17 @@ import {
ServiceProvider, ServiceProvider,
} from "../constant"; } from "../constant";
import { isModelAvailableInServer } from "../utils/model"; import { isModelAvailableInServer } from "../utils/model";
import { makeAzurePath } from "../azure";
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
export async function requestOpenai(req: NextRequest) { export async function requestOpenai(req: NextRequest) {
const controller = new AbortController(); const controller = new AbortController();
const isAzure = req.nextUrl.pathname.includes("azure/deployments");
var authValue, var authValue,
authHeaderName = ""; authHeaderName = "";
if (serverConfig.isAzure) { if (isAzure) {
authValue = authValue =
req.headers req.headers
.get("Authorization") .get("Authorization")
@ -56,14 +57,15 @@ export async function requestOpenai(req: NextRequest) {
10 * 60 * 1000, 10 * 60 * 1000,
); );
if (serverConfig.isAzure) { if (isAzure) {
if (!serverConfig.azureApiVersion) { const azureApiVersion =
return NextResponse.json({ req?.nextUrl?.searchParams?.get("api-version") ||
error: true, serverConfig.azureApiVersion;
message: `missing AZURE_API_VERSION in server env vars`, baseUrl = baseUrl.split("/deployments").shift() as string;
}); path = `${req.nextUrl.pathname.replaceAll(
} "/api/azure/",
path = makeAzurePath(path, serverConfig.azureApiVersion); "",
)}?api-version=${azureApiVersion}`;
} }
const fetchUrl = `${baseUrl}/${path}`; const fetchUrl = `${baseUrl}/${path}`;

View File

@ -1,9 +0,0 @@
export function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
}

View File

@ -30,6 +30,7 @@ export interface RequestMessage {
export interface LLMConfig { export interface LLMConfig {
model: string; model: string;
providerName?: string;
temperature?: number; temperature?: number;
top_p?: number; top_p?: number;
stream?: boolean; stream?: boolean;
@ -54,6 +55,7 @@ export interface LLMUsage {
export interface LLMModel { export interface LLMModel {
name: string; name: string;
displayName?: string;
available: boolean; available: boolean;
provider: LLMModelProvider; provider: LLMModelProvider;
} }
@ -160,10 +162,14 @@ export function getHeaders() {
Accept: "application/json", Accept: "application/json",
}; };
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model.startsWith("gemini"); const isGoogle = modelConfig.providerName == ServiceProvider.Google;
const isAzure = accessStore.provider === ServiceProvider.Azure; const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = accessStore.provider === ServiceProvider.Anthropic; const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization"; const authHeader = isAzure
? "api-key"
: isAnthropic
? "x-api-key"
: "Authorization";
const apiKey = isGoogle const apiKey = isGoogle
? accessStore.googleApiKey ? accessStore.googleApiKey
: isAzure : isAzure
@ -172,7 +178,8 @@ export function getHeaders() {
? accessStore.anthropicApiKey ? accessStore.anthropicApiKey
: accessStore.openaiApiKey; : accessStore.openaiApiKey;
const clientConfig = getClientConfig(); const clientConfig = getClientConfig();
const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`; const makeBearer = (s: string) =>
`${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0; const validString = (x: string) => x && x.length > 0;
// when using google api in app, not set auth header // when using google api in app, not set auth header
@ -185,7 +192,7 @@ export function getHeaders() {
validString(accessStore.accessCode) validString(accessStore.accessCode)
) { ) {
// access_code must send with header named `Authorization`, will using in auth middleware. // access_code must send with header named `Authorization`, will using in auth middleware.
headers['Authorization'] = makeBearer( headers["Authorization"] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode, ACCESS_CODE_PREFIX + accessStore.accessCode,
); );
} }

View File

@ -1,13 +1,16 @@
"use client"; "use client";
// azure and openai, using same models. so using same LLMApi.
import { import {
ApiPath, ApiPath,
DEFAULT_API_HOST, DEFAULT_API_HOST,
DEFAULT_MODELS, DEFAULT_MODELS,
OpenaiPath, OpenaiPath,
Azure,
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
ServiceProvider, ServiceProvider,
} from "@/app/constant"; } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";
import { import {
ChatOptions, ChatOptions,
@ -24,7 +27,6 @@ import {
} from "@fortaine/fetch-event-source"; } from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
import { makeAzurePath } from "@/app/azure";
import { import {
getMessageTextContent, getMessageTextContent,
getMessageImages, getMessageImages,
@ -62,33 +64,31 @@ export class ChatGPTApi implements LLMApi {
let baseUrl = ""; let baseUrl = "";
const isAzure = path.includes("deployments");
if (accessStore.useCustomConfig) { if (accessStore.useCustomConfig) {
const isAzure = accessStore.provider === ServiceProvider.Azure;
if (isAzure && !accessStore.isValidAzure()) { if (isAzure && !accessStore.isValidAzure()) {
throw Error( throw Error(
"incomplete azure config, please check it in your settings page", "incomplete azure config, please check it in your settings page",
); );
} }
if (isAzure) {
path = makeAzurePath(path, accessStore.azureApiVersion);
}
baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl; baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
} }
if (baseUrl.length === 0) { if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp; const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI;
? DEFAULT_API_HOST + "/proxy" + ApiPath.OpenAI baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath;
: ApiPath.OpenAI;
} }
if (baseUrl.endsWith("/")) { if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1); baseUrl = baseUrl.slice(0, baseUrl.length - 1);
} }
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) { if (
!baseUrl.startsWith("http") &&
!isAzure &&
!baseUrl.startsWith(ApiPath.OpenAI)
) {
baseUrl = "https://" + baseUrl; baseUrl = "https://" + baseUrl;
} }
@ -113,6 +113,7 @@ export class ChatGPTApi implements LLMApi {
...useChatStore.getState().currentSession().mask.modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig,
...{ ...{
model: options.config.model, model: options.config.model,
providerName: options.config.providerName,
}, },
}; };
@ -140,7 +141,35 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller); options.onController?.(controller);
try { try {
const chatPath = this.path(OpenaiPath.ChatPath); let chatPath = "";
if (modelConfig.providerName === ServiceProvider.Azure) {
// find model, and get displayName as deployName
const { models: configModels, customModels: configCustomModels } =
useAppConfig.getState();
const {
defaultModel,
customModels: accessCustomModels,
useCustomConfig,
} = useAccessStore.getState();
const models = collectModelsWithDefaultModel(
configModels,
[configCustomModels, accessCustomModels].join(","),
defaultModel,
);
const model = models.find(
(model) =>
model.name === modelConfig.model &&
model?.provider?.providerName === ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
}
const chatPayload = { const chatPayload = {
method: "POST", method: "POST",
body: JSON.stringify(requestPayload), body: JSON.stringify(requestPayload),

View File

@ -88,6 +88,7 @@ import {
Path, Path,
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
UNFINISHED_INPUT, UNFINISHED_INPUT,
ServiceProvider,
} from "../constant"; } from "../constant";
import { Avatar } from "./emoji"; import { Avatar } from "./emoji";
import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask"; import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask";
@ -448,6 +449,9 @@ export function ChatActions(props: {
// switch model // switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model; const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
const allModels = useAllModels(); const allModels = useAllModels();
const models = useMemo(() => { const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available); const filteredModels = allModels.filter((m) => m.available);
@ -479,13 +483,13 @@ export function ChatActions(props: {
const isUnavaliableModel = !models.some((m) => m.name === currentModel); const isUnavaliableModel = !models.some((m) => m.name === currentModel);
if (isUnavaliableModel && models.length > 0) { if (isUnavaliableModel && models.length > 0) {
// show next model to default model if exist // show next model to default model if exist
let nextModel: ModelType = ( let nextModel = models.find((model) => model.isDefault) || models[0];
models.find((model) => model.isDefault) || models[0] chatStore.updateCurrentSession((session) => {
).name; session.mask.modelConfig.model = nextModel.name;
chatStore.updateCurrentSession( session.mask.modelConfig.providerName = nextModel?.provider
(session) => (session.mask.modelConfig.model = nextModel), ?.providerName as ServiceProvider;
); });
showToast(nextModel); showToast(nextModel.name);
} }
}, [chatStore, currentModel, models]); }, [chatStore, currentModel, models]);
@ -573,19 +577,26 @@ export function ChatActions(props: {
{showModelSelector && ( {showModelSelector && (
<Selector <Selector
defaultSelectedValue={currentModel} defaultSelectedValue={`${currentModel}@${currentProviderName}`}
items={models.map((m) => ({ items={models.map((m) => ({
title: m.displayName, title: `${m.displayName}${
value: m.name, m?.provider?.providerName
? "(" + m?.provider?.providerName + ")"
: ""
}`,
value: `${m.name}@${m?.provider?.providerName}`,
}))} }))}
onClose={() => setShowModelSelector(false)} onClose={() => setShowModelSelector(false)}
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
chatStore.updateCurrentSession((session) => { chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = s[0] as ModelType; session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
providerName as ServiceProvider;
session.mask.syncGlobalConfig = false; session.mask.syncGlobalConfig = false;
}); });
showToast(s[0]); showToast(model);
}} }}
/> />
)} )}

View File

@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask"; import { DEFAULT_MASK_AVATAR } from "../store/mask";
import { prettyObject } from "../utils/format"; import { prettyObject } from "../utils/format";
import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant"; import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api"; import { ClientApi } from "../client/api";
import { getMessageTextContent } from "../utils"; import { getMessageTextContent } from "../utils";
import { identifyDefaultClaudeModel } from "../utils/checkers";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />, loading: () => <LoadingIcon />,
@ -314,9 +317,9 @@ export function PreviewActions(props: {
setShouldExport(false); setShouldExport(false);
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) { if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) { } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

View File

@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
import { getCSSVar, useMobileScreen } from "../utils"; import { getCSSVar, useMobileScreen } from "../utils";
import dynamic from "next/dynamic"; import dynamic from "next/dynamic";
import { ModelProvider, Path, SlotID } from "../constant"; import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
import { ErrorBoundary } from "./error"; import { ErrorBoundary } from "./error";
import { getISOLang, getLang } from "../locales"; import { getISOLang, getLang } from "../locales";
@ -29,7 +29,6 @@ import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client"; import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api"; import { ClientApi } from "../client/api";
import { useAccessStore } from "../store"; import { useAccessStore } from "../store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
export function Loading(props: { noLogo?: boolean }) { export function Loading(props: { noLogo?: boolean }) {
return ( return (
@ -172,9 +171,9 @@ export function useLoadData() {
const config = useAppConfig(); const config = useAppConfig();
var api: ClientApi; var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) { if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) { } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

View File

@ -1,3 +1,4 @@
import { ServiceProvider } from "@/app/constant";
import { ModalConfigValidator, ModelConfig } from "../store"; import { ModalConfigValidator, ModelConfig } from "../store";
import Locale from "../locales"; import Locale from "../locales";
@ -10,25 +11,25 @@ export function ModelConfigList(props: {
updateConfig: (updater: (config: ModelConfig) => void) => void; updateConfig: (updater: (config: ModelConfig) => void) => void;
}) { }) {
const allModels = useAllModels(); const allModels = useAllModels();
const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`;
return ( return (
<> <>
<ListItem title={Locale.Settings.Model}> <ListItem title={Locale.Settings.Model}>
<Select <Select
value={props.modelConfig.model} value={value}
onChange={(e) => { onChange={(e) => {
props.updateConfig( const [model, providerName] = e.currentTarget.value.split("@");
(config) => props.updateConfig((config) => {
(config.model = ModalConfigValidator.model( config.model = ModalConfigValidator.model(model);
e.currentTarget.value, config.providerName = providerName as ServiceProvider;
)), });
);
}} }}
> >
{allModels {allModels
.filter((v) => v.available) .filter((v) => v.available)
.map((v, i) => ( .map((v, i) => (
<option value={v.name} key={i}> <option value={`${v.name}@${v.provider?.providerName}`} key={i}>
{v.displayName}({v.provider?.providerName}) {v.displayName}({v.provider?.providerName})
</option> </option>
))} ))}
@ -92,7 +93,7 @@ export function ModelConfigList(props: {
></input> ></input>
</ListItem> </ListItem>
{props.modelConfig.model.startsWith("gemini") ? null : ( {props.modelConfig?.providerName == ServiceProvider.Google ? null : (
<> <>
<ListItem <ListItem
title={Locale.Settings.PresencePenalty.Title} title={Locale.Settings.PresencePenalty.Title}

View File

@ -25,6 +25,7 @@ export enum Path {
export enum ApiPath { export enum ApiPath {
Cors = "", Cors = "",
Azure = "/api/azure",
OpenAI = "/api/openai", OpenAI = "/api/openai",
Anthropic = "/api/anthropic", Anthropic = "/api/anthropic",
} }
@ -93,6 +94,8 @@ export const OpenaiPath = {
}; };
export const Azure = { export const Azure = {
ChatPath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
}; };
@ -150,6 +153,7 @@ const openaiModels = [
"gpt-4o-2024-05-13", "gpt-4o-2024-05-13",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"gpt-4-turbo-2024-04-09", "gpt-4-turbo-2024-04-09",
"gpt-4-1106-preview",
]; ];
const googleModels = [ const googleModels = [
@ -179,6 +183,15 @@ export const DEFAULT_MODELS = [
providerType: "openai", providerType: "openai",
}, },
})), })),
...openaiModels.map((name) => ({
name,
available: true,
provider: {
id: "azure",
providerName: "Azure",
providerType: "azure",
},
})),
...googleModels.map((name) => ({ ...googleModels.map((name) => ({
name, name,
available: true, available: true,

View File

@ -17,6 +17,11 @@ const DEFAULT_OPENAI_URL =
? DEFAULT_API_HOST + "/api/proxy/openai" ? DEFAULT_API_HOST + "/api/proxy/openai"
: ApiPath.OpenAI; : ApiPath.OpenAI;
const DEFAULT_AZURE_URL =
getClientConfig()?.buildMode === "export"
? DEFAULT_API_HOST + "/api/proxy/azure/{resource_name}"
: ApiPath.Azure;
const DEFAULT_ACCESS_STATE = { const DEFAULT_ACCESS_STATE = {
accessCode: "", accessCode: "",
useCustomConfig: false, useCustomConfig: false,
@ -28,7 +33,7 @@ const DEFAULT_ACCESS_STATE = {
openaiApiKey: "", openaiApiKey: "",
// azure // azure
azureUrl: "", azureUrl: DEFAULT_AZURE_URL,
azureApiKey: "", azureApiKey: "",
azureApiVersion: "2023-08-01-preview", azureApiVersion: "2023-08-01-preview",

View File

@ -9,6 +9,7 @@ import {
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_SYSTEM_TEMPLATE, DEFAULT_SYSTEM_TEMPLATE,
KnowledgeCutOffDate, KnowledgeCutOffDate,
ServiceProvider,
ModelProvider, ModelProvider,
StoreKey, StoreKey,
SUMMARIZE_MODEL, SUMMARIZE_MODEL,
@ -20,7 +21,6 @@ import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token"; import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
import { collectModelsWithDefaultModel } from "../utils/model"; import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access"; import { useAccessStore } from "./access";
@ -364,9 +364,9 @@ export const useChatStore = createPersistStore(
}); });
var api: ClientApi; var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) { if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) { } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
@ -548,9 +548,9 @@ export const useChatStore = createPersistStore(
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
var api: ClientApi; var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) { if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) { } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);

View File

@ -5,6 +5,7 @@ import {
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_SIDEBAR_WIDTH, DEFAULT_SIDEBAR_WIDTH,
StoreKey, StoreKey,
ServiceProvider,
} from "../constant"; } from "../constant";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
@ -48,6 +49,7 @@ export const DEFAULT_CONFIG = {
modelConfig: { modelConfig: {
model: "gpt-3.5-turbo" as ModelType, model: "gpt-3.5-turbo" as ModelType,
providerName: "Openai" as ServiceProvider,
temperature: 0.5, temperature: 0.5,
top_p: 1, top_p: 1,
max_tokens: 4000, max_tokens: 4000,

View File

@ -1,21 +0,0 @@
import { useAccessStore } from "../store/access";
import { useAppConfig } from "../store/config";
import { collectModels } from "./model";
export function identifyDefaultClaudeModel(modelName: string) {
const accessStore = useAccessStore.getState();
const configStore = useAppConfig.getState();
const allModals = collectModels(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
);
const modelMeta = allModals.find((m) => m.name === modelName);
return (
modelName.startsWith("claude") &&
modelMeta &&
modelMeta.provider?.providerType === "anthropic"
);
}

View File

@ -11,7 +11,12 @@ export function useAllModels() {
[configStore.customModels, accessStore.customModels].join(","), [configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel, accessStore.defaultModel,
); );
}, [accessStore.customModels, configStore.customModels, configStore.models]); }, [
accessStore.customModels,
accessStore.defaultModel,
configStore.customModels,
configStore.models,
]);
return models; return models;
} }

View File

@ -69,6 +69,11 @@ if (mode !== "export") {
source: "/api/proxy/v1/:path*", source: "/api/proxy/v1/:path*",
destination: "https://api.openai.com/v1/:path*", destination: "https://api.openai.com/v1/:path*",
}, },
{
// https://{resource_name}.openai.azure.com/openai/deployments/{deploy_name}/chat/completions
source: "/api/proxy/azure/:resource_name/deployments/:deploy_name/:path*",
destination: "https://:resource_name.openai.azure.com/openai/deployments/:deploy_name/:path*",
},
{ {
source: "/api/proxy/google/:path*", source: "/api/proxy/google/:path*",
destination: "https://generativelanguage.googleapis.com/:path*", destination: "https://generativelanguage.googleapis.com/:path*",