mirror of
https://github.com/Yidadaa/ChatGPT-Next-Web.git
synced 2025-08-09 02:35:41 +08:00
feat: add multi-model support
This commit is contained in:
@@ -2,7 +2,13 @@ import { trimTopic } from "../utils";
|
||||
|
||||
import Locale, { getLang } from "../locales";
|
||||
import { showToast } from "../components/ui-lib";
|
||||
import { ModelConfig, ModelType, useAppConfig } from "./config";
|
||||
import {
|
||||
LLMProvider,
|
||||
MaskConfig,
|
||||
ModelConfig,
|
||||
ModelType,
|
||||
useAppConfig,
|
||||
} from "./config";
|
||||
import { createEmptyMask, Mask } from "./mask";
|
||||
import {
|
||||
DEFAULT_INPUT_TEMPLATE,
|
||||
@@ -10,19 +16,19 @@ import {
|
||||
StoreKey,
|
||||
SUMMARIZE_MODEL,
|
||||
} from "../constant";
|
||||
import { api, RequestMessage } from "../client/api";
|
||||
import { ChatControllerPool } from "../client/controller";
|
||||
import { ChatControllerPool } from "../client/common/controller";
|
||||
import { prettyObject } from "../utils/format";
|
||||
import { estimateTokenLength } from "../utils/token";
|
||||
import { nanoid } from "nanoid";
|
||||
import { createPersistStore } from "../utils/store";
|
||||
import { RequestMessage, api } from "../client";
|
||||
|
||||
export type ChatMessage = RequestMessage & {
|
||||
date: string;
|
||||
streaming?: boolean;
|
||||
isError?: boolean;
|
||||
id: string;
|
||||
model?: ModelType;
|
||||
model?: string;
|
||||
};
|
||||
|
||||
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
|
||||
@@ -84,46 +90,25 @@ function getSummarizeModel(currentModel: string) {
|
||||
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
|
||||
}
|
||||
|
||||
interface ChatStore {
|
||||
sessions: ChatSession[];
|
||||
currentSessionIndex: number;
|
||||
clearSessions: () => void;
|
||||
moveSession: (from: number, to: number) => void;
|
||||
selectSession: (index: number) => void;
|
||||
newSession: (mask?: Mask) => void;
|
||||
deleteSession: (index: number) => void;
|
||||
currentSession: () => ChatSession;
|
||||
nextSession: (delta: number) => void;
|
||||
onNewMessage: (message: ChatMessage) => void;
|
||||
onUserInput: (content: string) => Promise<void>;
|
||||
summarizeSession: () => void;
|
||||
updateStat: (message: ChatMessage) => void;
|
||||
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
|
||||
updateMessage: (
|
||||
sessionIndex: number,
|
||||
messageIndex: number,
|
||||
updater: (message?: ChatMessage) => void,
|
||||
) => void;
|
||||
resetSession: () => void;
|
||||
getMessagesWithMemory: () => ChatMessage[];
|
||||
getMemoryPrompt: () => ChatMessage;
|
||||
|
||||
clearAllData: () => void;
|
||||
}
|
||||
|
||||
function countMessages(msgs: ChatMessage[]) {
|
||||
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
||||
}
|
||||
|
||||
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
|
||||
function fillTemplateWith(
|
||||
input: string,
|
||||
context: {
|
||||
model: string;
|
||||
template?: string;
|
||||
},
|
||||
) {
|
||||
const vars = {
|
||||
model: modelConfig.model,
|
||||
model: context.model,
|
||||
time: new Date().toLocaleString(),
|
||||
lang: getLang(),
|
||||
input: input,
|
||||
};
|
||||
|
||||
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
|
||||
let output = context.template ?? DEFAULT_INPUT_TEMPLATE;
|
||||
|
||||
// must contains {{input}}
|
||||
const inputVar = "{{input}}";
|
||||
@@ -197,13 +182,13 @@ export const useChatStore = createPersistStore(
|
||||
|
||||
if (mask) {
|
||||
const config = useAppConfig.getState();
|
||||
const globalModelConfig = config.modelConfig;
|
||||
const globalModelConfig = config.globalMaskConfig;
|
||||
|
||||
session.mask = {
|
||||
...mask,
|
||||
modelConfig: {
|
||||
config: {
|
||||
...globalModelConfig,
|
||||
...mask.modelConfig,
|
||||
...mask.config,
|
||||
},
|
||||
};
|
||||
session.topic = mask.name;
|
||||
@@ -288,11 +273,39 @@ export const useChatStore = createPersistStore(
|
||||
get().summarizeSession();
|
||||
},
|
||||
|
||||
getCurrentMaskConfig() {
|
||||
return get().currentSession().mask.config;
|
||||
},
|
||||
|
||||
extractModelConfig(maskConfig: MaskConfig) {
|
||||
const provider = maskConfig.provider;
|
||||
if (!maskConfig.modelConfig[provider]) {
|
||||
throw Error("[Chat] failed to initialize provider: " + provider);
|
||||
}
|
||||
|
||||
return maskConfig.modelConfig[provider];
|
||||
},
|
||||
|
||||
getCurrentModelConfig() {
|
||||
const maskConfig = this.getCurrentMaskConfig();
|
||||
return this.extractModelConfig(maskConfig);
|
||||
},
|
||||
|
||||
getClient() {
|
||||
const appConfig = useAppConfig.getState();
|
||||
const currentMaskConfig = get().getCurrentMaskConfig();
|
||||
return api.createLLMClient(appConfig.providerConfig, currentMaskConfig);
|
||||
},
|
||||
|
||||
async onUserInput(content: string) {
|
||||
const session = get().currentSession();
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
const maskConfig = this.getCurrentMaskConfig();
|
||||
const modelConfig = this.getCurrentModelConfig();
|
||||
|
||||
const userContent = fillTemplateWith(content, modelConfig);
|
||||
const userContent = fillTemplateWith(content, {
|
||||
model: modelConfig.model,
|
||||
template: maskConfig.chatConfig.template,
|
||||
});
|
||||
console.log("[User Input] after template: ", userContent);
|
||||
|
||||
const userMessage: ChatMessage = createMessage({
|
||||
@@ -323,10 +336,11 @@ export const useChatStore = createPersistStore(
|
||||
]);
|
||||
});
|
||||
|
||||
const client = this.getClient();
|
||||
|
||||
// make request
|
||||
api.llm.chat({
|
||||
client.chatStream({
|
||||
messages: sendMessages,
|
||||
config: { ...modelConfig, stream: true },
|
||||
onUpdate(message) {
|
||||
botMessage.streaming = true;
|
||||
if (message) {
|
||||
@@ -391,7 +405,9 @@ export const useChatStore = createPersistStore(
|
||||
|
||||
getMessagesWithMemory() {
|
||||
const session = get().currentSession();
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
const maskConfig = this.getCurrentMaskConfig();
|
||||
const chatConfig = maskConfig.chatConfig;
|
||||
const modelConfig = this.getCurrentModelConfig();
|
||||
const clearContextIndex = session.clearContextIndex ?? 0;
|
||||
const messages = session.messages.slice();
|
||||
const totalMessageCount = session.messages.length;
|
||||
@@ -400,14 +416,14 @@ export const useChatStore = createPersistStore(
|
||||
const contextPrompts = session.mask.context.slice();
|
||||
|
||||
// system prompts, to get close to OpenAI Web ChatGPT
|
||||
const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
|
||||
const shouldInjectSystemPrompts = chatConfig.enableInjectSystemPrompts;
|
||||
const systemPrompts = shouldInjectSystemPrompts
|
||||
? [
|
||||
createMessage({
|
||||
role: "system",
|
||||
content: fillTemplateWith("", {
|
||||
...modelConfig,
|
||||
template: DEFAULT_SYSTEM_TEMPLATE,
|
||||
model: modelConfig.model,
|
||||
template: chatConfig.template,
|
||||
}),
|
||||
}),
|
||||
]
|
||||
@@ -421,7 +437,7 @@ export const useChatStore = createPersistStore(
|
||||
|
||||
// long term memory
|
||||
const shouldSendLongTermMemory =
|
||||
modelConfig.sendMemory &&
|
||||
chatConfig.sendMemory &&
|
||||
session.memoryPrompt &&
|
||||
session.memoryPrompt.length > 0 &&
|
||||
session.lastSummarizeIndex > clearContextIndex;
|
||||
@@ -433,7 +449,7 @@ export const useChatStore = createPersistStore(
|
||||
// short term memory
|
||||
const shortTermMemoryStartIndex = Math.max(
|
||||
0,
|
||||
totalMessageCount - modelConfig.historyMessageCount,
|
||||
totalMessageCount - chatConfig.historyMessageCount,
|
||||
);
|
||||
|
||||
// lets concat send messages, including 4 parts:
|
||||
@@ -494,6 +510,8 @@ export const useChatStore = createPersistStore(
|
||||
|
||||
summarizeSession() {
|
||||
const config = useAppConfig.getState();
|
||||
const maskConfig = this.getCurrentMaskConfig();
|
||||
const chatConfig = maskConfig.chatConfig;
|
||||
const session = get().currentSession();
|
||||
|
||||
// remove error messages if any
|
||||
@@ -502,7 +520,7 @@ export const useChatStore = createPersistStore(
|
||||
// should summarize topic after chating more than 50 words
|
||||
const SUMMARIZE_MIN_LEN = 50;
|
||||
if (
|
||||
config.enableAutoGenerateTitle &&
|
||||
chatConfig.enableAutoGenerateTitle &&
|
||||
session.topic === DEFAULT_TOPIC &&
|
||||
countMessages(messages) >= SUMMARIZE_MIN_LEN
|
||||
) {
|
||||
@@ -512,11 +530,12 @@ export const useChatStore = createPersistStore(
|
||||
content: Locale.Store.Prompt.Topic,
|
||||
}),
|
||||
);
|
||||
api.llm.chat({
|
||||
|
||||
const client = this.getClient();
|
||||
client.chat({
|
||||
messages: topicMessages,
|
||||
config: {
|
||||
model: getSummarizeModel(session.mask.modelConfig.model),
|
||||
},
|
||||
shouldSummarize: true,
|
||||
|
||||
onFinish(message) {
|
||||
get().updateCurrentSession(
|
||||
(session) =>
|
||||
@@ -527,7 +546,7 @@ export const useChatStore = createPersistStore(
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
const modelConfig = this.getCurrentModelConfig();
|
||||
const summarizeIndex = Math.max(
|
||||
session.lastSummarizeIndex,
|
||||
session.clearContextIndex ?? 0,
|
||||
@@ -541,7 +560,7 @@ export const useChatStore = createPersistStore(
|
||||
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
|
||||
const n = toBeSummarizedMsgs.length;
|
||||
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
|
||||
Math.max(0, n - modelConfig.historyMessageCount),
|
||||
Math.max(0, n - chatConfig.historyMessageCount),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -554,14 +573,14 @@ export const useChatStore = createPersistStore(
|
||||
"[Chat History] ",
|
||||
toBeSummarizedMsgs,
|
||||
historyMsgLength,
|
||||
modelConfig.compressMessageLengthThreshold,
|
||||
chatConfig.compressMessageLengthThreshold,
|
||||
);
|
||||
|
||||
if (
|
||||
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
|
||||
modelConfig.sendMemory
|
||||
historyMsgLength > chatConfig.compressMessageLengthThreshold &&
|
||||
chatConfig.sendMemory
|
||||
) {
|
||||
api.llm.chat({
|
||||
this.getClient().chatStream({
|
||||
messages: toBeSummarizedMsgs.concat(
|
||||
createMessage({
|
||||
role: "system",
|
||||
@@ -569,11 +588,7 @@ export const useChatStore = createPersistStore(
|
||||
date: "",
|
||||
}),
|
||||
),
|
||||
config: {
|
||||
...modelConfig,
|
||||
stream: true,
|
||||
model: getSummarizeModel(session.mask.modelConfig.model),
|
||||
},
|
||||
shouldSummarize: true,
|
||||
onUpdate(message) {
|
||||
session.memoryPrompt = message;
|
||||
},
|
||||
@@ -614,52 +629,9 @@ export const useChatStore = createPersistStore(
|
||||
name: StoreKey.Chat,
|
||||
version: 3.1,
|
||||
migrate(persistedState, version) {
|
||||
const state = persistedState as any;
|
||||
const newState = JSON.parse(
|
||||
JSON.stringify(state),
|
||||
) as typeof DEFAULT_CHAT_STATE;
|
||||
// TODO(yifei): migrate from old versions
|
||||
|
||||
if (version < 2) {
|
||||
newState.sessions = [];
|
||||
|
||||
const oldSessions = state.sessions;
|
||||
for (const oldSession of oldSessions) {
|
||||
const newSession = createEmptySession();
|
||||
newSession.topic = oldSession.topic;
|
||||
newSession.messages = [...oldSession.messages];
|
||||
newSession.mask.modelConfig.sendMemory = true;
|
||||
newSession.mask.modelConfig.historyMessageCount = 4;
|
||||
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
|
||||
newState.sessions.push(newSession);
|
||||
}
|
||||
}
|
||||
|
||||
if (version < 3) {
|
||||
// migrate id to nanoid
|
||||
newState.sessions.forEach((s) => {
|
||||
s.id = nanoid();
|
||||
s.messages.forEach((m) => (m.id = nanoid()));
|
||||
});
|
||||
}
|
||||
|
||||
// Enable `enableInjectSystemPrompts` attribute for old sessions.
|
||||
// Resolve issue of old sessions not automatically enabling.
|
||||
if (version < 3.1) {
|
||||
newState.sessions.forEach((s) => {
|
||||
if (
|
||||
// Exclude those already set by user
|
||||
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
|
||||
) {
|
||||
// Because users may have changed this configuration,
|
||||
// the user's current configuration is used instead of the default
|
||||
const config = useAppConfig.getState();
|
||||
s.mask.modelConfig.enableInjectSystemPrompts =
|
||||
config.modelConfig.enableInjectSystemPrompts;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return newState as any;
|
||||
return persistedState as any;
|
||||
},
|
||||
},
|
||||
);
|
||||
|
Reference in New Issue
Block a user