Merge pull request #5774 from ConnectAI-E/feature/update-target-session

fix: updateCurrentSession => updateTargetSession
This commit is contained in:
Lloyd Zhou 2024-11-06 11:16:33 +08:00 committed by GitHub
commit f526d6f560
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 53 deletions

View File

@ -149,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) {
text={Locale.Chat.Config.Reset} text={Locale.Chat.Config.Reset}
onClick={async () => { onClick={async () => {
if (await showConfirm(Locale.Memory.ResetConfirm)) { if (await showConfirm(Locale.Memory.ResetConfirm)) {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.memoryPrompt = ""), (session) => (session.memoryPrompt = ""),
); );
} }
@ -174,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) {
updateMask={(updater) => { updateMask={(updater) => {
const mask = { ...session.mask }; const mask = { ...session.mask };
updater(mask); updater(mask);
chatStore.updateCurrentSession((session) => (session.mask = mask)); chatStore.updateTargetSession(
session,
(session) => (session.mask = mask),
);
}} }}
shouldSyncFromGlobal shouldSyncFromGlobal
extraListItems={ extraListItems={
@ -346,12 +350,14 @@ export function PromptHints(props: {
function ClearContextDivider() { function ClearContextDivider() {
const chatStore = useChatStore(); const chatStore = useChatStore();
const session = chatStore.currentSession();
return ( return (
<div <div
className={styles["clear-context"]} className={styles["clear-context"]}
onClick={() => onClick={() =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = undefined), (session) => (session.clearContextIndex = undefined),
) )
} }
@ -461,6 +467,7 @@ export function ChatActions(props: {
const navigate = useNavigate(); const navigate = useNavigate();
const chatStore = useChatStore(); const chatStore = useChatStore();
const pluginStore = usePluginStore(); const pluginStore = usePluginStore();
const session = chatStore.currentSession();
// switch themes // switch themes
const theme = config.theme; const theme = config.theme;
@ -477,10 +484,9 @@ export function ChatActions(props: {
const stopAll = () => ChatControllerPool.stopAll(); const stopAll = () => ChatControllerPool.stopAll();
// switch model // switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model; const currentModel = session.mask.modelConfig.model;
const currentProviderName = const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName || session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
ServiceProvider.OpenAI;
const allModels = useAllModels(); const allModels = useAllModels();
const models = useMemo(() => { const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available); const filteredModels = allModels.filter((m) => m.available);
@ -514,12 +520,9 @@ export function ChatActions(props: {
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"]; const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
const dalle3Styles: DalleStyle[] = ["vivid", "natural"]; const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
const currentSize = const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; const currentQuality = session.mask.modelConfig?.quality ?? "standard";
const currentQuality = const currentStyle = session.mask.modelConfig?.style ?? "vivid";
chatStore.currentSession().mask.modelConfig?.quality ?? "standard";
const currentStyle =
chatStore.currentSession().mask.modelConfig?.style ?? "vivid";
const isMobileScreen = useMobileScreen(); const isMobileScreen = useMobileScreen();
@ -537,7 +540,7 @@ export function ChatActions(props: {
if (isUnavailableModel && models.length > 0) { if (isUnavailableModel && models.length > 0) {
// show next model to default model if exist // show next model to default model if exist
let nextModel = models.find((model) => model.isDefault) || models[0]; let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = nextModel.name; session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider; ?.providerName as ServiceProvider;
@ -548,7 +551,7 @@ export function ChatActions(props: {
: nextModel.name, : nextModel.name,
); );
} }
}, [chatStore, currentModel, models]); }, [chatStore, currentModel, models, session]);
return ( return (
<div className={styles["chat-input-actions"]}> <div className={styles["chat-input-actions"]}>
@ -615,7 +618,7 @@ export function ChatActions(props: {
text={Locale.Chat.InputActions.Clear} text={Locale.Chat.InputActions.Clear}
icon={<BreakIcon />} icon={<BreakIcon />}
onClick={() => { onClick={() => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
if (session.clearContextIndex === session.messages.length) { if (session.clearContextIndex === session.messages.length) {
session.clearContextIndex = undefined; session.clearContextIndex = undefined;
} else { } else {
@ -647,7 +650,7 @@ export function ChatActions(props: {
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const [model, providerName] = getModelProvider(s[0]); const [model, providerName] = getModelProvider(s[0]);
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.model = model as ModelType; session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName = session.mask.modelConfig.providerName =
providerName as ServiceProvider; providerName as ServiceProvider;
@ -685,7 +688,7 @@ export function ChatActions(props: {
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const size = s[0]; const size = s[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.size = size; session.mask.modelConfig.size = size;
}); });
showToast(size); showToast(size);
@ -712,7 +715,7 @@ export function ChatActions(props: {
onSelection={(q) => { onSelection={(q) => {
if (q.length === 0) return; if (q.length === 0) return;
const quality = q[0]; const quality = q[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.quality = quality; session.mask.modelConfig.quality = quality;
}); });
showToast(quality); showToast(quality);
@ -739,7 +742,7 @@ export function ChatActions(props: {
onSelection={(s) => { onSelection={(s) => {
if (s.length === 0) return; if (s.length === 0) return;
const style = s[0]; const style = s[0];
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.modelConfig.style = style; session.mask.modelConfig.style = style;
}); });
showToast(style); showToast(style);
@ -770,7 +773,7 @@ export function ChatActions(props: {
}))} }))}
onClose={() => setShowPluginSelector(false)} onClose={() => setShowPluginSelector(false)}
onSelection={(s) => { onSelection={(s) => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
session.mask.plugin = s as string[]; session.mask.plugin = s as string[];
}); });
}} }}
@ -813,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
icon={<ConfirmIcon />} icon={<ConfirmIcon />}
key="ok" key="ok"
onClick={() => { onClick={() => {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.messages = messages), (session) => (session.messages = messages),
); );
props.onClose(); props.onClose();
@ -830,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
type="text" type="text"
value={session.topic} value={session.topic}
onInput={(e) => onInput={(e) =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.topic = e.currentTarget.value), (session) => (session.topic = e.currentTarget.value),
) )
} }
@ -991,7 +996,8 @@ function _Chat() {
prev: () => chatStore.nextSession(-1), prev: () => chatStore.nextSession(-1),
next: () => chatStore.nextSession(1), next: () => chatStore.nextSession(1),
clear: () => clear: () =>
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session.clearContextIndex = session.messages.length), (session) => (session.clearContextIndex = session.messages.length),
), ),
fork: () => chatStore.forkSession(), fork: () => chatStore.forkSession(),
@ -1062,7 +1068,7 @@ function _Chat() {
}; };
useEffect(() => { useEffect(() => {
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(session, (session) => {
const stopTiming = Date.now() - REQUEST_TIMEOUT_MS; const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
session.messages.forEach((m) => { session.messages.forEach((m) => {
// check if should stop all stale messages // check if should stop all stale messages
@ -1088,7 +1094,7 @@ function _Chat() {
} }
}); });
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []); }, [session]);
// check if should send message // check if should send message
const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => { const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
@ -1119,7 +1125,8 @@ function _Chat() {
}; };
const deleteMessage = (msgId?: string) => { const deleteMessage = (msgId?: string) => {
chatStore.updateCurrentSession( chatStore.updateTargetSession(
session,
(session) => (session) =>
(session.messages = session.messages.filter((m) => m.id !== msgId)), (session.messages = session.messages.filter((m) => m.id !== msgId)),
); );
@ -1186,7 +1193,7 @@ function _Chat() {
}; };
const onPinMessage = (message: ChatMessage) => { const onPinMessage = (message: ChatMessage) => {
chatStore.updateCurrentSession((session) => chatStore.updateTargetSession(session, (session) =>
session.mask.context.push(message), session.mask.context.push(message),
); );
@ -1712,14 +1719,17 @@ function _Chat() {
}); });
} }
} }
chatStore.updateCurrentSession((session) => { chatStore.updateTargetSession(
const m = session.mask.context session,
.concat(session.messages) (session) => {
.find((m) => m.id === message.id); const m = session.mask.context
if (m) { .concat(session.messages)
m.content = newContent; .find((m) => m.id === message.id);
} if (m) {
}); m.content = newContent;
}
},
);
}} }}
></IconButton> ></IconButton>
</div> </div>

View File

@ -357,7 +357,7 @@ export const useChatStore = createPersistStore(
session.messages = session.messages.concat(); session.messages = session.messages.concat();
session.lastUpdate = Date.now(); session.lastUpdate = Date.now();
}); });
get().updateStat(message); get().updateStat(message, targetSession);
get().summarizeSession(false, targetSession); get().summarizeSession(false, targetSession);
}, },
@ -396,10 +396,10 @@ export const useChatStore = createPersistStore(
// get recent messages // get recent messages
const recentMessages = get().getMessagesWithMemory(); const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage); const sendMessages = recentMessages.concat(userMessage);
const messageIndex = get().currentSession().messages.length + 1; const messageIndex = session.messages.length + 1;
// save user's and bot's message // save user's and bot's message
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
const savedUserMessage = { const savedUserMessage = {
...userMessage, ...userMessage,
content: mContent, content: mContent,
@ -420,7 +420,7 @@ export const useChatStore = createPersistStore(
if (message) { if (message) {
botMessage.content = message; botMessage.content = message;
} }
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -434,7 +434,7 @@ export const useChatStore = createPersistStore(
}, },
onBeforeTool(tool: ChatMessageTool) { onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool); (botMessage.tools = botMessage?.tools || []).push(tool);
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -444,7 +444,7 @@ export const useChatStore = createPersistStore(
tools[i] = { ...tool }; tools[i] = { ...tool };
} }
}); });
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
}, },
@ -459,7 +459,7 @@ export const useChatStore = createPersistStore(
botMessage.streaming = false; botMessage.streaming = false;
userMessage.isError = !isAborted; userMessage.isError = !isAborted;
botMessage.isError = !isAborted; botMessage.isError = !isAborted;
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat(); session.messages = session.messages.concat();
}); });
ChatControllerPool.remove( ChatControllerPool.remove(
@ -591,8 +591,8 @@ export const useChatStore = createPersistStore(
set(() => ({ sessions })); set(() => ({ sessions }));
}, },
resetSession() { resetSession(session: ChatSession) {
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.messages = []; session.messages = [];
session.memoryPrompt = ""; session.memoryPrompt = "";
}); });
@ -736,19 +736,12 @@ export const useChatStore = createPersistStore(
} }
}, },
updateStat(message: ChatMessage) { updateStat(message: ChatMessage, session: ChatSession) {
get().updateCurrentSession((session) => { get().updateTargetSession(session, (session) => {
session.stat.charCount += message.content.length; session.stat.charCount += message.content.length;
// TODO: should update chat count and word count // TODO: should update chat count and word count
}); });
}, },
updateCurrentSession(updater: (session: ChatSession) => void) {
const sessions = get().sessions;
const index = get().currentSessionIndex;
updater(sessions[index]);
set(() => ({ sessions }));
},
updateTargetSession( updateTargetSession(
targetSession: ChatSession, targetSession: ChatSession,
updater: (session: ChatSession) => void, updater: (session: ChatSession) => void,