diff --git a/app/client/api.ts b/app/client/api.ts index ee43fc7cc..d2eeca46a 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -162,46 +162,70 @@ export class ClientApi { export function getHeaders() { const accessStore = useAccessStore.getState(); + const chatStore = useChatStore.getState(); const headers: Record = { "Content-Type": "application/json", Accept: "application/json", }; - const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; - const isGoogle = modelConfig.providerName == ServiceProvider.Google; - const isAzure = modelConfig.providerName === ServiceProvider.Azure; - const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; - const authHeader = isAzure - ? "api-key" - : isAnthropic - ? "x-api-key" - : "Authorization"; - const apiKey = isGoogle - ? accessStore.googleApiKey - : isAzure - ? accessStore.azureApiKey - : isAnthropic - ? accessStore.anthropicApiKey - : accessStore.openaiApiKey; - const clientConfig = getClientConfig(); - const makeBearer = (s: string) => - `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`; - const validString = (x: string) => x && x.length > 0; + const clientConfig = getClientConfig(); + + function getConfig() { + const modelConfig = chatStore.currentSession().mask.modelConfig; + const isGoogle = modelConfig.providerName == ServiceProvider.Google; + const isAzure = modelConfig.providerName === ServiceProvider.Azure; + const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; + const isEnabledAccessControl = accessStore.enabledAccessControl(); + const apiKey = isGoogle + ? accessStore.googleApiKey + : isAzure + ? accessStore.azureApiKey + : isAnthropic + ? accessStore.anthropicApiKey + : accessStore.openaiApiKey; + return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl }; + } + + function getAuthHeader(): string { + return isAzure ? "api-key" : isAnthropic ? "x-api-key" : "Authorization"; + } + + function getBearerToken(apiKey: string, noBearer: boolean = false): string { + return validString(apiKey) + ? `${noBearer ? "" : "Bearer "}${apiKey.trim()}` + : ""; + } + + function validString(x: string): boolean { + return x?.length > 0; + } + const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } = + getConfig(); // when using google api in app, not set auth header - if (!(isGoogle && clientConfig?.isApp)) { - // use user's api key first - if (validString(apiKey)) { - headers[authHeader] = makeBearer(apiKey); - } else if ( - accessStore.enabledAccessControl() && - validString(accessStore.accessCode) - ) { - // access_code must send with header named `Authorization`, will using in auth middleware. - headers["Authorization"] = makeBearer( - ACCESS_CODE_PREFIX + accessStore.accessCode, - ); - } + if (isGoogle && clientConfig?.isApp) return headers; + + const authHeader = getAuthHeader(); + + const bearerToken = getBearerToken(apiKey, isAzure || isAnthropic); + + if (bearerToken) { + headers[authHeader] = bearerToken; + } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { + headers["Authorization"] = getBearerToken( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); } return headers; } + +export function getClientApi(provider: ServiceProvider): ClientApi { + switch (provider) { + case ServiceProvider.Google: + return new ClientApi(ModelProvider.GeminiPro); + case ServiceProvider.Anthropic: + return new ClientApi(ModelProvider.Claude); + default: + return new ClientApi(ModelProvider.GPT); + } +} diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 1cc531eb8..948807d4c 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -36,13 +36,9 @@ import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; import { prettyObject } from "../utils/format"; -import { - EXPORT_MESSAGE_CLASS_NAME, - ModelProvider, - ServiceProvider, -} from "../constant"; +import { EXPORT_MESSAGE_CLASS_NAME } from "../constant"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { getMessageTextContent } from "../utils"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { @@ -316,16 +312,7 @@ export function PreviewActions(props: { const onRenderMsgs = (msgs: ChatMessage[]) => { setShouldExport(false); - var api: ClientApi; - if (config.modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) { - api = new ClientApi(ModelProvider.Doubao); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); api .share(msgs) diff --git a/app/components/home.tsx b/app/components/home.tsx index 7da20df22..e127c65f8 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg"; import { getCSSVar, useMobileScreen } from "../utils"; import dynamic from "next/dynamic"; -import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant"; +import { Path, SlotID } from "../constant"; import { ErrorBoundary } from "./error"; import { getISOLang, getLang } from "../locales"; @@ -27,7 +27,7 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { useAccessStore } from "../store"; export function Loading(props: { noLogo?: boolean }) { @@ -170,16 +170,8 @@ function Screen() { export function useLoadData() { const config = useAppConfig(); - var api: ClientApi; - if (config.modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) { - api = new ClientApi(ModelProvider.Doubao); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); + useEffect(() => { (async () => { const models = await api.llm.models(); diff --git a/app/store/chat.ts b/app/store/chat.ts index 475d436d9..d14bd82d8 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -15,7 +15,12 @@ import { SUMMARIZE_MODEL, GEMINI_SUMMARIZE_MODEL, } from "../constant"; -import { ClientApi, RequestMessage, MultimodalContent } from "../client/api"; +import { getClientApi } from "../client/api"; +import type { + ClientApi, + RequestMessage, + MultimodalContent, +} from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; @@ -363,17 +368,7 @@ export const useChatStore = createPersistStore( ]); }); - var api: ClientApi; - if (modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else if (modelConfig.providerName == ServiceProvider.ByteDance) { - api = new ClientApi(ModelProvider.Doubao); - } else { - api = new ClientApi(ModelProvider.GPT); - } - + const api: ClientApi = getClientApi(modelConfig.providerName); // make request api.llm.chat({ messages: sendMessages, @@ -549,16 +544,7 @@ export const useChatStore = createPersistStore( const session = get().currentSession(); const modelConfig = session.mask.modelConfig; - var api: ClientApi; - if (modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else if (modelConfig.providerName == ServiceProvider.ByteDance) { - api = new ClientApi(ModelProvider.Doubao); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(modelConfig.providerName); // remove error messages if any const messages = session.messages;