From b844045d231658b9e40fa0582936c6746e7a7ef4 Mon Sep 17 00:00:00 2001 From: ryanhex53 Date: Tue, 5 Nov 2024 07:44:12 +0000 Subject: [PATCH] Custom model names can include the `@` symbol by itself. To specify the model's provider, append it after the model name using `@` as before. This format supports cases like `google vertex ai` with a model name like `claude-3-5-sonnet@20240620`. For instance, `claude-3-5-sonnet@20240620@vertex-ai` will be split by `split(/@(?!.*@)/)` into: `[ 'claude-3-5-sonnet@20240620', 'vertex-ai' ]`, where the former is the model name and the latter is the custom provider. --- app/api/common.ts | 2 +- app/components/chat.tsx | 2 +- app/components/model-config.tsx | 6 ++++-- app/store/access.ts | 2 +- app/utils/model.ts | 8 ++++---- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/app/api/common.ts b/app/api/common.ts index b4c792d6f..322dedeed 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) { .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName)) .forEach((m) => { const [fullName, displayName] = m.split("="); - const [_, providerName] = fullName.split("@"); + const [_, providerName] = fullName.split(/@(?!.*@)/); if (providerName === "azure" && !displayName) { const [_, deployId] = (serverConfig?.azureUrl ?? "").split( "deployments/", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 3d5b6a4f2..2ff08253a 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -645,7 +645,7 @@ export function ChatActions(props: { onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; - const [model, providerName] = s[0].split("@"); + const [model, providerName] = s[0].split(/@(?!.*@)/); chatStore.updateCurrentSession((session) => { session.mask.modelConfig.model = model as ModelType; session.mask.modelConfig.providerName = diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index f2297e10b..0eac916eb 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -28,7 +28,8 @@ export function ModelConfigList(props: { value={value} align="left" onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = + e.currentTarget.value.split(/@(?!.*@)/); props.updateConfig((config) => { config.model = ModalConfigValidator.model(model); config.providerName = providerName as ServiceProvider; @@ -247,7 +248,8 @@ export function ModelConfigList(props: { aria-label={Locale.Settings.CompressModel.Title} value={compressModelValue} onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = + e.currentTarget.value.split(/@(?!.*@)/); props.updateConfig((config) => { config.compressModel = ModalConfigValidator.model(model); config.compressProviderName = providerName as ServiceProvider; diff --git a/app/store/access.ts b/app/store/access.ts index 3b0e6357b..4e2cb1603 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -226,7 +226,7 @@ export const useAccessStore = createPersistStore( .then((res) => { const defaultModel = res.defaultModel ?? ""; if (defaultModel !== "") { - const [model, providerName] = defaultModel.split("@"); + const [model, providerName] = defaultModel.split(/@(?!.*@)/); DEFAULT_CONFIG.modelConfig.model = model; DEFAULT_CONFIG.modelConfig.providerName = providerName; } diff --git a/app/utils/model.ts b/app/utils/model.ts index 0b62b53be..0b95713e1 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -79,10 +79,10 @@ export function collectModelTable( ); } else { // 1. find model by name, and set available value - const [customModelName, customProviderName] = name.split("@"); + const [customModelName, customProviderName] = name.split(/@(?!.*@)/); let count = 0; for (const fullName in modelTable) { - const [modelName, providerName] = fullName.split("@"); + const [modelName, providerName] = fullName.split(/@(?!.*@)/); if ( customModelName == modelName && (customProviderName === undefined || @@ -102,7 +102,7 @@ export function collectModelTable( } // 2. if model not exists, create new model with available value if (count === 0) { - let [customModelName, customProviderName] = name.split("@"); + let [customModelName, customProviderName] = name.split(/@(?!.*@)/); const provider = customProvider( customProviderName || customModelName, ); @@ -139,7 +139,7 @@ export function collectModelTableWithDefaultModel( for (const key of Object.keys(modelTable)) { if ( modelTable[key].available && - key.split("@").shift() == defaultModel + key.split(/@(?!.*@)/).shift() == defaultModel ) { modelTable[key].isDefault = true; break;