feat: add multi-model support

This commit is contained in:
Yidadaa
2023-09-26 00:19:21 +08:00
parent b90dfb48ee
commit 5610f423d0
62 changed files with 1439 additions and 940 deletions

View File

@@ -1,23 +1,20 @@
import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant";
import { getHeaders } from "../client/api";
import { REMOTE_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant";
import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
import { getAuthHeaders } from "../client/common/auth";
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
const DEFAULT_OPENAI_URL =
getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/";
getClientConfig()?.buildMode === "export" ? REMOTE_API_HOST : "/api/openai/";
console.log("[API] default openai url", DEFAULT_OPENAI_URL);
const DEFAULT_ACCESS_STATE = {
token: "",
accessCode: "",
needCode: true,
hideUserApiKey: false,
hideBalanceQuery: false,
disableGPT4: false,
openaiUrl: DEFAULT_OPENAI_URL,
};
export const useAccessStore = createPersistStore(
@@ -25,35 +22,24 @@ export const useAccessStore = createPersistStore(
(set, get) => ({
enabledAccessControl() {
this.fetch();
this.fetchConfig();
return get().needCode;
},
updateCode(code: string) {
set(() => ({ accessCode: code?.trim() }));
},
updateToken(token: string) {
set(() => ({ token: token?.trim() }));
},
updateOpenAiUrl(url: string) {
set(() => ({ openaiUrl: url?.trim() }));
},
isAuthorized() {
this.fetch();
this.fetchConfig();
// has token or has code or disabled access control
return (
!!get().token || !!get().accessCode || !this.enabledAccessControl()
);
return !!get().accessCode || !this.enabledAccessControl();
},
fetch() {
fetchConfig() {
if (fetchState > 0 || getClientConfig()?.buildMode === "export") return;
fetchState = 1;
fetch("/api/config", {
method: "post",
body: null,
headers: {
...getHeaders(),
...getAuthHeaders(),
},
})
.then((res) => res.json())

View File

@@ -2,7 +2,13 @@ import { trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import {
LLMProvider,
MaskConfig,
ModelConfig,
ModelType,
useAppConfig,
} from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
@@ -10,19 +16,19 @@ import {
StoreKey,
SUMMARIZE_MODEL,
} from "../constant";
import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { ChatControllerPool } from "../client/common/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { RequestMessage, api } from "../client";
export type ChatMessage = RequestMessage & {
date: string;
streaming?: boolean;
isError?: boolean;
id: string;
model?: ModelType;
model?: string;
};
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -84,46 +90,25 @@ function getSummarizeModel(currentModel: string) {
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
}
interface ChatStore {
sessions: ChatSession[];
currentSessionIndex: number;
clearSessions: () => void;
moveSession: (from: number, to: number) => void;
selectSession: (index: number) => void;
newSession: (mask?: Mask) => void;
deleteSession: (index: number) => void;
currentSession: () => ChatSession;
nextSession: (delta: number) => void;
onNewMessage: (message: ChatMessage) => void;
onUserInput: (content: string) => Promise<void>;
summarizeSession: () => void;
updateStat: (message: ChatMessage) => void;
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
updateMessage: (
sessionIndex: number,
messageIndex: number,
updater: (message?: ChatMessage) => void,
) => void;
resetSession: () => void;
getMessagesWithMemory: () => ChatMessage[];
getMemoryPrompt: () => ChatMessage;
clearAllData: () => void;
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
}
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
function fillTemplateWith(
input: string,
context: {
model: string;
template?: string;
},
) {
const vars = {
model: modelConfig.model,
model: context.model,
time: new Date().toLocaleString(),
lang: getLang(),
input: input,
};
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
let output = context.template ?? DEFAULT_INPUT_TEMPLATE;
// must contains {{input}}
const inputVar = "{{input}}";
@@ -197,13 +182,13 @@ export const useChatStore = createPersistStore(
if (mask) {
const config = useAppConfig.getState();
const globalModelConfig = config.modelConfig;
const globalModelConfig = config.globalMaskConfig;
session.mask = {
...mask,
modelConfig: {
config: {
...globalModelConfig,
...mask.modelConfig,
...mask.config,
},
};
session.topic = mask.name;
@@ -288,11 +273,39 @@ export const useChatStore = createPersistStore(
get().summarizeSession();
},
getCurrentMaskConfig() {
return get().currentSession().mask.config;
},
extractModelConfig(maskConfig: MaskConfig) {
const provider = maskConfig.provider;
if (!maskConfig.modelConfig[provider]) {
throw Error("[Chat] failed to initialize provider: " + provider);
}
return maskConfig.modelConfig[provider];
},
getCurrentModelConfig() {
const maskConfig = this.getCurrentMaskConfig();
return this.extractModelConfig(maskConfig);
},
getClient() {
const appConfig = useAppConfig.getState();
const currentMaskConfig = get().getCurrentMaskConfig();
return api.createLLMClient(appConfig.providerConfig, currentMaskConfig);
},
async onUserInput(content: string) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const maskConfig = this.getCurrentMaskConfig();
const modelConfig = this.getCurrentModelConfig();
const userContent = fillTemplateWith(content, modelConfig);
const userContent = fillTemplateWith(content, {
model: modelConfig.model,
template: maskConfig.chatConfig.template,
});
console.log("[User Input] after template: ", userContent);
const userMessage: ChatMessage = createMessage({
@@ -323,10 +336,11 @@ export const useChatStore = createPersistStore(
]);
});
const client = this.getClient();
// make request
api.llm.chat({
client.chatStream({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
@@ -391,7 +405,9 @@ export const useChatStore = createPersistStore(
getMessagesWithMemory() {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const maskConfig = this.getCurrentMaskConfig();
const chatConfig = maskConfig.chatConfig;
const modelConfig = this.getCurrentModelConfig();
const clearContextIndex = session.clearContextIndex ?? 0;
const messages = session.messages.slice();
const totalMessageCount = session.messages.length;
@@ -400,14 +416,14 @@ export const useChatStore = createPersistStore(
const contextPrompts = session.mask.context.slice();
// system prompts, to get close to OpenAI Web ChatGPT
const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
const shouldInjectSystemPrompts = chatConfig.enableInjectSystemPrompts;
const systemPrompts = shouldInjectSystemPrompts
? [
createMessage({
role: "system",
content: fillTemplateWith("", {
...modelConfig,
template: DEFAULT_SYSTEM_TEMPLATE,
model: modelConfig.model,
template: chatConfig.template,
}),
}),
]
@@ -421,7 +437,7 @@ export const useChatStore = createPersistStore(
// long term memory
const shouldSendLongTermMemory =
modelConfig.sendMemory &&
chatConfig.sendMemory &&
session.memoryPrompt &&
session.memoryPrompt.length > 0 &&
session.lastSummarizeIndex > clearContextIndex;
@@ -433,7 +449,7 @@ export const useChatStore = createPersistStore(
// short term memory
const shortTermMemoryStartIndex = Math.max(
0,
totalMessageCount - modelConfig.historyMessageCount,
totalMessageCount - chatConfig.historyMessageCount,
);
// lets concat send messages, including 4 parts:
@@ -494,6 +510,8 @@ export const useChatStore = createPersistStore(
summarizeSession() {
const config = useAppConfig.getState();
const maskConfig = this.getCurrentMaskConfig();
const chatConfig = maskConfig.chatConfig;
const session = get().currentSession();
// remove error messages if any
@@ -502,7 +520,7 @@ export const useChatStore = createPersistStore(
// should summarize topic after chating more than 50 words
const SUMMARIZE_MIN_LEN = 50;
if (
config.enableAutoGenerateTitle &&
chatConfig.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN
) {
@@ -512,11 +530,12 @@ export const useChatStore = createPersistStore(
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
const client = this.getClient();
client.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model),
},
shouldSummarize: true,
onFinish(message) {
get().updateCurrentSession(
(session) =>
@@ -527,7 +546,7 @@ export const useChatStore = createPersistStore(
});
}
const modelConfig = session.mask.modelConfig;
const modelConfig = this.getCurrentModelConfig();
const summarizeIndex = Math.max(
session.lastSummarizeIndex,
session.clearContextIndex ?? 0,
@@ -541,7 +560,7 @@ export const useChatStore = createPersistStore(
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - modelConfig.historyMessageCount),
Math.max(0, n - chatConfig.historyMessageCount),
);
}
@@ -554,14 +573,14 @@ export const useChatStore = createPersistStore(
"[Chat History] ",
toBeSummarizedMsgs,
historyMsgLength,
modelConfig.compressMessageLengthThreshold,
chatConfig.compressMessageLengthThreshold,
);
if (
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
modelConfig.sendMemory
historyMsgLength > chatConfig.compressMessageLengthThreshold &&
chatConfig.sendMemory
) {
api.llm.chat({
this.getClient().chatStream({
messages: toBeSummarizedMsgs.concat(
createMessage({
role: "system",
@@ -569,11 +588,7 @@ export const useChatStore = createPersistStore(
date: "",
}),
),
config: {
...modelConfig,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model),
},
shouldSummarize: true,
onUpdate(message) {
session.memoryPrompt = message;
},
@@ -614,52 +629,9 @@ export const useChatStore = createPersistStore(
name: StoreKey.Chat,
version: 3.1,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
JSON.stringify(state),
) as typeof DEFAULT_CHAT_STATE;
// TODO(yifei): migrate from old versions
if (version < 2) {
newState.sessions = [];
const oldSessions = state.sessions;
for (const oldSession of oldSessions) {
const newSession = createEmptySession();
newSession.topic = oldSession.topic;
newSession.messages = [...oldSession.messages];
newSession.mask.modelConfig.sendMemory = true;
newSession.mask.modelConfig.historyMessageCount = 4;
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
newState.sessions.push(newSession);
}
}
if (version < 3) {
// migrate id to nanoid
newState.sessions.forEach((s) => {
s.id = nanoid();
s.messages.forEach((m) => (m.id = nanoid()));
});
}
// Enable `enableInjectSystemPrompts` attribute for old sessions.
// Resolve issue of old sessions not automatically enabling.
if (version < 3.1) {
newState.sessions.forEach((s) => {
if (
// Exclude those already set by user
!s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
) {
// Because users may have changed this configuration,
// the user's current configuration is used instead of the default
const config = useAppConfig.getState();
s.mask.modelConfig.enableInjectSystemPrompts =
config.modelConfig.enableInjectSystemPrompts;
}
});
}
return newState as any;
return persistedState as any;
},
},
);

View File

@@ -1,4 +1,3 @@
import { LLMModel } from "../client/api";
import { isMacOS } from "../utils";
import { getClientConfig } from "../config/client";
import {
@@ -8,24 +7,85 @@ import {
StoreKey,
} from "../constant";
import { createPersistStore } from "../utils/store";
import { OpenAIConfig } from "../client/openai/config";
import { api } from "../client";
import { SubmitKey, Theme } from "../typing";
export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
export enum SubmitKey {
Enter = "Enter",
CtrlEnter = "Ctrl + Enter",
ShiftEnter = "Shift + Enter",
AltEnter = "Alt + Enter",
MetaEnter = "Meta + Enter",
}
export const DEFAULT_CHAT_CONFIG = {
enableAutoGenerateTitle: true,
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
enableInjectSystemPrompts: true,
template: DEFAULT_INPUT_TEMPLATE,
};
export type ChatConfig = typeof DEFAULT_CHAT_CONFIG;
export enum Theme {
Auto = "auto",
Dark = "dark",
Light = "light",
}
export const DEFAULT_PROVIDER_CONFIG = {
openai: OpenAIConfig.provider,
// azure: {
// endpoint: "https://api.openai.com",
// apiKey: "",
// version: "",
// ...COMMON_PROVIDER_CONFIG,
// },
// claude: {
// endpoint: "https://api.anthropic.com",
// apiKey: "",
// ...COMMON_PROVIDER_CONFIG,
// },
// google: {
// endpoint: "https://api.anthropic.com",
// apiKey: "",
// ...COMMON_PROVIDER_CONFIG,
// },
};
export const DEFAULT_CONFIG = {
export const DEFAULT_MODEL_CONFIG = {
openai: OpenAIConfig.model,
// azure: {
// model: "gpt-3.5-turbo" as string,
// summarizeModel: "gpt-3.5-turbo",
//
// temperature: 0.5,
// top_p: 1,
// max_tokens: 2000,
// presence_penalty: 0,
// frequency_penalty: 0,
// },
// claude: {
// model: "claude-2",
// summarizeModel: "claude-2",
//
// max_tokens_to_sample: 100000,
// temperature: 1,
// top_p: 0.7,
// top_k: 1,
// },
// google: {
// model: "chat-bison-001",
// summarizeModel: "claude-2",
//
// temperature: 1,
// topP: 0.7,
// topK: 1,
// },
};
export type LLMProvider = keyof typeof DEFAULT_PROVIDER_CONFIG;
export const LLMProviders = Array.from(
Object.entries(DEFAULT_PROVIDER_CONFIG),
).map(([k, v]) => [v.name, k]);
export const DEFAULT_MASK_CONFIG = {
provider: "openai" as LLMProvider,
chatConfig: { ...DEFAULT_CHAT_CONFIG },
modelConfig: { ...DEFAULT_MODEL_CONFIG },
};
export const DEFAULT_APP_CONFIG = {
lastUpdate: Date.now(), // timestamp, to merge state
submitKey: isMacOS() ? SubmitKey.MetaEnter : SubmitKey.CtrlEnter,
@@ -34,7 +94,6 @@ export const DEFAULT_CONFIG = {
theme: Theme.Auto as Theme,
tightBorder: !!getClientConfig()?.isApp,
sendPreviewBubble: true,
enableAutoGenerateTitle: true,
sidebarWidth: DEFAULT_SIDEBAR_WIDTH,
disablePromptHint: false,
@@ -42,27 +101,14 @@ export const DEFAULT_CONFIG = {
dontShowMaskSplashScreen: false, // dont show splash screen when create chat
hideBuiltinMasks: false, // dont add builtin masks
customModels: "",
models: DEFAULT_MODELS as any as LLMModel[],
modelConfig: {
model: "gpt-3.5-turbo" as ModelType,
temperature: 0.5,
top_p: 1,
max_tokens: 2000,
presence_penalty: 0,
frequency_penalty: 0,
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
enableInjectSystemPrompts: true,
template: DEFAULT_INPUT_TEMPLATE,
},
providerConfig: { ...DEFAULT_PROVIDER_CONFIG },
globalMaskConfig: { ...DEFAULT_MASK_CONFIG },
};
export type ChatConfig = typeof DEFAULT_CONFIG;
export type ModelConfig = ChatConfig["modelConfig"];
export type AppConfig = typeof DEFAULT_APP_CONFIG;
export type ProviderConfig = typeof DEFAULT_PROVIDER_CONFIG;
export type MaskConfig = typeof DEFAULT_MASK_CONFIG;
export type ModelConfig = typeof DEFAULT_MODEL_CONFIG;
export function limitNumber(
x: number,
@@ -99,48 +145,21 @@ export const ModalConfigValidator = {
};
export const useAppConfig = createPersistStore(
{ ...DEFAULT_CONFIG },
{ ...DEFAULT_APP_CONFIG },
(set, get) => ({
reset() {
set(() => ({ ...DEFAULT_CONFIG }));
set(() => ({ ...DEFAULT_APP_CONFIG }));
},
mergeModels(newModels: LLMModel[]) {
if (!newModels || newModels.length === 0) {
return;
}
const oldModels = get().models;
const modelMap: Record<string, LLMModel> = {};
for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
}
for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
}
set(() => ({
models: Object.values(modelMap),
}));
},
allModels() {
const customModels = get()
.customModels.split(",")
.filter((v) => !!v && v.length > 0)
.map((m) => ({ name: m, available: true }));
return get().models.concat(customModels);
getDefaultClient() {
return api.createLLMClient(get().providerConfig, get().globalMaskConfig);
},
}),
{
name: StoreKey.Config,
version: 3.8,
version: 4,
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
const state = persistedState as any;
if (version < 3.4) {
state.modelConfig.sendMemory = true;
@@ -169,6 +188,10 @@ export const useAppConfig = createPersistStore(
state.lastUpdate = Date.now();
}
if (version < 4) {
// todo: migarte from old versions
}
return state as any;
},
},

View File

@@ -1,10 +1,11 @@
import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, useAppConfig } from "./config";
import { MaskConfig, ModelConfig, useAppConfig } from "./config";
import { StoreKey } from "../constant";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { deepClone } from "../utils/clone";
export type Mask = {
id: string;
@@ -14,7 +15,9 @@ export type Mask = {
hideContext?: boolean;
context: ChatMessage[];
syncGlobalConfig?: boolean;
modelConfig: ModelConfig;
config: MaskConfig;
lang: Lang;
builtin: boolean;
};
@@ -33,7 +36,7 @@ export const createEmptyMask = () =>
name: DEFAULT_TOPIC,
context: [],
syncGlobalConfig: true, // use global config as default
modelConfig: { ...useAppConfig.getState().modelConfig },
config: deepClone(useAppConfig.getState().globalMaskConfig),
lang: getLang(),
builtin: false,
createdAt: Date.now(),
@@ -87,10 +90,11 @@ export const useMaskStore = createPersistStore(
const buildinMasks = BUILTIN_MASKS.map(
(m) =>
({
id: m.name,
...m,
modelConfig: {
...config.modelConfig,
...m.modelConfig,
config: {
...config.globalMaskConfig,
...m.config,
},
}) as Mask,
);
@@ -120,6 +124,8 @@ export const useMaskStore = createPersistStore(
newState.masks = updatedMasks;
}
// TODO(yifei): migrate old masks
return newState as any;
},
},

View File

@@ -13,7 +13,7 @@ import { downloadAs, readFromFile } from "../utils";
import { showToast } from "../components/ui-lib";
import Locale from "../locales";
import { createSyncClient, ProviderType } from "../utils/cloud";
import { corsPath } from "../utils/cors";
import { getApiPath } from "../utils/path";
export interface WebDavConfig {
server: string;
@@ -27,7 +27,7 @@ export type SyncStore = GetStoreState<typeof useSyncStore>;
const DEFAULT_SYNC_STATE = {
provider: ProviderType.WebDAV,
useProxy: true,
proxyUrl: corsPath(ApiPath.Cors),
proxyUrl: getApiPath(ApiPath.Cors),
webdav: {
endpoint: "",

View File

@@ -1,5 +1,4 @@
import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant";
import { api } from "../client/api";
import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
import ChatGptIcon from "../icons/chatgpt.png";
@@ -85,35 +84,40 @@ export const useUpdateStore = createPersistStore(
}));
if (window.__TAURI__?.notification && isApp) {
// Check if notification permission is granted
await window.__TAURI__?.notification.isPermissionGranted().then((granted) => {
if (!granted) {
return;
} else {
// Request permission to show notifications
window.__TAURI__?.notification.requestPermission().then((permission) => {
if (permission === 'granted') {
if (version === remoteId) {
// Show a notification using Tauri
window.__TAURI__?.notification.sendNotification({
title: "ChatGPT Next Web",
body: `${Locale.Settings.Update.IsLatest}`,
icon: `${ChatGptIcon.src}`,
sound: "Default"
});
} else {
const updateMessage = Locale.Settings.Update.FoundUpdate(`${remoteId}`);
// Show a notification for the new version using Tauri
window.__TAURI__?.notification.sendNotification({
title: "ChatGPT Next Web",
body: updateMessage,
icon: `${ChatGptIcon.src}`,
sound: "Default"
});
}
}
});
}
});
await window.__TAURI__?.notification
.isPermissionGranted()
.then((granted) => {
if (!granted) {
return;
} else {
// Request permission to show notifications
window.__TAURI__?.notification
.requestPermission()
.then((permission) => {
if (permission === "granted") {
if (version === remoteId) {
// Show a notification using Tauri
window.__TAURI__?.notification.sendNotification({
title: "ChatGPT Next Web",
body: `${Locale.Settings.Update.IsLatest}`,
icon: `${ChatGptIcon.src}`,
sound: "Default",
});
} else {
const updateMessage =
Locale.Settings.Update.FoundUpdate(`${remoteId}`);
// Show a notification for the new version using Tauri
window.__TAURI__?.notification.sendNotification({
title: "ChatGPT Next Web",
body: updateMessage,
icon: `${ChatGptIcon.src}`,
sound: "Default",
});
}
}
});
}
});
}
console.log("[Got Upstream] ", remoteId);
} catch (error) {
@@ -130,14 +134,7 @@ export const useUpdateStore = createPersistStore(
}));
try {
const usage = await api.llm.usage();
if (usage) {
set(() => ({
used: usage.used,
subscription: usage.total,
}));
}
// TODO: add check usage api here
} catch (e) {
console.error((e as Error).message);
}