feat: Support a way to define default model by adding DEFAULT_MODEL env.

This commit is contained in:
Wayland Zhan
2024-04-19 06:57:15 +00:00
parent 9b2cb1e1c3
commit c96e4b7966
6 changed files with 81 additions and 15 deletions

View File

@@ -1,14 +1,15 @@
import { useMemo } from "react";
import { useAccessStore, useAppConfig } from "../store";
import { collectModels } from "./model";
import { collectModels, collectModelsWithDefaultModel } from "./model";
export function useAllModels() {
const accessStore = useAccessStore();
const configStore = useAppConfig();
const models = useMemo(() => {
return collectModels(
return collectModelsWithDefaultModel(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
}, [accessStore.customModels, configStore.customModels, configStore.models]);

View File

@@ -1,5 +1,11 @@
import { LLMModel } from "../client/api";
const customProvider = (modelName: string) => ({
id: modelName,
providerName: "",
providerType: "custom",
});
export function collectModelTable(
models: readonly LLMModel[],
customModels: string,
@@ -11,6 +17,7 @@ export function collectModelTable(
name: string;
displayName: string;
provider?: LLMModel["provider"]; // Marked as optional
isDefault?: boolean;
}
> = {};
@@ -22,12 +29,6 @@ export function collectModelTable(
};
});
const customProvider = (modelName: string) => ({
id: modelName,
providerName: "",
providerType: "custom",
});
// server custom models
customModels
.split(",")
@@ -52,6 +53,27 @@ export function collectModelTable(
};
}
});
return modelTable;
}
export function collectModelTableWithDefaultModel(
models: readonly LLMModel[],
customModels: string,
defaultModel: string,
) {
let modelTable = collectModelTable(models, customModels);
if (defaultModel && defaultModel !== "") {
delete modelTable[defaultModel];
modelTable[defaultModel] = {
name: defaultModel,
displayName: defaultModel,
available: true,
provider:
modelTable[defaultModel]?.provider ?? customProvider(defaultModel),
isDefault: true,
};
}
return modelTable;
}
@@ -67,3 +89,17 @@ export function collectModels(
return allModels;
}
export function collectModelsWithDefaultModel(
models: readonly LLMModel[],
customModels: string,
defaultModel: string,
) {
const modelTable = collectModelTableWithDefaultModel(
models,
customModels,
defaultModel,
);
const allModels = Object.values(modelTable);
return allModels;
}