From c96e4b79667cc3335bf5ee225914f43b5918c62f Mon Sep 17 00:00:00 2001 From: Wayland Zhan Date: Fri, 19 Apr 2024 06:57:15 +0000 Subject: [PATCH] feat: Support a way to define default model by adding DEFAULT_MODEL env. --- app/api/config/route.ts | 1 + app/components/chat.tsx | 29 +++++++++++++++++++------ app/config/server.ts | 4 ++++ app/store/access.ts | 9 ++++++++ app/utils/hooks.ts | 5 +++-- app/utils/model.ts | 48 +++++++++++++++++++++++++++++++++++------ 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/app/api/config/route.ts b/app/api/config/route.ts index db84fba17..b0d9da031 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -13,6 +13,7 @@ const DANGER_CONFIG = { hideBalanceQuery: serverConfig.hideBalanceQuery, disableFastLink: serverConfig.disableFastLink, customModels: serverConfig.customModels, + defaultModel: serverConfig.defaultModel, }; declare global { diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b9750f285..85df5b9a8 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -448,10 +448,20 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; const allModels = useAllModels(); - const models = useMemo( - () => allModels.filter((m) => m.available), - [allModels], - ); + const models = useMemo(() => { + const filteredModels = allModels.filter((m) => m.available); + const defaultModel = filteredModels.find((m) => m.isDefault); + + if (defaultModel) { + const arr = [ + defaultModel, + ...filteredModels.filter((m) => m !== defaultModel), + ]; + return arr; + } else { + return filteredModels; + } + }, [allModels]); const [showModelSelector, setShowModelSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); @@ -467,7 +477,10 @@ export function ChatActions(props: { // switch to first available model const isUnavaliableModel = !models.some((m) => m.name === currentModel); if (isUnavaliableModel && models.length > 0) { - const nextModel = models[0].name as ModelType; + // show next model to default model if exist + let nextModel: ModelType = ( + models.find((model) => model.isDefault) || models[0] + ).name; chatStore.updateCurrentSession( (session) => (session.mask.modelConfig.model = nextModel), ); @@ -1102,11 +1115,13 @@ function _Chat() { }; // eslint-disable-next-line react-hooks/exhaustive-deps }, []); - + const handlePaste = useCallback( async (event: React.ClipboardEvent) => { const currentModel = chatStore.currentSession().mask.modelConfig.model; - if(!isVisionModel(currentModel)){return;} + if (!isVisionModel(currentModel)) { + return; + } const items = (event.clipboardData || window.clipboardData).items; for (const item of items) { if (item.kind === "file" && item.type.startsWith("image/")) { diff --git a/app/config/server.ts b/app/config/server.ts index c27ef5e44..618112172 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -21,6 +21,7 @@ declare global { ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not DISABLE_FAST_LINK?: string; // disallow parse settings from url or not CUSTOM_MODELS?: string; // to control custom models + DEFAULT_MODEL?: string; // to cnntrol default model in every new chat window // azure only AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name} @@ -59,12 +60,14 @@ export const getServerSideConfig = () => { const disableGPT4 = !!process.env.DISABLE_GPT4; let customModels = process.env.CUSTOM_MODELS ?? ""; + let defaultModel = process.env.DEFAULT_MODEL ?? ""; if (disableGPT4) { if (customModels) customModels += ","; customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4")) .map((m) => "-" + m.name) .join(","); + if (defaultModel.startsWith("gpt-4")) defaultModel = ""; } const isAzure = !!process.env.AZURE_URL; @@ -116,6 +119,7 @@ export const getServerSideConfig = () => { hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY, disableFastLink: !!process.env.DISABLE_FAST_LINK, customModels, + defaultModel, whiteWebDevEndpoints, }; }; diff --git a/app/store/access.ts b/app/store/access.ts index 163666402..64909609e 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -8,6 +8,7 @@ import { getHeaders } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; +import { DEFAULT_CONFIG } from "./config"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -48,6 +49,7 @@ const DEFAULT_ACCESS_STATE = { disableGPT4: false, disableFastLink: false, customModels: "", + defaultModel: "", }; export const useAccessStore = createPersistStore( @@ -100,6 +102,13 @@ export const useAccessStore = createPersistStore( }, }) .then((res) => res.json()) + .then((res) => { + // Set default model from env request + let defaultModel = res.defaultModel ?? ""; + DEFAULT_CONFIG.modelConfig.model = + defaultModel !== "" ? defaultModel : "gpt-3.5-turbo"; + return res; + }) .then((res: DangerConfig) => { console.log("[Config] got config from server", res); set(() => ({ ...res })); diff --git a/app/utils/hooks.ts b/app/utils/hooks.ts index 35d1f53a4..55d5d4fca 100644 --- a/app/utils/hooks.ts +++ b/app/utils/hooks.ts @@ -1,14 +1,15 @@ import { useMemo } from "react"; import { useAccessStore, useAppConfig } from "../store"; -import { collectModels } from "./model"; +import { collectModels, collectModelsWithDefaultModel } from "./model"; export function useAllModels() { const accessStore = useAccessStore(); const configStore = useAppConfig(); const models = useMemo(() => { - return collectModels( + return collectModelsWithDefaultModel( configStore.models, [configStore.customModels, accessStore.customModels].join(","), + accessStore.defaultModel, ); }, [accessStore.customModels, configStore.customModels, configStore.models]); diff --git a/app/utils/model.ts b/app/utils/model.ts index 378fc498e..6477640aa 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,5 +1,11 @@ import { LLMModel } from "../client/api"; +const customProvider = (modelName: string) => ({ + id: modelName, + providerName: "", + providerType: "custom", +}); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -11,6 +17,7 @@ export function collectModelTable( name: string; displayName: string; provider?: LLMModel["provider"]; // Marked as optional + isDefault?: boolean; } > = {}; @@ -22,12 +29,6 @@ export function collectModelTable( }; }); - const customProvider = (modelName: string) => ({ - id: modelName, - providerName: "", - providerType: "custom", - }); - // server custom models customModels .split(",") @@ -52,6 +53,27 @@ export function collectModelTable( }; } }); + + return modelTable; +} + +export function collectModelTableWithDefaultModel( + models: readonly LLMModel[], + customModels: string, + defaultModel: string, +) { + let modelTable = collectModelTable(models, customModels); + if (defaultModel && defaultModel !== "") { + delete modelTable[defaultModel]; + modelTable[defaultModel] = { + name: defaultModel, + displayName: defaultModel, + available: true, + provider: + modelTable[defaultModel]?.provider ?? customProvider(defaultModel), + isDefault: true, + }; + } return modelTable; } @@ -67,3 +89,17 @@ export function collectModels( return allModels; } + +export function collectModelsWithDefaultModel( + models: readonly LLMModel[], + customModels: string, + defaultModel: string, +) { + const modelTable = collectModelTableWithDefaultModel( + models, + customModels, + defaultModel, + ); + const allModels = Object.values(modelTable); + return allModels; +}