merge main

This commit is contained in:
opchips 2024-11-06 15:08:18 +08:00
commit 6667ee1c7f
11 changed files with 155 additions and 73 deletions

View File

@ -301,6 +301,14 @@ iflytek Api Key.
iflytek Api Secret. iflytek Api Secret.
### `CHATGLM_API_KEY` (optional)
ChatGLM Api Key.
### `CHATGLM_URL` (optional)
ChatGLM Api Url.
### `HIDE_USER_API_KEY` (optional) ### `HIDE_USER_API_KEY` (optional)
> Default: Empty > Default: Empty

View File

@ -184,6 +184,13 @@ ByteDance Api Url.
讯飞星火Api Secret. 讯飞星火Api Secret.
### `CHATGLM_API_KEY` (可选)
ChatGLM Api Key.
### `CHATGLM_URL` (可选)
ChatGLM Api Url.
### `HIDE_USER_API_KEY` (可选) ### `HIDE_USER_API_KEY` (可选)

View File

@ -1,8 +1,8 @@
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "../config/server"; import { getServerSideConfig } from "../config/server";
import { OPENAI_BASE_URL, ServiceProvider } from "../constant"; import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { cloudflareAIGatewayUrl } from "../utils/cloudflare"; import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
import { getModelProvider, isModelAvailableInServer } from "../utils/model";
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName)) .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
.forEach((m) => { .forEach((m) => {
const [fullName, displayName] = m.split("="); const [fullName, displayName] = m.split("=");
const [_, providerName] = fullName.split("@"); const [_, providerName] = getModelProvider(fullName);
if (providerName === "azure" && !displayName) { if (providerName === "azure" && !displayName) {
const [_, deployId] = (serverConfig?.azureUrl ?? "").split( const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
"deployments/", "deployments/",

View File

@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts"; import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
import { isEmpty } from "lodash-es"; import { isEmpty } from "lodash-es";
import { getModelProvider } from "../utils/model";
const localStorage = safeLocalStorage(); const localStorage = safeLocalStorage();
@ -148,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) {
text={Locale.Chat.Config.Reset} text={Locale.Chat.Config.Reset}
onClick={async () => { onClick={async () => {
if (await showConfirm(Locale.Memory.ResetConfirm)) { if (await showConfirm(Locale.Memory.ResetConfirm)) {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.memoryPrompt = ""), (session) => (session.memoryPrompt = ""),
); );
} }
@ -173,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) {
updateMask={(updater) => { updateMask={(updater) => {
const mask = { ...session.mask }; const mask = { ...session.mask };
updater(mask); updater(mask);
chatStore.updateCurrentSession((session) => (session.mask = mask)); chatStore.updateTargetSession(
session,
(session) => (session.mask = mask),
);
}} }}
shouldSyncFromGlobal shouldSyncFromGlobal
extraListItems={ extraListItems={
@ -345,12 +350,14 @@ export function PromptHints(props: {
function ClearContextDivider() { function ClearContextDivider() {
const chatStore = useChatStore(); const chatStore = useChatStore();
const session = chatStore.currentSession();
return ( return (
<div <div
className={styles["clear-context"]} className={styles["clear-context"]}
onClick={() => onClick={() =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = undefined), (session) => (session.clearContextIndex = undefined),
) )
} }
@ -460,6 +467,7 @@ export function ChatActions(props: {
const navigate = useNavigate(); const navigate = useNavigate();
const chatStore = useChatStore(); const chatStore = useChatStore();
const pluginStore = usePluginStore(); const pluginStore = usePluginStore();
const session = chatStore.currentSession();
// switch themes // switch themes
const theme = config.theme; const theme = config.theme;
@ -476,10 +484,9 @@ export function ChatActions(props: {
const stopAll = () => ChatControllerPool.stopAll(); const stopAll = () => ChatControllerPool.stopAll();
// switch model // switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model; const currentModel = session.mask.modelConfig.model;
const currentProviderName = const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName || session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
ServiceProvider.OpenAI;
const allModels = useAllModels(); const allModels = useAllModels();
const models = useMemo(() => { const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available); const filteredModels = allModels.filter((m) => m.available);
@ -513,12 +520,9 @@ export function ChatActions(props: {
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"]; const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
const dalle3Styles: DalleStyle[] = ["vivid", "natural"]; const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
const currentSize = const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; const currentQuality = session.mask.modelConfig?.quality ?? "standard";
const currentQuality = const currentStyle = session.mask.modelConfig?.style ?? "vivid";
chatStore.currentSession().mask.modelConfig?.quality ?? "standard";
const currentStyle =
chatStore.currentSession().mask.modelConfig?.style ?? "vivid";
const isMobileScreen = useMobileScreen(); const isMobileScreen = useMobileScreen();
@ -536,7 +540,7 @@ export function ChatActions(props: {
if (isUnavailableModel && models.length > 0) { if (isUnavailableModel && models.length > 0) {
// show next model to default model if exist // show next model to default model if exist
let nextModel = models.find((model) => model.isDefault) || models[0]; let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = nextModel.name; session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider; ?.providerName as ServiceProvider;
@ -547,7 +551,7 @@ export function ChatActions(props: {
: nextModel.name, : nextModel.name,
); );
} }
}, [chatStore, currentModel, models]); }, [chatStore, currentModel, models, session]);
return ( return (
<div className={styles["chat-input-actions"]}> <div className={styles["chat-input-actions"]}>
@ -614,7 +618,7 @@ export function ChatActions(props: {
text={Locale.Chat.InputActions.Clear} text={Locale.Chat.InputActions.Clear}
icon={<BreakIcon />} icon={<BreakIcon />}
onClick={() => { onClick={() => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
if (session.clearContextIndex === session.messages.length) { if (session.clearContextIndex === session.messages.length) {
session.clearContextIndex = undefined; session.clearContextIndex = undefined;
} else { } else {
@ -645,8 +649,8 @@ export function ChatActions(props: {
onClose={() => setShowModelSelector(false)} onClose={() => setShowModelSelector(false)}
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const [model, providerName] = s[0].split("@"); const [model, providerName] = getModelProvider(s[0]);
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = model as ModelType; session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName = session.mask.modelConfig.providerName =
providerName as ServiceProvider; providerName as ServiceProvider;
@ -684,7 +688,7 @@ export function ChatActions(props: {
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const size = s[0]; const size = s[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.size = size; session.mask.modelConfig.size = size;
}); });
showToast(size); showToast(size);
@ -711,7 +715,7 @@ export function ChatActions(props: {
onSelection={(q) => { onSelection={(q) => {
if (q.length === 0) return; if (q.length === 0) return;
const quality = q[0]; const quality = q[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.quality = quality; session.mask.modelConfig.quality = quality;
}); });
showToast(quality); showToast(quality);
@ -738,7 +742,7 @@ export function ChatActions(props: {
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const style = s[0]; const style = s[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.style = style; session.mask.modelConfig.style = style;
}); });
showToast(style); showToast(style);
@ -769,7 +773,7 @@ export function ChatActions(props: {
}))} }))}
onClose={() => setShowPluginSelector(false)} onClose={() => setShowPluginSelector(false)}
onSelection={(s) => { onSelection={(s) => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.plugin = s as string[]; session.mask.plugin = s as string[];
}); });
}} }}
@ -812,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
icon={<ConfirmIcon />} icon={<ConfirmIcon />}
key="ok" key="ok"
onClick={() => { onClick={() => {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.messages = messages), (session) => (session.messages = messages),
); );
props.onClose(); props.onClose();
@ -829,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
type="text" type="text"
value={session.topic} value={session.topic}
onInput={(e) => onInput={(e) =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.topic = e.currentTarget.value), (session) => (session.topic = e.currentTarget.value),
) )
} }
@ -990,7 +996,8 @@ function _Chat() {
prev: () => chatStore.nextSession(-1), prev: () => chatStore.nextSession(-1),
next: () => chatStore.nextSession(1), next: () => chatStore.nextSession(1),
clear: () => clear: () =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = session.messages.length), (session) => (session.clearContextIndex = session.messages.length),
), ),
fork: () => chatStore.forkSession(), fork: () => chatStore.forkSession(),
@ -1061,7 +1068,7 @@ function _Chat() {
}; };
useEffect(() => { useEffect(() => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
const stopTiming = Date.now() - REQUEST_TIMEOUT_MS; const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
session.messages.forEach((m) => { session.messages.forEach((m) => {
// check if should stop all stale messages // check if should stop all stale messages
@ -1087,7 +1094,7 @@ function _Chat() {
} }
}); });
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []); }, [session]);
// check if should send message // check if should send message
const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => { const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
@ -1118,7 +1125,8 @@ function _Chat() {
}; };
const deleteMessage = (msgId?: string) => { const deleteMessage = (msgId?: string) => {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session) =>
(session.messages = session.messages.filter((m) => m.id !== msgId)), (session.messages = session.messages.filter((m) => m.id !== msgId)),
); );
@ -1185,7 +1193,7 @@ function _Chat() {
}; };
const onPinMessage = (message: ChatMessage) => { const onPinMessage = (message: ChatMessage) => {
chatStore.updateCurrentSession((session) => chatStore.updateTargetSession(session, (session) =>
session.mask.context.push(message), session.mask.context.push(message),
); );
@ -1607,7 +1615,7 @@ function _Chat() {
title={Locale.Chat.Actions.RefreshTitle} title={Locale.Chat.Actions.RefreshTitle}
onClick={() => { onClick={() => {
showToast(Locale.Chat.Actions.RefreshToast); showToast(Locale.Chat.Actions.RefreshToast);
chatStore.summarizeSession(true); chatStore.summarizeSession(true, session);
}} }}
/> />
</div> </div>
@ -1711,14 +1719,17 @@ function _Chat() {
}); });
} }
} }
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(
const m = session.mask.context session,
.concat(session.messages) (session) => {
.find((m) => m.id === message.id); const m = session.mask.context
if (m) { .concat(session.messages)
m.content = newContent; .find((m) => m.id === message.id);
} if (m) {
}); m.content = newContent;
}
},
);
}} }}
></IconButton> ></IconButton>
</div> </div>

View File

@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
import { useAllModels } from "../utils/hooks"; import { useAllModels } from "../utils/hooks";
import { groupBy } from "lodash-es"; import { groupBy } from "lodash-es";
import styles from "./model-config.module.scss"; import styles from "./model-config.module.scss";
import { getModelProvider } from "../utils/model";
export function ModelConfigList(props: { export function ModelConfigList(props: {
modelConfig: ModelConfig; modelConfig: ModelConfig;
@ -28,7 +29,9 @@ export function ModelConfigList(props: {
value={value} value={value}
align="left" align="left"
onChange={(e) => { onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@"); const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => { props.updateConfig((config) => {
config.model = ModalConfigValidator.model(model); config.model = ModalConfigValidator.model(model);
config.providerName = providerName as ServiceProvider; config.providerName = providerName as ServiceProvider;
@ -247,7 +250,9 @@ export function ModelConfigList(props: {
aria-label={Locale.Settings.CompressModel.Title} aria-label={Locale.Settings.CompressModel.Title}
value={compressModelValue} value={compressModelValue}
onChange={(e) => { onChange={(e) => {
const [model, providerName] = e.currentTarget.value.split("@"); const [model, providerName] = getModelProvider(
e.currentTarget.value,
);
props.updateConfig((config) => { props.updateConfig((config) => {
config.compressModel = ModalConfigValidator.model(model); config.compressModel = ModalConfigValidator.model(model);
config.compressProviderName = providerName as ServiceProvider; config.compressProviderName = providerName as ServiceProvider;

View File

@ -232,7 +232,7 @@ export const XAI = {
export const ChatGLM = { export const ChatGLM = {
ExampleEndpoint: CHATGLM_BASE_URL, ExampleEndpoint: CHATGLM_BASE_URL,
ChatPath: "/api/paas/v4/chat/completions", ChatPath: "api/paas/v4/chat/completions",
}; };
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@ -327,12 +327,13 @@ const anthropicModels = [
"claude-2.1", "claude-2.1",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-opus-latest",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-5-haiku-20241022", "claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest",
"claude-3-opus-latest", "claude-3-5-haiku-latest",
]; ];
const baiduModels = [ const baiduModels = [

View File

@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
import { ensure } from "../utils/clone"; import { ensure } from "../utils/clone";
import { DEFAULT_CONFIG } from "./config"; import { DEFAULT_CONFIG } from "./config";
import { getModelProvider } from "../utils/model";
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
.then((res) => { .then((res) => {
const defaultModel = res.defaultModel ?? ""; const defaultModel = res.defaultModel ?? "";
if (defaultModel !== "") { if (defaultModel !== "") {
const [model, providerName] = defaultModel.split("@"); const [model, providerName] = getModelProvider(defaultModel);
DEFAULT_CONFIG.modelConfig.model = model; DEFAULT_CONFIG.modelConfig.model = model;
DEFAULT_CONFIG.modelConfig.providerName = providerName; DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
} }
return res; return res;

View File

@ -352,13 +352,13 @@ export const useChatStore = createPersistStore(
return session; return session;
}, },
onNewMessage(message: ChatMessage) { onNewMessage(message: ChatMessage, targetSession: ChatSession) {
get().updateCurrentSession((session) => { get().updateTargetSession(targetSession, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
session.lastUpdate = Date.now(); session.lastUpdate = Date.now();
}); });
get().updateStat(message); get().updateStat(message, targetSession);
get().summarizeSession(); get().summarizeSession(false, targetSession);
}, },
async onUserInput(content: string, attachImages?: string[]) { async onUserInput(content: string, attachImages?: string[]) {
@ -396,10 +396,10 @@ export const useChatStore = createPersistStore(
// get recent messages // get recent messages
const recentMessages = get().getMessagesWithMemory(); const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage); const sendMessages = recentMessages.concat(userMessage);
const messageIndex = get().currentSession().messages.length + 1; const messageIndex = session.messages.length + 1;
// save user's and bot's message // save user's and bot's message
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
const savedUserMessage = { const savedUserMessage = {
...userMessage, ...userMessage,
content: mContent, content: mContent,
@ -420,7 +420,7 @@ export const useChatStore = createPersistStore(
if (message) { if (message) {
botMessage.content = message; botMessage.content = message;
} }
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -428,13 +428,14 @@ export const useChatStore = createPersistStore(
botMessage.streaming = false; botMessage.streaming = false;
if (message) { if (message) {
botMessage.content = message; botMessage.content = message;
get().onNewMessage(botMessage); botMessage.date = new Date().toLocaleString();
get().onNewMessage(botMessage, session);
} }
ChatControllerPool.remove(session.id, botMessage.id); ChatControllerPool.remove(session.id, botMessage.id);
}, },
onBeforeTool(tool: ChatMessageTool) { onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool); (botMessage.tools = botMessage?.tools || []).push(tool);
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -444,7 +445,7 @@ export const useChatStore = createPersistStore(
tools[i] = { ...tool }; tools[i] = { ...tool };
} }
}); });
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -459,7 +460,7 @@ export const useChatStore = createPersistStore(
botMessage.streaming = false; botMessage.streaming = false;
userMessage.isError = !isAborted; userMessage.isError = !isAborted;
botMessage.isError = !isAborted; botMessage.isError = !isAborted;
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
ChatControllerPool.remove( ChatControllerPool.remove(
@ -591,16 +592,19 @@ export const useChatStore = createPersistStore(
set(() => ({ sessions })); set(() => ({ sessions }));
}, },
resetSession() { resetSession(session: ChatSession) {
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = []; session.messages = [];
session.memoryPrompt = ""; session.memoryPrompt = "";
}); });
}, },
summarizeSession(refreshTitle: boolean = false) { summarizeSession(
refreshTitle: boolean = false,
targetSession: ChatSession,
) {
const config = useAppConfig.getState(); const config = useAppConfig.getState();
const session = get().currentSession(); const session = targetSession;
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3? // skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) { if (isDalle3(modelConfig.model)) {
@ -651,7 +655,8 @@ export const useChatStore = createPersistStore(
}, },
onFinish(message, responseRes) { onFinish(message, responseRes) {
if (responseRes?.status === 200) { if (responseRes?.status === 200) {
get().updateCurrentSession( get().updateTargetSession(
session,
(session) => (session) =>
(session.topic = (session.topic =
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC), message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
@ -719,7 +724,7 @@ export const useChatStore = createPersistStore(
onFinish(message, responseRes) { onFinish(message, responseRes) {
if (responseRes?.status === 200) { if (responseRes?.status === 200) {
console.log("[Memory] ", message); console.log("[Memory] ", message);
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.lastSummarizeIndex = lastSummarizeIndex; session.lastSummarizeIndex = lastSummarizeIndex;
session.memoryPrompt = message; // Update the memory prompt for stored it in local storage session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
}); });
@ -732,20 +737,22 @@ export const useChatStore = createPersistStore(
} }
}, },
updateStat(message: ChatMessage) { updateStat(message: ChatMessage, session: ChatSession) {
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.stat.charCount += message.content.length; session.stat.charCount += message.content.length;
// TODO: should update chat count and word count // TODO: should update chat count and word count
}); });
}, },
updateTargetSession(
updateCurrentSession(updater: (session: ChatSession) => void) { targetSession: ChatSession,
updater: (session: ChatSession) => void,
) {
const sessions = get().sessions; const sessions = get().sessions;
const index = get().currentSessionIndex; const index = sessions.findIndex((s) => s.id === targetSession.id);
if (index < 0) return;
updater(sessions[index]); updater(sessions[index]);
set(() => ({ sessions })); set(() => ({ sessions }));
}, },
async clearAllData() { async clearAllData() {
await indexedDBStorage.clear(); await indexedDBStorage.clear();
localStorage.clear(); localStorage.clear();

View File

@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
} }
}); });
/**
* get model name and provider from a formatted string,
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
* @param modelWithProvider model name with provider separated by last `@` char,
* @returns [model, provider] tuple, if no `@` char found, provider is undefined
*/
export function getModelProvider(modelWithProvider: string): [string, string?] {
const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
return [model, provider];
}
export function collectModelTable( export function collectModelTable(
models: readonly LLMModel[], models: readonly LLMModel[],
customModels: string, customModels: string,
@ -79,10 +90,10 @@ export function collectModelTable(
); );
} else { } else {
// 1. find model by name, and set available value // 1. find model by name, and set available value
const [customModelName, customProviderName] = name.split("@"); const [customModelName, customProviderName] = getModelProvider(name);
let count = 0; let count = 0;
for (const fullName in modelTable) { for (const fullName in modelTable) {
const [modelName, providerName] = fullName.split("@"); const [modelName, providerName] = getModelProvider(fullName);
if ( if (
customModelName == modelName && customModelName == modelName &&
(customProviderName === undefined || (customProviderName === undefined ||
@ -102,7 +113,7 @@ export function collectModelTable(
} }
// 2. if model not exists, create new model with available value // 2. if model not exists, create new model with available value
if (count === 0) { if (count === 0) {
let [customModelName, customProviderName] = name.split("@"); let [customModelName, customProviderName] = getModelProvider(name);
const provider = customProvider( const provider = customProvider(
customProviderName || customModelName, customProviderName || customModelName,
); );
@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
for (const key of Object.keys(modelTable)) { for (const key of Object.keys(modelTable)) {
if ( if (
modelTable[key].available && modelTable[key].available &&
key.split("@").shift() == defaultModel getModelProvider(key)[0] == defaultModel
) { ) {
modelTable[key].isDefault = true; modelTable[key].isDefault = true;
break; break;

View File

@ -9,7 +9,7 @@
}, },
"package": { "package": {
"productName": "NextChat", "productName": "NextChat",
"version": "2.15.6" "version": "2.15.7"
}, },
"tauri": { "tauri": {
"allowlist": { "allowlist": {

View File

@ -0,0 +1,31 @@
import { getModelProvider } from "../app/utils/model";
describe("getModelProvider", () => {
test("should return model and provider when input contains '@'", () => {
const input = "model@provider";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model");
expect(provider).toBe("provider");
});
test("should return model and undefined provider when input does not contain '@'", () => {
const input = "model";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model");
expect(provider).toBeUndefined();
});
test("should handle multiple '@' characters correctly", () => {
const input = "model@provider@extra";
const [model, provider] = getModelProvider(input);
expect(model).toBe("model@provider");
expect(provider).toBe("extra");
});
test("should return empty strings when input is empty", () => {
const input = "";
const [model, provider] = getModelProvider(input);
expect(model).toBe("");
expect(provider).toBeUndefined();
});
});