diff --git a/README.md b/README.md index dd2c5b1ee..fce62ba37 100644 --- a/README.md +++ b/README.md @@ -301,6 +301,14 @@ iflytek Api Key. iflytek Api Secret. +### `CHATGLM_API_KEY` (optional) + +ChatGLM Api Key. + +### `CHATGLM_URL` (optional) + +ChatGLM Api Url. + ### `HIDE_USER_API_KEY` (optional) > Default: Empty diff --git a/README_CN.md b/README_CN.md index ccdcf28ff..d4da8b9da 100644 --- a/README_CN.md +++ b/README_CN.md @@ -184,6 +184,13 @@ ByteDance Api Url. 讯飞星火Api Secret. +### `CHATGLM_API_KEY` (可选) + +ChatGLM Api Key. + +### `CHATGLM_URL` (可选) + +ChatGLM Api Url. ### `HIDE_USER_API_KEY` (可选) diff --git a/app/api/common.ts b/app/api/common.ts index b4c792d6f..495a12ccd 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,8 +1,8 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; import { OPENAI_BASE_URL, ServiceProvider } from "../constant"; -import { isModelAvailableInServer } from "../utils/model"; import { cloudflareAIGatewayUrl } from "../utils/cloudflare"; +import { getModelProvider, isModelAvailableInServer } from "../utils/model"; const serverConfig = getServerSideConfig(); @@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) { .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName)) .forEach((m) => { const [fullName, displayName] = m.split("="); - const [_, providerName] = fullName.split("@"); + const [_, providerName] = getModelProvider(fullName); if (providerName === "azure" && !displayName) { const [_, deployId] = (serverConfig?.azureUrl ?? "").split( "deployments/", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 3d5b6a4f2..82d6c6e39 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio"; import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts"; import { isEmpty } from "lodash-es"; +import { getModelProvider } from "../utils/model"; const localStorage = safeLocalStorage(); @@ -148,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) { text={Locale.Chat.Config.Reset} onClick={async () => { if (await showConfirm(Locale.Memory.ResetConfirm)) { - chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.memoryPrompt = ""), ); } @@ -173,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) { updateMask={(updater) => { const mask = { ...session.mask }; updater(mask); - chatStore.updateCurrentSession((session) => (session.mask = mask)); + chatStore.updateTargetSession( + session, + (session) => (session.mask = mask), + ); }} shouldSyncFromGlobal extraListItems={ @@ -345,12 +350,14 @@ export function PromptHints(props: { function ClearContextDivider() { const chatStore = useChatStore(); + const session = chatStore.currentSession(); return (
- chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.clearContextIndex = undefined), ) } @@ -460,6 +467,7 @@ export function ChatActions(props: { const navigate = useNavigate(); const chatStore = useChatStore(); const pluginStore = usePluginStore(); + const session = chatStore.currentSession(); // switch themes const theme = config.theme; @@ -476,10 +484,9 @@ export function ChatActions(props: { const stopAll = () => ChatControllerPool.stopAll(); // switch model - const currentModel = chatStore.currentSession().mask.modelConfig.model; + const currentModel = session.mask.modelConfig.model; const currentProviderName = - chatStore.currentSession().mask.modelConfig?.providerName || - ServiceProvider.OpenAI; + session.mask.modelConfig?.providerName || ServiceProvider.OpenAI; const allModels = useAllModels(); const models = useMemo(() => { const filteredModels = allModels.filter((m) => m.available); @@ -513,12 +520,9 @@ export function ChatActions(props: { const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const dalle3Qualitys: DalleQuality[] = ["standard", "hd"]; const dalle3Styles: DalleStyle[] = ["vivid", "natural"]; - const currentSize = - chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; - const currentQuality = - chatStore.currentSession().mask.modelConfig?.quality ?? "standard"; - const currentStyle = - chatStore.currentSession().mask.modelConfig?.style ?? "vivid"; + const currentSize = session.mask.modelConfig?.size ?? "1024x1024"; + const currentQuality = session.mask.modelConfig?.quality ?? "standard"; + const currentStyle = session.mask.modelConfig?.style ?? "vivid"; const isMobileScreen = useMobileScreen(); @@ -536,7 +540,7 @@ export function ChatActions(props: { if (isUnavailableModel && models.length > 0) { // show next model to default model if exist let nextModel = models.find((model) => model.isDefault) || models[0]; - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.model = nextModel.name; session.mask.modelConfig.providerName = nextModel?.provider ?.providerName as ServiceProvider; @@ -547,7 +551,7 @@ export function ChatActions(props: { : nextModel.name, ); } - }, [chatStore, currentModel, models]); + }, [chatStore, currentModel, models, session]); return (
@@ -614,7 +618,7 @@ export function ChatActions(props: { text={Locale.Chat.InputActions.Clear} icon={} onClick={() => { - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { if (session.clearContextIndex === session.messages.length) { session.clearContextIndex = undefined; } else { @@ -645,8 +649,8 @@ export function ChatActions(props: { onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; - const [model, providerName] = s[0].split("@"); - chatStore.updateCurrentSession((session) => { + const [model, providerName] = getModelProvider(s[0]); + chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.model = model as ModelType; session.mask.modelConfig.providerName = providerName as ServiceProvider; @@ -684,7 +688,7 @@ export function ChatActions(props: { onSelection={(s) => { if (s.length === 0) return; const size = s[0]; - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.size = size; }); showToast(size); @@ -711,7 +715,7 @@ export function ChatActions(props: { onSelection={(q) => { if (q.length === 0) return; const quality = q[0]; - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.quality = quality; }); showToast(quality); @@ -738,7 +742,7 @@ export function ChatActions(props: { onSelection={(s) => { if (s.length === 0) return; const style = s[0]; - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.style = style; }); showToast(style); @@ -769,7 +773,7 @@ export function ChatActions(props: { }))} onClose={() => setShowPluginSelector(false)} onSelection={(s) => { - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { session.mask.plugin = s as string[]; }); }} @@ -812,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) { icon={} key="ok" onClick={() => { - chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.messages = messages), ); props.onClose(); @@ -829,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) { type="text" value={session.topic} onInput={(e) => - chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.topic = e.currentTarget.value), ) } @@ -990,7 +996,8 @@ function _Chat() { prev: () => chatStore.nextSession(-1), next: () => chatStore.nextSession(1), clear: () => - chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.clearContextIndex = session.messages.length), ), fork: () => chatStore.forkSession(), @@ -1061,7 +1068,7 @@ function _Chat() { }; useEffect(() => { - chatStore.updateCurrentSession((session) => { + chatStore.updateTargetSession(session, (session) => { const stopTiming = Date.now() - REQUEST_TIMEOUT_MS; session.messages.forEach((m) => { // check if should stop all stale messages @@ -1087,7 +1094,7 @@ function _Chat() { } }); // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + }, [session]); // check if should send message const onInputKeyDown = (e: React.KeyboardEvent) => { @@ -1118,7 +1125,8 @@ function _Chat() { }; const deleteMessage = (msgId?: string) => { - chatStore.updateCurrentSession( + chatStore.updateTargetSession( + session, (session) => (session.messages = session.messages.filter((m) => m.id !== msgId)), ); @@ -1185,7 +1193,7 @@ function _Chat() { }; const onPinMessage = (message: ChatMessage) => { - chatStore.updateCurrentSession((session) => + chatStore.updateTargetSession(session, (session) => session.mask.context.push(message), ); @@ -1607,7 +1615,7 @@ function _Chat() { title={Locale.Chat.Actions.RefreshTitle} onClick={() => { showToast(Locale.Chat.Actions.RefreshToast); - chatStore.summarizeSession(true); + chatStore.summarizeSession(true, session); }} />
@@ -1711,14 +1719,17 @@ function _Chat() { }); } } - chatStore.updateCurrentSession((session) => { - const m = session.mask.context - .concat(session.messages) - .find((m) => m.id === message.id); - if (m) { - m.content = newContent; - } - }); + chatStore.updateTargetSession( + session, + (session) => { + const m = session.mask.context + .concat(session.messages) + .find((m) => m.id === message.id); + if (m) { + m.content = newContent; + } + }, + ); }} >
diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index f2297e10b..e845bfeac 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib"; import { useAllModels } from "../utils/hooks"; import { groupBy } from "lodash-es"; import styles from "./model-config.module.scss"; +import { getModelProvider } from "../utils/model"; export function ModelConfigList(props: { modelConfig: ModelConfig; @@ -28,7 +29,9 @@ export function ModelConfigList(props: { value={value} align="left" onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = getModelProvider( + e.currentTarget.value, + ); props.updateConfig((config) => { config.model = ModalConfigValidator.model(model); config.providerName = providerName as ServiceProvider; @@ -247,7 +250,9 @@ export function ModelConfigList(props: { aria-label={Locale.Settings.CompressModel.Title} value={compressModelValue} onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = getModelProvider( + e.currentTarget.value, + ); props.updateConfig((config) => { config.compressModel = ModalConfigValidator.model(model); config.compressProviderName = providerName as ServiceProvider; diff --git a/app/constant.ts b/app/constant.ts index d10e624cb..296878bd7 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -232,7 +232,7 @@ export const XAI = { export const ChatGLM = { ExampleEndpoint: CHATGLM_BASE_URL, - ChatPath: "/api/paas/v4/chat/completions", + ChatPath: "api/paas/v4/chat/completions", }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -327,12 +327,13 @@ const anthropicModels = [ "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229", + "claude-3-opus-latest", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-latest", - "claude-3-opus-latest", + "claude-3-5-haiku-latest", ]; const baiduModels = [ diff --git a/app/store/access.ts b/app/store/access.ts index 3b0e6357b..4796b2fe8 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; import { DEFAULT_CONFIG } from "./config"; +import { getModelProvider } from "../utils/model"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore( .then((res) => { const defaultModel = res.defaultModel ?? ""; if (defaultModel !== "") { - const [model, providerName] = defaultModel.split("@"); + const [model, providerName] = getModelProvider(defaultModel); DEFAULT_CONFIG.modelConfig.model = model; - DEFAULT_CONFIG.modelConfig.providerName = providerName; + DEFAULT_CONFIG.modelConfig.providerName = providerName as any; } return res; diff --git a/app/store/chat.ts b/app/store/chat.ts index 1bf2e1367..af4993d12 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -352,13 +352,13 @@ export const useChatStore = createPersistStore( return session; }, - onNewMessage(message: ChatMessage) { - get().updateCurrentSession((session) => { + onNewMessage(message: ChatMessage, targetSession: ChatSession) { + get().updateTargetSession(targetSession, (session) => { session.messages = session.messages.concat(); session.lastUpdate = Date.now(); }); - get().updateStat(message); - get().summarizeSession(); + get().updateStat(message, targetSession); + get().summarizeSession(false, targetSession); }, async onUserInput(content: string, attachImages?: string[]) { @@ -396,10 +396,10 @@ export const useChatStore = createPersistStore( // get recent messages const recentMessages = get().getMessagesWithMemory(); const sendMessages = recentMessages.concat(userMessage); - const messageIndex = get().currentSession().messages.length + 1; + const messageIndex = session.messages.length + 1; // save user's and bot's message - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { const savedUserMessage = { ...userMessage, content: mContent, @@ -420,7 +420,7 @@ export const useChatStore = createPersistStore( if (message) { botMessage.content = message; } - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { session.messages = session.messages.concat(); }); }, @@ -428,13 +428,14 @@ export const useChatStore = createPersistStore( botMessage.streaming = false; if (message) { botMessage.content = message; - get().onNewMessage(botMessage); + botMessage.date = new Date().toLocaleString(); + get().onNewMessage(botMessage, session); } ChatControllerPool.remove(session.id, botMessage.id); }, onBeforeTool(tool: ChatMessageTool) { (botMessage.tools = botMessage?.tools || []).push(tool); - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { session.messages = session.messages.concat(); }); }, @@ -444,7 +445,7 @@ export const useChatStore = createPersistStore( tools[i] = { ...tool }; } }); - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { session.messages = session.messages.concat(); }); }, @@ -459,7 +460,7 @@ export const useChatStore = createPersistStore( botMessage.streaming = false; userMessage.isError = !isAborted; botMessage.isError = !isAborted; - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { session.messages = session.messages.concat(); }); ChatControllerPool.remove( @@ -591,16 +592,19 @@ export const useChatStore = createPersistStore( set(() => ({ sessions })); }, - resetSession() { - get().updateCurrentSession((session) => { + resetSession(session: ChatSession) { + get().updateTargetSession(session, (session) => { session.messages = []; session.memoryPrompt = ""; }); }, - summarizeSession(refreshTitle: boolean = false) { + summarizeSession( + refreshTitle: boolean = false, + targetSession: ChatSession, + ) { const config = useAppConfig.getState(); - const session = get().currentSession(); + const session = targetSession; const modelConfig = session.mask.modelConfig; // skip summarize when using dalle3? if (isDalle3(modelConfig.model)) { @@ -651,7 +655,8 @@ export const useChatStore = createPersistStore( }, onFinish(message, responseRes) { if (responseRes?.status === 200) { - get().updateCurrentSession( + get().updateTargetSession( + session, (session) => (session.topic = message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC), @@ -719,7 +724,7 @@ export const useChatStore = createPersistStore( onFinish(message, responseRes) { if (responseRes?.status === 200) { console.log("[Memory] ", message); - get().updateCurrentSession((session) => { + get().updateTargetSession(session, (session) => { session.lastSummarizeIndex = lastSummarizeIndex; session.memoryPrompt = message; // Update the memory prompt for stored it in local storage }); @@ -732,20 +737,22 @@ export const useChatStore = createPersistStore( } }, - updateStat(message: ChatMessage) { - get().updateCurrentSession((session) => { + updateStat(message: ChatMessage, session: ChatSession) { + get().updateTargetSession(session, (session) => { session.stat.charCount += message.content.length; // TODO: should update chat count and word count }); }, - - updateCurrentSession(updater: (session: ChatSession) => void) { + updateTargetSession( + targetSession: ChatSession, + updater: (session: ChatSession) => void, + ) { const sessions = get().sessions; - const index = get().currentSessionIndex; + const index = sessions.findIndex((s) => s.id === targetSession.id); + if (index < 0) return; updater(sessions[index]); set(() => ({ sessions })); }, - async clearAllData() { await indexedDBStorage.clear(); localStorage.clear(); diff --git a/app/utils/model.ts b/app/utils/model.ts index 0b62b53be..a1b7df1b6 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType) => } }); +/** + * get model name and provider from a formatted string, + * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` + * @param modelWithProvider model name with provider separated by last `@` char, + * @returns [model, provider] tuple, if no `@` char found, provider is undefined + */ +export function getModelProvider(modelWithProvider: string): [string, string?] { + const [model, provider] = modelWithProvider.split(/@(?!.*@)/); + return [model, provider]; +} + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -79,10 +90,10 @@ export function collectModelTable( ); } else { // 1. find model by name, and set available value - const [customModelName, customProviderName] = name.split("@"); + const [customModelName, customProviderName] = getModelProvider(name); let count = 0; for (const fullName in modelTable) { - const [modelName, providerName] = fullName.split("@"); + const [modelName, providerName] = getModelProvider(fullName); if ( customModelName == modelName && (customProviderName === undefined || @@ -102,7 +113,7 @@ export function collectModelTable( } // 2. if model not exists, create new model with available value if (count === 0) { - let [customModelName, customProviderName] = name.split("@"); + let [customModelName, customProviderName] = getModelProvider(name); const provider = customProvider( customProviderName || customModelName, ); @@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel( for (const key of Object.keys(modelTable)) { if ( modelTable[key].available && - key.split("@").shift() == defaultModel + getModelProvider(key)[0] == defaultModel ) { modelTable[key].isDefault = true; break; diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 415825b13..7e08d9070 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -9,7 +9,7 @@ }, "package": { "productName": "NextChat", - "version": "2.15.6" + "version": "2.15.7" }, "tauri": { "allowlist": { diff --git a/test/model-provider.test.ts b/test/model-provider.test.ts new file mode 100644 index 000000000..41f14be02 --- /dev/null +++ b/test/model-provider.test.ts @@ -0,0 +1,31 @@ +import { getModelProvider } from "../app/utils/model"; + +describe("getModelProvider", () => { + test("should return model and provider when input contains '@'", () => { + const input = "model@provider"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model"); + expect(provider).toBe("provider"); + }); + + test("should return model and undefined provider when input does not contain '@'", () => { + const input = "model"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model"); + expect(provider).toBeUndefined(); + }); + + test("should handle multiple '@' characters correctly", () => { + const input = "model@provider@extra"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model@provider"); + expect(provider).toBe("extra"); + }); + + test("should return empty strings when input is empty", () => { + const input = ""; + const [model, provider] = getModelProvider(input); + expect(model).toBe(""); + expect(provider).toBeUndefined(); + }); +});