feat: add multi-model support

This commit is contained in:
Yidadaa
2023-09-26 00:19:21 +08:00
parent b90dfb48ee
commit 5610f423d0
62 changed files with 1439 additions and 940 deletions

View File

@@ -2,7 +2,13 @@ import { trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import {
LLMProvider,
MaskConfig,
ModelConfig,
ModelType,
useAppConfig,
} from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
@@ -10,19 +16,19 @@ import {
StoreKey,
SUMMARIZE_MODEL,
} from "../constant";
import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { ChatControllerPool } from "../client/common/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { RequestMessage, api } from "../client";
export type ChatMessage = RequestMessage & {
date: string;
streaming?: boolean;
isError?: boolean;
id: string;
model?: ModelType;
model?: string;
};
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -84,46 +90,25 @@ function getSummarizeModel(currentModel: string) {
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
}
interface ChatStore {
sessions: ChatSession[];
currentSessionIndex: number;
clearSessions: () => void;
moveSession: (from: number, to: number) => void;
selectSession: (index: number) => void;
newSession: (mask?: Mask) => void;
deleteSession: (index: number) => void;
currentSession: () => ChatSession;
nextSession: (delta: number) => void;
onNewMessage: (message: ChatMessage) => void;
onUserInput: (content: string) => Promise<void>;
summarizeSession: () => void;
updateStat: (message: ChatMessage) => void;
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
updateMessage: (
sessionIndex: number,
messageIndex: number,
updater: (message?: ChatMessage) => void,
) => void;
resetSession: () => void;
getMessagesWithMemory: () => ChatMessage[];
getMemoryPrompt: () => ChatMessage;
clearAllData: () => void;
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
}
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
function fillTemplateWith(
input: string,
context: {
model: string;
template?: string;
},
) {
const vars = {
model: modelConfig.model,
model: context.model,
time: new Date().toLocaleString(),
lang: getLang(),
input: input,
};
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
let output = context.template ?? DEFAULT_INPUT_TEMPLATE;
// must contains {{input}}
const inputVar = "{{input}}";
@@ -197,13 +182,13 @@ export const useChatStore = createPersistStore(
if (mask) {
const config = useAppConfig.getState();
const globalModelConfig = config.modelConfig;
const globalModelConfig = config.globalMaskConfig;
session.mask = {
...mask,
modelConfig: {
config: {
...globalModelConfig,
...mask.modelConfig,
...mask.config,
},
};
session.topic = mask.name;
@@ -288,11 +273,39 @@ export const useChatStore = createPersistStore(
get().summarizeSession();
},
getCurrentMaskConfig() {
return get().currentSession().mask.config;
},
extractModelConfig(maskConfig: MaskConfig) {
const provider = maskConfig.provider;
if (!maskConfig.modelConfig[provider]) {
throw Error("[Chat] failed to initialize provider: " + provider);
}
return maskConfig.modelConfig[provider];
},
getCurrentModelConfig() {
const maskConfig = this.getCurrentMaskConfig();
return this.extractModelConfig(maskConfig);
},
getClient() {
const appConfig = useAppConfig.getState();
const currentMaskConfig = get().getCurrentMaskConfig();
return api.createLLMClient(appConfig.providerConfig, currentMaskConfig);
},
async onUserInput(content: string) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const maskConfig = this.getCurrentMaskConfig();
const modelConfig = this.getCurrentModelConfig();
const userContent = fillTemplateWith(content, modelConfig);
const userContent = fillTemplateWith(content, {
model: modelConfig.model,
template: maskConfig.chatConfig.template,
});
console.log("[User Input] after template: ", userContent);
const userMessage: ChatMessage = createMessage({
@@ -323,10 +336,11 @@ export const useChatStore = createPersistStore(
]);
});
const client = this.getClient();
// make request
api.llm.chat({
client.chatStream({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
@@ -391,7 +405,9 @@ export const useChatStore = createPersistStore(
getMessagesWithMemory() {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const maskConfig = this.getCurrentMaskConfig();
const chatConfig = maskConfig.chatConfig;
const modelConfig = this.getCurrentModelConfig();
const clearContextIndex = session.clearContextIndex ?? 0;
const messages = session.messages.slice();
const totalMessageCount = session.messages.length;
@@ -400,14 +416,14 @@ export const useChatStore = createPersistStore(
const contextPrompts = session.mask.context.slice();
// system prompts, to get close to OpenAI Web ChatGPT
const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
const shouldInjectSystemPrompts = chatConfig.enableInjectSystemPrompts;
const systemPrompts = shouldInjectSystemPrompts
? [
createMessage({
role: "system",
content: fillTemplateWith("", {
...modelConfig,
template: DEFAULT_SYSTEM_TEMPLATE,
model: modelConfig.model,
template: chatConfig.template,
}),
}),
]
@@ -421,7 +437,7 @@ export const useChatStore = createPersistStore(
// long term memory
const shouldSendLongTermMemory =
modelConfig.sendMemory &&
chatConfig.sendMemory &&
session.memoryPrompt &&
session.memoryPrompt.length > 0 &&
session.lastSummarizeIndex > clearContextIndex;
@@ -433,7 +449,7 @@ export const useChatStore = createPersistStore(
// short term memory
const shortTermMemoryStartIndex = Math.max(
0,
totalMessageCount - modelConfig.historyMessageCount,
totalMessageCount - chatConfig.historyMessageCount,
);
// lets concat send messages, including 4 parts:
@@ -494,6 +510,8 @@ export const useChatStore = createPersistStore(
summarizeSession() {
const config = useAppConfig.getState();
const maskConfig = this.getCurrentMaskConfig();
const chatConfig = maskConfig.chatConfig;
const session = get().currentSession();
// remove error messages if any
@@ -502,7 +520,7 @@ export const useChatStore = createPersistStore(
// should summarize topic after chating more than 50 words
const SUMMARIZE_MIN_LEN = 50;
if (
config.enableAutoGenerateTitle &&
chatConfig.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN
) {
@@ -512,11 +530,12 @@ export const useChatStore = createPersistStore(
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
const client = this.getClient();
client.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model),
},
shouldSummarize: true,
onFinish(message) {
get().updateCurrentSession(
(session) =>
@@ -527,7 +546,7 @@ export const useChatStore = createPersistStore(
});
}
const modelConfig = session.mask.modelConfig;
const modelConfig = this.getCurrentModelConfig();
const summarizeIndex = Math.max(
session.lastSummarizeIndex,
session.clearContextIndex ?? 0,
@@ -541,7 +560,7 @@ export const useChatStore = createPersistStore(
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - modelConfig.historyMessageCount),
Math.max(0, n - chatConfig.historyMessageCount),
);
}
@@ -554,14 +573,14 @@ export const useChatStore = createPersistStore(
"[Chat History] ",
toBeSummarizedMsgs,
historyMsgLength,
modelConfig.compressMessageLengthThreshold,
chatConfig.compressMessageLengthThreshold,
);
if (
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
modelConfig.sendMemory
historyMsgLength > chatConfig.compressMessageLengthThreshold &&
chatConfig.sendMemory
) {
api.llm.chat({
this.getClient().chatStream({
messages: toBeSummarizedMsgs.concat(
createMessage({
role: "system",
@@ -569,11 +588,7 @@ export const useChatStore = createPersistStore(
date: "",
}),
),
config: {
...modelConfig,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model),
},
shouldSummarize: true,
onUpdate(message) {
session.memoryPrompt = message;
},
@@ -614,52 +629,9 @@ export const useChatStore = createPersistStore(
name: StoreKey.Chat,
version: 3.1,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
JSON.stringify(state),
) as typeof DEFAULT_CHAT_STATE;
// TODO(yifei): migrate from old versions
if (version < 2) {
newState.sessions = [];
const oldSessions = state.sessions;
for (const oldSession of oldSessions) {
const newSession = createEmptySession();
newSession.topic = oldSession.topic;
newSession.messages = [...oldSession.messages];
newSession.mask.modelConfig.sendMemory = true;
newSession.mask.modelConfig.historyMessageCount = 4;
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
newState.sessions.push(newSession);
}
}
if (version < 3) {
// migrate id to nanoid
newState.sessions.forEach((s) => {
s.id = nanoid();
s.messages.forEach((m) => (m.id = nanoid()));
});
}
// Enable `enableInjectSystemPrompts` attribute for old sessions.
// Resolve issue of old sessions not automatically enabling.
if (version < 3.1) {
newState.sessions.forEach((s) => {
if (
// Exclude those already set by user
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
) {
// Because users may have changed this configuration,
// the user's current configuration is used instead of the default
const config = useAppConfig.getState();
s.mask.modelConfig.enableInjectSystemPrompts =
config.modelConfig.enableInjectSystemPrompts;
}
});
}
return newState as any;
return persistedState as any;
},
},
);