Merge branch 'main' into tts-stt

This commit is contained in:
DDMeaqua
2024-09-18 10:39:56 +08:00
62 changed files with 2254 additions and 524 deletions

View File

@@ -1,32 +1,43 @@
import { trimTopic, getMessageTextContent } from "../utils";
import { getMessageTextContent, trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { nanoid } from "nanoid";
import type {
ClientApi,
MultimodalContent,
RequestMessage,
} from "../client/api";
import { getClientApi } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS,
DEFAULT_SYSTEM_TEMPLATE,
KnowledgeCutOffDate,
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import { getClientApi } from "../client/api";
import type {
ClientApi,
RequestMessage,
MultimodalContent,
} from "../client/api";
import { ChatControllerPool } from "../client/controller";
import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { isDalle3 } from "../utils";
import { estimateTokenLength } from "../utils/token";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
const localStorage = safeLocalStorage();
export type ChatMessageTool = {
id: string;
index?: number;
type?: string;
function?: {
name: string;
arguments?: string;
};
content?: string;
isError?: boolean;
};
export type ChatMessage = RequestMessage & {
date: string;
@@ -34,6 +45,7 @@ export type ChatMessage = RequestMessage & {
isError?: boolean;
id: string;
model?: ModelType;
tools?: ChatMessageTool[];
};
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -90,27 +102,6 @@ function createEmptySession(): ChatSession {
};
}
function getSummarizeModel(currentModel: string) {
// if it is using gpt-* models, force to use 4o-mini to summarize
if (currentModel.startsWith("gpt")) {
const configStore = useAppConfig.getState();
const accessStore = useAccessStore.getState();
const allModel = collectModelsWithDefaultModel(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
const summarizeModel = allModel.find(
(m) => m.name === SUMMARIZE_MODEL && m.available,
);
return summarizeModel?.name ?? currentModel;
}
if (currentModel.startsWith("gemini")) {
return GEMINI_SUMMARIZE_MODEL;
}
return currentModel;
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce(
(pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -165,6 +156,7 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) {
const DEFAULT_CHAT_STATE = {
sessions: [createEmptySession()],
currentSessionIndex: 0,
lastInput: "",
};
export const useChatStore = createPersistStore(
@@ -389,8 +381,24 @@ export const useChatStore = createPersistStore(
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool);
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onAfterTool(tool: ChatMessageTool) {
botMessage?.tools?.forEach((t, i, tools) => {
if (tool.id == t.id) {
tools[i] = { ...tool };
}
});
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onError(error) {
const isAborted = error.message.includes("aborted");
const isAborted = error.message?.includes?.("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
@@ -446,7 +454,8 @@ export const useChatStore = createPersistStore(
// system prompts, to get close to OpenAI Web ChatGPT
const shouldInjectSystemPrompts =
modelConfig.enableInjectSystemPrompts &&
session.mask.modelConfig.model.startsWith("gpt-");
(session.mask.modelConfig.model.startsWith("gpt-") ||
session.mask.modelConfig.model.startsWith("chatgpt-"));
var systemPrompts: ChatMessage[] = [];
systemPrompts = shouldInjectSystemPrompts
@@ -538,7 +547,7 @@ export const useChatStore = createPersistStore(
});
},
summarizeSession() {
summarizeSession(refreshTitle: boolean = false) {
const config = useAppConfig.getState();
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
@@ -547,7 +556,7 @@ export const useChatStore = createPersistStore(
return;
}
const providerName = modelConfig.providerName;
const providerName = modelConfig.compressProviderName;
const api: ClientApi = getClientApi(providerName);
// remove error messages if any
@@ -556,20 +565,30 @@ export const useChatStore = createPersistStore(
// should summarize topic after chating more than 50 words
const SUMMARIZE_MIN_LEN = 50;
if (
config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN
(config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
refreshTitle
) {
const topicMessages = messages.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
const startIndex = Math.max(
0,
messages.length - modelConfig.historyMessageCount,
);
const topicMessages = messages
.slice(
startIndex < messages.length ? startIndex : messages.length - 1,
messages.length,
)
.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model),
model: modelConfig.compressModel,
stream: false,
providerName,
},
@@ -632,7 +651,7 @@ export const useChatStore = createPersistStore(
config: {
...modelcfg,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model),
model: modelConfig.compressModel,
},
onUpdate(message) {
session.memoryPrompt = message;
@@ -665,17 +684,23 @@ export const useChatStore = createPersistStore(
set(() => ({ sessions }));
},
clearAllData() {
async clearAllData() {
await indexedDBStorage.clear();
localStorage.clear();
location.reload();
},
setLastInput(lastInput: string) {
set({
lastInput,
});
},
};
return methods;
},
{
name: StoreKey.Chat,
version: 3.1,
version: 3.2,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
@@ -722,6 +747,16 @@ export const useChatStore = createPersistStore(
});
}
// add default summarize model for every session
if (version < 3.2) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
s.mask.modelConfig.compressProviderName =
config.modelConfig.compressProviderName;
});
}
return newState as any;
},
},

View File

@@ -63,7 +63,7 @@ export const DEFAULT_CONFIG = {
models: DEFAULT_MODELS as any as LLMModel[],
modelConfig: {
model: "gpt-3.5-turbo" as ModelType,
model: "gpt-4o-mini" as ModelType,
providerName: "OpenAI" as ServiceProvider,
temperature: 0.5,
top_p: 1,
@@ -73,6 +73,8 @@ export const DEFAULT_CONFIG = {
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
compressModel: "gpt-4o-mini" as ModelType,
compressProviderName: "OpenAI" as ServiceProvider,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
@@ -189,7 +191,7 @@ export const useAppConfig = createPersistStore(
}),
{
name: StoreKey.Config,
version: 3.9,
version: 4,
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
@@ -227,6 +229,13 @@ export const useAppConfig = createPersistStore(
: config?.template ?? DEFAULT_INPUT_TEMPLATE;
}
if (version < 4) {
state.modelConfig.compressModel =
DEFAULT_CONFIG.modelConfig.compressModel;
state.modelConfig.compressProviderName =
DEFAULT_CONFIG.modelConfig.compressProviderName;
}
return state as any;
},
},

View File

@@ -2,3 +2,4 @@ export * from "./chat";
export * from "./update";
export * from "./access";
export * from "./config";
export * from "./plugin";

View File

@@ -2,7 +2,7 @@ import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, useAppConfig } from "./config";
import { StoreKey, Plugin } from "../constant";
import { StoreKey } from "../constant";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
@@ -17,14 +17,18 @@ export type Mask = {
modelConfig: ModelConfig;
lang: Lang;
builtin: boolean;
plugin?: Plugin[];
plugin?: string[];
enableArtifacts?: boolean;
};
export const DEFAULT_MASK_STATE = {
masks: {} as Record<string, Mask>,
language: undefined as Lang | undefined,
};
export type MaskState = typeof DEFAULT_MASK_STATE;
export type MaskState = typeof DEFAULT_MASK_STATE & {
language?: Lang | undefined;
};
export const DEFAULT_MASK_AVATAR = "gpt-bot";
export const createEmptyMask = () =>
@@ -38,7 +42,7 @@ export const createEmptyMask = () =>
lang: getLang(),
builtin: false,
createdAt: Date.now(),
plugin: [Plugin.Artifacts],
plugin: [],
}) as Mask;
export const useMaskStore = createPersistStore(
@@ -101,6 +105,11 @@ export const useMaskStore = createPersistStore(
search(text: string) {
return Object.values(get().masks);
},
setLanguage(language: Lang | undefined) {
set({
language,
});
},
}),
{
name: StoreKey.Mask,

225
app/store/plugin.ts Normal file
View File

@@ -0,0 +1,225 @@
import OpenAPIClientAxios from "openapi-client-axios";
import { getLang, Lang } from "../locales";
import { StoreKey } from "../constant";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import yaml from "js-yaml";
import { adapter } from "../utils";
export type Plugin = {
id: string;
createdAt: number;
title: string;
version: string;
content: string;
builtin: boolean;
authType?: string;
authLocation?: string;
authHeader?: string;
authToken?: string;
usingProxy?: boolean;
};
export type FunctionToolItem = {
type: string;
function: {
name: string;
description?: string;
parameters: Object;
};
};
type FunctionToolServiceItem = {
api: OpenAPIClientAxios;
length: number;
tools: FunctionToolItem[];
funcs: Record<string, Function>;
};
export const FunctionToolService = {
tools: {} as Record<string, FunctionToolServiceItem>,
add(plugin: Plugin, replace = false) {
if (!replace && this.tools[plugin.id]) return this.tools[plugin.id];
const headerName = (
plugin?.authType == "custom" ? plugin?.authHeader : "Authorization"
) as string;
const tokenValue =
plugin?.authType == "basic"
? `Basic ${plugin?.authToken}`
: plugin?.authType == "bearer"
? ` Bearer ${plugin?.authToken}`
: plugin?.authToken;
const authLocation = plugin?.authLocation || "header";
const definition = yaml.load(plugin.content) as any;
const serverURL = definition?.servers?.[0]?.url;
const baseURL = !!plugin?.usingProxy ? "/api/proxy" : serverURL;
const headers: Record<string, string | undefined> = {
"X-Base-URL": !!plugin?.usingProxy ? serverURL : undefined,
};
if (authLocation == "header") {
headers[headerName] = tokenValue;
}
const api = new OpenAPIClientAxios({
definition: yaml.load(plugin.content) as any,
axiosConfigDefaults: {
adapter: (window.__TAURI__ ? adapter : ["xhr"]) as any,
baseURL,
headers,
},
});
try {
api.initSync();
} catch (e) {}
const operations = api.getOperations();
return (this.tools[plugin.id] = {
api,
length: operations.length,
tools: operations.map((o) => {
// @ts-ignore
const parameters = o?.requestBody?.content["application/json"]
?.schema || {
type: "object",
properties: {},
};
if (!parameters["required"]) {
parameters["required"] = [];
}
if (o.parameters instanceof Array) {
o.parameters.forEach((p) => {
// @ts-ignore
if (p?.in == "query" || p?.in == "path") {
// const name = `${p.in}__${p.name}`
// @ts-ignore
const name = p?.name;
parameters["properties"][name] = {
// @ts-ignore
type: p.schema.type,
// @ts-ignore
description: p.description,
};
// @ts-ignore
if (p.required) {
parameters["required"].push(name);
}
}
});
}
return {
type: "function",
function: {
name: o.operationId,
description: o.description || o.summary,
parameters: parameters,
},
} as FunctionToolItem;
}),
funcs: operations.reduce((s, o) => {
// @ts-ignore
s[o.operationId] = function (args) {
const parameters: Record<string, any> = {};
if (o.parameters instanceof Array) {
o.parameters.forEach((p) => {
// @ts-ignore
parameters[p?.name] = args[p?.name];
// @ts-ignore
delete args[p?.name];
});
}
if (authLocation == "query") {
parameters[headerName] = tokenValue;
} else if (authLocation == "body") {
args[headerName] = tokenValue;
}
// @ts-ignore
return api.client[o.operationId](
parameters,
args,
api.axiosConfigDefaults,
);
};
return s;
}, {}),
});
},
get(id: string) {
return this.tools[id];
},
};
export const createEmptyPlugin = () =>
({
id: nanoid(),
title: "",
version: "1.0.0",
content: "",
builtin: false,
createdAt: Date.now(),
}) as Plugin;
export const DEFAULT_PLUGIN_STATE = {
plugins: {} as Record<string, Plugin>,
};
export const usePluginStore = createPersistStore(
{ ...DEFAULT_PLUGIN_STATE },
(set, get) => ({
create(plugin?: Partial<Plugin>) {
const plugins = get().plugins;
const id = nanoid();
plugins[id] = {
...createEmptyPlugin(),
...plugin,
id,
builtin: false,
};
set(() => ({ plugins }));
get().markUpdate();
return plugins[id];
},
updatePlugin(id: string, updater: (plugin: Plugin) => void) {
const plugins = get().plugins;
const plugin = plugins[id];
if (!plugin) return;
const updatePlugin = { ...plugin };
updater(updatePlugin);
plugins[id] = updatePlugin;
FunctionToolService.add(updatePlugin, true);
set(() => ({ plugins }));
get().markUpdate();
},
delete(id: string) {
const plugins = get().plugins;
delete plugins[id];
set(() => ({ plugins }));
get().markUpdate();
},
getAsTools(ids: string[]) {
const plugins = get().plugins;
const selected = (ids || [])
.map((id) => plugins[id])
.filter((i) => i)
.map((p) => FunctionToolService.add(p));
return [
// @ts-ignore
selected.reduce((s, i) => s.concat(i.tools), []),
selected.reduce((s, i) => Object.assign(s, i.funcs), {}),
];
},
get(id?: string) {
return get().plugins[id ?? 1145141919810];
},
getAll() {
return Object.values(get().plugins).sort(
(a, b) => b.createdAt - a.createdAt,
);
},
}),
{
name: StoreKey.Plugin,
version: 1,
},
);