From a925b424a8b02399d22ed05b3dc28631dbc03bc5 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Wed, 9 Oct 2024 13:42:25 +0800 Subject: [PATCH] fix compressModel, related #5426, fix #5606 #5603 #5575 --- app/store/chat.ts | 50 +++++++++++++++++++++++++++++++++++++++++---- app/store/config.ts | 8 ++++---- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/app/store/chat.ts b/app/store/chat.ts index 931cad768..98163981c 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -16,6 +16,8 @@ import { DEFAULT_SYSTEM_TEMPLATE, KnowledgeCutOffDate, StoreKey, + SUMMARIZE_MODEL, + GEMINI_SUMMARIZE_MODEL, } from "../constant"; import Locale, { getLang } from "../locales"; import { isDalle3, safeLocalStorage } from "../utils"; @@ -23,6 +25,8 @@ import { prettyObject } from "../utils/format"; import { createPersistStore } from "../utils/store"; import { estimateTokenLength } from "../utils/token"; import { ModelConfig, ModelType, useAppConfig } from "./config"; +import { useAccessStore } from "./access"; +import { collectModelsWithDefaultModel } from "../utils/model"; import { createEmptyMask, Mask } from "./mask"; const localStorage = safeLocalStorage(); @@ -103,6 +107,29 @@ function createEmptySession(): ChatSession { }; } +function getSummarizeModel(currentModel: string, providerName: string) { + // if it is using gpt-* models, force to use 4o-mini to summarize + if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) { + const configStore = useAppConfig.getState(); + const accessStore = useAccessStore.getState(); + const allModel = collectModelsWithDefaultModel( + configStore.models, + [configStore.customModels, accessStore.customModels].join(","), + accessStore.defaultModel, + ); + const summarizeModel = allModel.find( + (m) => m.name === SUMMARIZE_MODEL && m.available, + ); + if (summarizeModel) { + return [summarizeModel.name, summarizeModel.providerName]; + } + } + if (currentModel.startsWith("gemini")) { + return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google]; + } + return [currentModel, providerName]; +} + function countMessages(msgs: ChatMessage[]) { return msgs.reduce( (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)), @@ -579,7 +606,13 @@ export const useChatStore = createPersistStore( return; } - const providerName = modelConfig.compressProviderName; + // if not config compressModel, then using getSummarizeModel + const [model, providerName] = modelConfig.compressModel + ? [modelConfig.compressModel, modelConfig.compressProviderName] + : getSummarizeModel( + session.mask.modelConfig.model, + session.mask.modelConfig.providerName, + ); const api: ClientApi = getClientApi(providerName); // remove error messages if any @@ -611,7 +644,7 @@ export const useChatStore = createPersistStore( api.llm.chat({ messages: topicMessages, config: { - model: modelConfig.compressModel, + model, stream: false, providerName, }, @@ -675,7 +708,8 @@ export const useChatStore = createPersistStore( config: { ...modelcfg, stream: true, - model: modelConfig.compressModel, + model, + providerName, }, onUpdate(message) { session.memoryPrompt = message; @@ -728,7 +762,7 @@ export const useChatStore = createPersistStore( }, { name: StoreKey.Chat, - version: 3.2, + version: 3.3, migrate(persistedState, version) { const state = persistedState as any; const newState = JSON.parse( @@ -784,6 +818,14 @@ export const useChatStore = createPersistStore( config.modelConfig.compressProviderName; }); } + // revert default summarize model for every session + if (version < 3.3) { + newState.sessions.forEach((s) => { + const config = useAppConfig.getState(); + s.mask.modelConfig.compressModel = undefined; + s.mask.modelConfig.compressProviderName = undefined; + }); + } return newState as any; }, diff --git a/app/store/config.ts b/app/store/config.ts index 3dcd4d86b..c52ee3915 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -71,8 +71,8 @@ export const DEFAULT_CONFIG = { sendMemory: true, historyMessageCount: 4, compressMessageLengthThreshold: 1000, - compressModel: "gpt-4o-mini" as ModelType, - compressProviderName: "OpenAI" as ServiceProvider, + compressModel: undefined, + compressProviderName: undefined, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, size: "1024x1024" as DalleSize, @@ -178,7 +178,7 @@ export const useAppConfig = createPersistStore( }), { name: StoreKey.Config, - version: 4, + version: 4.1, merge(persistedState, currentState) { const state = persistedState as ChatConfig | undefined; @@ -231,7 +231,7 @@ export const useAppConfig = createPersistStore( : config?.template ?? DEFAULT_INPUT_TEMPLATE; } - if (version < 4) { + if (version < 4.1) { state.modelConfig.compressModel = DEFAULT_CONFIG.modelConfig.compressModel; state.modelConfig.compressProviderName =