🔨 refactor(model): 更改原先的实现方法,在 collect table 函数后面增加额外的 sort 处理

This commit is contained in:
frostime 2024-08-05 16:37:22 +08:00
parent 8a4b8a84d6
commit b023a00445
1 changed files with 39 additions and 11 deletions

View File

@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({
providerType: "custom", providerType: "custom",
}); });
const sortModelTable = (
models: ReturnType<typeof collectModels>,
rule: "custom-first" | "default-first",
) =>
models.sort((a, b) => {
if (a.provider === undefined && b.provider === undefined) {
return 0;
}
let aIsCustom = a.provider?.providerType === "custom";
let bIsCustom = b.provider?.providerType === "custom";
if (aIsCustom === bIsCustom) {
return 0;
}
if (aIsCustom) {
return rule === "custom-first" ? -1 : 1;
} else {
return rule === "custom-first" ? 1 : -1;
}
});
export function collectModelTable( export function collectModelTable(
models: readonly LLMModel[], models: readonly LLMModel[],
customModels: string, customModels: string,
@ -22,6 +45,15 @@ export function collectModelTable(
} }
> = {}; > = {};
// default models
models.forEach((m) => {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
});
// server custom models // server custom models
customModels customModels
.split(",") .split(",")
@ -80,15 +112,6 @@ export function collectModelTable(
} }
}); });
// default models
models.forEach((m) => {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
});
return modelTable; return modelTable;
} }
@ -126,7 +149,9 @@ export function collectModels(
customModels: string, customModels: string,
) { ) {
const modelTable = collectModelTable(models, customModels); const modelTable = collectModelTable(models, customModels);
const allModels = Object.values(modelTable); let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels, "custom-first");
return allModels; return allModels;
} }
@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel(
customModels, customModels,
defaultModel, defaultModel,
); );
const allModels = Object.values(modelTable); let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels, "custom-first");
return allModels; return allModels;
} }