From b023a00445682fcb336fe231ffe7c667632c0d15 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 16:37:22 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(model):=20=E6=9B=B4?= =?UTF-8?q?=E6=94=B9=E5=8E=9F=E5=85=88=E7=9A=84=E5=AE=9E=E7=8E=B0=E6=96=B9?= =?UTF-8?q?=E6=B3=95=EF=BC=8C=E5=9C=A8=20collect=20table=20=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=90=8E=E9=9D=A2=E5=A2=9E=E5=8A=A0=E9=A2=9D=E5=A4=96?= =?UTF-8?q?=E7=9A=84=20sort=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 50 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 6b1485e32..b117b5eb6 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({ providerType: "custom", }); +const sortModelTable = ( + models: ReturnType, + rule: "custom-first" | "default-first", +) => + models.sort((a, b) => { + if (a.provider === undefined && b.provider === undefined) { + return 0; + } + + let aIsCustom = a.provider?.providerType === "custom"; + let bIsCustom = b.provider?.providerType === "custom"; + + if (aIsCustom === bIsCustom) { + return 0; + } + + if (aIsCustom) { + return rule === "custom-first" ? -1 : 1; + } else { + return rule === "custom-first" ? 1 : -1; + } + }); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -22,6 +45,15 @@ export function collectModelTable( } > = {}; + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + // server custom models customModels .split(",") @@ -80,15 +112,6 @@ export function collectModelTable( } }); - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - return modelTable; } @@ -126,7 +149,9 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); return allModels; } @@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel( customModels, defaultModel, ); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); + return allModels; }