feat: #138 add context prompt, close #330 #321

This commit is contained in:
Yifei Zhang
2023-04-02 17:48:43 +00:00
parent c978de2c10
commit b85245e317
14 changed files with 296 additions and 69 deletions

View File

@@ -53,6 +53,8 @@ export interface ChatConfig {
export type ModelConfig = ChatConfig["modelConfig"];
export const ROLES: Message["role"][] = ["system", "user", "assistant"];
const ENABLE_GPT4 = true;
export const ALL_MODELS = [
@@ -151,6 +153,7 @@ export interface ChatSession {
id: number;
topic: string;
memoryPrompt: string;
context: Message[];
messages: Message[];
stat: ChatStat;
lastUpdate: string;
@@ -158,7 +161,7 @@ export interface ChatSession {
}
const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
export const BOT_HELLO = {
export const BOT_HELLO: Message = {
role: "assistant",
content: Locale.Store.BotHello,
date: "",
@@ -171,6 +174,7 @@ function createEmptySession(): ChatSession {
id: Date.now(),
topic: DEFAULT_TOPIC,
memoryPrompt: "",
context: [],
messages: [],
stat: {
tokenCount: 0,
@@ -380,16 +384,18 @@ export const useChatStore = create<ChatStore>()(
const session = get().currentSession();
const config = get().config;
const n = session.messages.length;
const recentMessages = session.messages.slice(
Math.max(0, n - config.historyMessageCount),
);
const memoryPrompt = get().getMemoryPrompt();
const context = session.context.slice();
if (session.memoryPrompt) {
recentMessages.unshift(memoryPrompt);
if (session.memoryPrompt && session.memoryPrompt.length > 0) {
const memoryPrompt = get().getMemoryPrompt();
context.push(memoryPrompt);
}
const recentMessages = context.concat(
session.messages.slice(Math.max(0, n - config.historyMessageCount)),
);
return recentMessages;
},
@@ -427,11 +433,13 @@ export const useChatStore = create<ChatStore>()(
let toBeSummarizedMsgs = session.messages.slice(
session.lastSummarizeIndex,
);
const historyMsgLength = countMessages(toBeSummarizedMsgs);
if (historyMsgLength > 4000) {
if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
-config.historyMessageCount,
Math.max(0, n - config.historyMessageCount),
);
}
@@ -494,7 +502,16 @@ export const useChatStore = create<ChatStore>()(
}),
{
name: LOCAL_KEY,
version: 1,
version: 1.1,
migrate(persistedState, version) {
const state = persistedState as ChatStore;
if (version === 1) {
state.sessions.forEach((s) => (s.context = []));
}
return state;
},
},
),
);