feat: Support a way to define default model by adding DEFAULT_MODEL env.
This commit is contained in:
parent
9b2cb1e1c3
commit
c96e4b7966
|
@ -13,6 +13,7 @@ const DANGER_CONFIG = {
|
|||
hideBalanceQuery: serverConfig.hideBalanceQuery,
|
||||
disableFastLink: serverConfig.disableFastLink,
|
||||
customModels: serverConfig.customModels,
|
||||
defaultModel: serverConfig.defaultModel,
|
||||
};
|
||||
|
||||
declare global {
|
||||
|
|
|
@ -448,10 +448,20 @@ export function ChatActions(props: {
|
|||
// switch model
|
||||
const currentModel = chatStore.currentSession().mask.modelConfig.model;
|
||||
const allModels = useAllModels();
|
||||
const models = useMemo(
|
||||
() => allModels.filter((m) => m.available),
|
||||
[allModels],
|
||||
);
|
||||
const models = useMemo(() => {
|
||||
const filteredModels = allModels.filter((m) => m.available);
|
||||
const defaultModel = filteredModels.find((m) => m.isDefault);
|
||||
|
||||
if (defaultModel) {
|
||||
const arr = [
|
||||
defaultModel,
|
||||
...filteredModels.filter((m) => m !== defaultModel),
|
||||
];
|
||||
return arr;
|
||||
} else {
|
||||
return filteredModels;
|
||||
}
|
||||
}, [allModels]);
|
||||
const [showModelSelector, setShowModelSelector] = useState(false);
|
||||
const [showUploadImage, setShowUploadImage] = useState(false);
|
||||
|
||||
|
@ -467,7 +477,10 @@ export function ChatActions(props: {
|
|||
// switch to first available model
|
||||
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
|
||||
if (isUnavaliableModel && models.length > 0) {
|
||||
const nextModel = models[0].name as ModelType;
|
||||
// show next model to default model if exist
|
||||
let nextModel: ModelType = (
|
||||
models.find((model) => model.isDefault) || models[0]
|
||||
).name;
|
||||
chatStore.updateCurrentSession(
|
||||
(session) => (session.mask.modelConfig.model = nextModel),
|
||||
);
|
||||
|
@ -1106,7 +1119,9 @@ function _Chat() {
|
|||
const handlePaste = useCallback(
|
||||
async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
|
||||
const currentModel = chatStore.currentSession().mask.modelConfig.model;
|
||||
if(!isVisionModel(currentModel)){return;}
|
||||
if (!isVisionModel(currentModel)) {
|
||||
return;
|
||||
}
|
||||
const items = (event.clipboardData || window.clipboardData).items;
|
||||
for (const item of items) {
|
||||
if (item.kind === "file" && item.type.startsWith("image/")) {
|
||||
|
|
|
@ -21,6 +21,7 @@ declare global {
|
|||
ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
|
||||
DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
|
||||
CUSTOM_MODELS?: string; // to control custom models
|
||||
DEFAULT_MODEL?: string; // to cnntrol default model in every new chat window
|
||||
|
||||
// azure only
|
||||
AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name}
|
||||
|
@ -59,12 +60,14 @@ export const getServerSideConfig = () => {
|
|||
|
||||
const disableGPT4 = !!process.env.DISABLE_GPT4;
|
||||
let customModels = process.env.CUSTOM_MODELS ?? "";
|
||||
let defaultModel = process.env.DEFAULT_MODEL ?? "";
|
||||
|
||||
if (disableGPT4) {
|
||||
if (customModels) customModels += ",";
|
||||
customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
|
||||
.map((m) => "-" + m.name)
|
||||
.join(",");
|
||||
if (defaultModel.startsWith("gpt-4")) defaultModel = "";
|
||||
}
|
||||
|
||||
const isAzure = !!process.env.AZURE_URL;
|
||||
|
@ -116,6 +119,7 @@ export const getServerSideConfig = () => {
|
|||
hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
|
||||
disableFastLink: !!process.env.DISABLE_FAST_LINK,
|
||||
customModels,
|
||||
defaultModel,
|
||||
whiteWebDevEndpoints,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -8,6 +8,7 @@ import { getHeaders } from "../client/api";
|
|||
import { getClientConfig } from "../config/client";
|
||||
import { createPersistStore } from "../utils/store";
|
||||
import { ensure } from "../utils/clone";
|
||||
import { DEFAULT_CONFIG } from "./config";
|
||||
|
||||
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
|
||||
|
||||
|
@ -48,6 +49,7 @@ const DEFAULT_ACCESS_STATE = {
|
|||
disableGPT4: false,
|
||||
disableFastLink: false,
|
||||
customModels: "",
|
||||
defaultModel: "",
|
||||
};
|
||||
|
||||
export const useAccessStore = createPersistStore(
|
||||
|
@ -100,6 +102,13 @@ export const useAccessStore = createPersistStore(
|
|||
},
|
||||
})
|
||||
.then((res) => res.json())
|
||||
.then((res) => {
|
||||
// Set default model from env request
|
||||
let defaultModel = res.defaultModel ?? "";
|
||||
DEFAULT_CONFIG.modelConfig.model =
|
||||
defaultModel !== "" ? defaultModel : "gpt-3.5-turbo";
|
||||
return res;
|
||||
})
|
||||
.then((res: DangerConfig) => {
|
||||
console.log("[Config] got config from server", res);
|
||||
set(() => ({ ...res }));
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import { useMemo } from "react";
|
||||
import { useAccessStore, useAppConfig } from "../store";
|
||||
import { collectModels } from "./model";
|
||||
import { collectModels, collectModelsWithDefaultModel } from "./model";
|
||||
|
||||
export function useAllModels() {
|
||||
const accessStore = useAccessStore();
|
||||
const configStore = useAppConfig();
|
||||
const models = useMemo(() => {
|
||||
return collectModels(
|
||||
return collectModelsWithDefaultModel(
|
||||
configStore.models,
|
||||
[configStore.customModels, accessStore.customModels].join(","),
|
||||
accessStore.defaultModel,
|
||||
);
|
||||
}, [accessStore.customModels, configStore.customModels, configStore.models]);
|
||||
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
import { LLMModel } from "../client/api";
|
||||
|
||||
const customProvider = (modelName: string) => ({
|
||||
id: modelName,
|
||||
providerName: "",
|
||||
providerType: "custom",
|
||||
});
|
||||
|
||||
export function collectModelTable(
|
||||
models: readonly LLMModel[],
|
||||
customModels: string,
|
||||
|
@ -11,6 +17,7 @@ export function collectModelTable(
|
|||
name: string;
|
||||
displayName: string;
|
||||
provider?: LLMModel["provider"]; // Marked as optional
|
||||
isDefault?: boolean;
|
||||
}
|
||||
> = {};
|
||||
|
||||
|
@ -22,12 +29,6 @@ export function collectModelTable(
|
|||
};
|
||||
});
|
||||
|
||||
const customProvider = (modelName: string) => ({
|
||||
id: modelName,
|
||||
providerName: "",
|
||||
providerType: "custom",
|
||||
});
|
||||
|
||||
// server custom models
|
||||
customModels
|
||||
.split(",")
|
||||
|
@ -52,6 +53,27 @@ export function collectModelTable(
|
|||
};
|
||||
}
|
||||
});
|
||||
|
||||
return modelTable;
|
||||
}
|
||||
|
||||
export function collectModelTableWithDefaultModel(
|
||||
models: readonly LLMModel[],
|
||||
customModels: string,
|
||||
defaultModel: string,
|
||||
) {
|
||||
let modelTable = collectModelTable(models, customModels);
|
||||
if (defaultModel && defaultModel !== "") {
|
||||
delete modelTable[defaultModel];
|
||||
modelTable[defaultModel] = {
|
||||
name: defaultModel,
|
||||
displayName: defaultModel,
|
||||
available: true,
|
||||
provider:
|
||||
modelTable[defaultModel]?.provider ?? customProvider(defaultModel),
|
||||
isDefault: true,
|
||||
};
|
||||
}
|
||||
return modelTable;
|
||||
}
|
||||
|
||||
|
@ -67,3 +89,17 @@ export function collectModels(
|
|||
|
||||
return allModels;
|
||||
}
|
||||
|
||||
export function collectModelsWithDefaultModel(
|
||||
models: readonly LLMModel[],
|
||||
customModels: string,
|
||||
defaultModel: string,
|
||||
) {
|
||||
const modelTable = collectModelTableWithDefaultModel(
|
||||
models,
|
||||
customModels,
|
||||
defaultModel,
|
||||
);
|
||||
const allModels = Object.values(modelTable);
|
||||
return allModels;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue