ChatGPT-Next-Web/app/utils/model.ts

194 lines
5.4 KiB
TypeScript

import { DEFAULT_MODELS } from "../constant";
import { LLMModel } from "../client/api";
const CustomSeq = {
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
cache: new Map<string, number>(),
next: (id: string) => {
if (CustomSeq.cache.has(id)) {
return CustomSeq.cache.get(id) as number;
} else {
let seq = CustomSeq.val++;
CustomSeq.cache.set(id, seq);
return seq;
}
},
};
const customProvider = (providerName: string) => ({
id: providerName.toLowerCase(),
providerName: providerName,
providerType: "custom",
sorted: CustomSeq.next(providerName),
});
/**
* Sorts an array of models based on specified rules.
*
* First, sorted by provider; if the same, sorted by model
*/
const sortModelTable = (models: ReturnType<typeof collectModels>) =>
models.sort((a, b) => {
if (a.provider && b.provider) {
let cmp = a.provider.sorted - b.provider.sorted;
return cmp === 0 ? a.sorted - b.sorted : cmp;
} else {
return a.sorted - b.sorted;
}
});
export function collectModelTable(
models: readonly LLMModel[],
customModels: string,
) {
const modelTable: Record<
string,
{
available: boolean;
name: string;
displayName: string;
sorted: number;
provider?: LLMModel["provider"]; // Marked as optional
isDefault?: boolean;
}
> = {};
// default models
models.forEach((m) => {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
});
// server custom models
customModels
.split(",")
.filter((v) => !!v && v.length > 0)
.forEach((m) => {
const available = !m.startsWith("-");
const nameConfig =
m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
let [name, displayName] = nameConfig.split("=");
// enable or disable all models
if (name === "all") {
Object.values(modelTable).forEach(
(model) => (model.available = available),
);
} else {
// 1. find model by name, and set available value
const [customModelName, customProviderName] = name.split("@");
let count = 0;
for (const fullName in modelTable) {
const [modelName, providerName] = fullName.split("@");
if (
customModelName == modelName &&
(customProviderName === undefined ||
customProviderName === providerName)
) {
count += 1;
modelTable[fullName]["available"] = available;
// swap name and displayName for bytedance
if (providerName === "bytedance") {
[name, displayName] = [displayName, modelName];
modelTable[fullName]["name"] = name;
}
if (displayName) {
modelTable[fullName]["displayName"] = displayName;
}
}
}
// 2. if model not exists, create new model with available value
if (count === 0) {
let [customModelName, customProviderName] = name.split("@");
const provider = customProvider(
customProviderName || customModelName,
);
// swap name and displayName for bytedance
if (displayName && provider.providerName == "ByteDance") {
[customModelName, displayName] = [displayName, customModelName];
}
modelTable[`${customModelName}@${provider?.id}`] = {
name: customModelName,
displayName: displayName || customModelName,
available,
provider, // Use optional chaining
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
};
}
}
});
return modelTable;
}
export function collectModelTableWithDefaultModel(
models: readonly LLMModel[],
customModels: string,
defaultModel: string,
) {
let modelTable = collectModelTable(models, customModels);
if (defaultModel && defaultModel !== "") {
if (defaultModel.includes("@")) {
if (defaultModel in modelTable) {
modelTable[defaultModel].isDefault = true;
}
} else {
for (const key of Object.keys(modelTable)) {
if (
modelTable[key].available &&
key.split("@").shift() == defaultModel
) {
modelTable[key].isDefault = true;
break;
}
}
}
}
return modelTable;
}
/**
* Generate full model table.
*/
export function collectModels(
models: readonly LLMModel[],
customModels: string,
) {
const modelTable = collectModelTable(models, customModels);
let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels);
return allModels;
}
export function collectModelsWithDefaultModel(
models: readonly LLMModel[],
customModels: string,
defaultModel: string,
) {
const modelTable = collectModelTableWithDefaultModel(
models,
customModels,
defaultModel,
);
let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels);
return allModels;
}
export function isModelAvailableInServer(
customModels: string,
modelName: string,
providerName: string,
) {
const fullName = `${modelName}@${providerName}`;
const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
return modelTable[fullName]?.available === false;
}