feat(model): 增加 sorted 字段,并使用该字段对模型列表进行排序

1. 在 Model 和 Provider 类型中增加 sorted 字段(api.ts)
2. 默认模型在初始化的时候,自动设置默认 sorted 字段,从 1000 开始自增长(constant.ts)
3. 自定义模型更新的时候,自动分配 sorted 字段(model.ts)
This commit is contained in:
frostime 2024-08-05 19:43:32 +08:00
parent b023a00445
commit 150fc84b9b
3 changed files with 50 additions and 20 deletions

View File

@ -64,12 +64,14 @@ export interface LLMModel {
displayName?: string; displayName?: string;
available: boolean; available: boolean;
provider: LLMModelProvider; provider: LLMModelProvider;
sorted: number;
} }
export interface LLMModelProvider { export interface LLMModelProvider {
id: string; id: string;
providerName: string; providerName: string;
providerType: string; providerType: string;
sorted: number;
} }
export abstract class LLMApi { export abstract class LLMApi {

View File

@ -320,86 +320,105 @@ const tencentModels = [
const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]; const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];
let seq = 1000; // 内置的模型序号生成器从1000开始
export const DEFAULT_MODELS = [ export const DEFAULT_MODELS = [
...openaiModels.map((name) => ({ ...openaiModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++, // Global sequence sort(index)
provider: { provider: {
id: "openai", id: "openai",
providerName: "OpenAI", providerName: "OpenAI",
providerType: "openai", providerType: "openai",
sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致
}, },
})), })),
...openaiModels.map((name) => ({ ...openaiModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "azure", id: "azure",
providerName: "Azure", providerName: "Azure",
providerType: "azure", providerType: "azure",
sorted: 2,
}, },
})), })),
...googleModels.map((name) => ({ ...googleModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "google", id: "google",
providerName: "Google", providerName: "Google",
providerType: "google", providerType: "google",
sorted: 3,
}, },
})), })),
...anthropicModels.map((name) => ({ ...anthropicModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "anthropic", id: "anthropic",
providerName: "Anthropic", providerName: "Anthropic",
providerType: "anthropic", providerType: "anthropic",
sorted: 4,
}, },
})), })),
...baiduModels.map((name) => ({ ...baiduModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "baidu", id: "baidu",
providerName: "Baidu", providerName: "Baidu",
providerType: "baidu", providerType: "baidu",
sorted: 5,
}, },
})), })),
...bytedanceModels.map((name) => ({ ...bytedanceModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "bytedance", id: "bytedance",
providerName: "ByteDance", providerName: "ByteDance",
providerType: "bytedance", providerType: "bytedance",
sorted: 6,
}, },
})), })),
...alibabaModes.map((name) => ({ ...alibabaModes.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "alibaba", id: "alibaba",
providerName: "Alibaba", providerName: "Alibaba",
providerType: "alibaba", providerType: "alibaba",
sorted: 7,
}, },
})), })),
...tencentModels.map((name) => ({ ...tencentModels.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "tencent", id: "tencent",
providerName: "Tencent", providerName: "Tencent",
providerType: "tencent", providerType: "tencent",
sorted: 8,
}, },
})), })),
...moonshotModes.map((name) => ({ ...moonshotModes.map((name) => ({
name, name,
available: true, available: true,
sorted: seq++,
provider: { provider: {
id: "moonshot", id: "moonshot",
providerName: "Moonshot", providerName: "Moonshot",
providerType: "moonshot", providerType: "moonshot",
sorted: 9,
}, },
})), })),
] as const; ] as const;

View File

@ -1,32 +1,39 @@
import { DEFAULT_MODELS } from "../constant"; import { DEFAULT_MODELS } from "../constant";
import { LLMModel } from "../client/api"; import { LLMModel } from "../client/api";
const CustomSeq = {
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
cache: new Map<string, number>(),
next: (id: string) => {
if (CustomSeq.cache.has(id)) {
return CustomSeq.cache.get(id) as number;
} else {
let seq = CustomSeq.val++;
CustomSeq.cache.set(id, seq);
return seq;
}
},
};
const customProvider = (providerName: string) => ({ const customProvider = (providerName: string) => ({
id: providerName.toLowerCase(), id: providerName.toLowerCase(),
providerName: providerName, providerName: providerName,
providerType: "custom", providerType: "custom",
sorted: CustomSeq.next(providerName),
}); });
const sortModelTable = ( /**
models: ReturnType<typeof collectModels>, * Sorts an array of models based on specified rules.
rule: "custom-first" | "default-first", *
) => * First, sorted by provider; if the same, sorted by model
*/
const sortModelTable = (models: ReturnType<typeof collectModels>) =>
models.sort((a, b) => { models.sort((a, b) => {
if (a.provider === undefined && b.provider === undefined) { if (a.provider && b.provider) {
return 0; let cmp = a.provider.sorted - b.provider.sorted;
} return cmp === 0 ? a.sorted - b.sorted : cmp;
let aIsCustom = a.provider?.providerType === "custom";
let bIsCustom = b.provider?.providerType === "custom";
if (aIsCustom === bIsCustom) {
return 0;
}
if (aIsCustom) {
return rule === "custom-first" ? -1 : 1;
} else { } else {
return rule === "custom-first" ? 1 : -1; return a.sorted - b.sorted;
} }
}); });
@ -40,6 +47,7 @@ export function collectModelTable(
available: boolean; available: boolean;
name: string; name: string;
displayName: string; displayName: string;
sorted: number;
provider?: LLMModel["provider"]; // Marked as optional provider?: LLMModel["provider"]; // Marked as optional
isDefault?: boolean; isDefault?: boolean;
} }
@ -107,6 +115,7 @@ export function collectModelTable(
displayName: displayName || customModelName, displayName: displayName || customModelName,
available, available,
provider, // Use optional chaining provider, // Use optional chaining
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
}; };
} }
} }
@ -151,7 +160,7 @@ export function collectModels(
const modelTable = collectModelTable(models, customModels); const modelTable = collectModelTable(models, customModels);
let allModels = Object.values(modelTable); let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels, "custom-first"); allModels = sortModelTable(allModels);
return allModels; return allModels;
} }
@ -168,7 +177,7 @@ export function collectModelsWithDefaultModel(
); );
let allModels = Object.values(modelTable); let allModels = Object.values(modelTable);
allModels = sortModelTable(allModels, "custom-first"); allModels = sortModelTable(allModels);
return allModels; return allModels;
} }