refactor: #1000 #1179 api layer for client-side only mode and local models

This commit is contained in:
Yidadaa
2023-05-15 01:33:46 +08:00
parent bd90caa99d
commit a3de277c43
15 changed files with 247 additions and 593 deletions

View File

@@ -1,7 +1,7 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { StoreKey } from "../constant";
import { getHeaders } from "../requests";
import { getHeaders } from "../client/api";
import { BOT_HELLO } from "./chat";
import { ALL_MODELS } from "./config";

View File

@@ -1,12 +1,6 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { type ChatCompletionResponseMessage } from "openai";
import {
ControllerPool,
requestChatStream,
requestWithPrompt,
} from "../requests";
import { trimTopic } from "../utils";
import Locale from "../locales";
@@ -14,9 +8,11 @@ import { showToast } from "../components/ui-lib";
import { ModelType } from "./config";
import { createEmptyMask, Mask } from "./mask";
import { StoreKey } from "../constant";
import { api } from "../client/api";
import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
export type Message = ChatCompletionResponseMessage & {
export type ChatMessage = RequestMessage & {
date: string;
streaming?: boolean;
isError?: boolean;
@@ -24,7 +20,7 @@ export type Message = ChatCompletionResponseMessage & {
model?: ModelType;
};
export function createMessage(override: Partial<Message>): Message {
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
return {
id: Date.now(),
date: new Date().toLocaleString(),
@@ -34,8 +30,6 @@ export function createMessage(override: Partial<Message>): Message {
};
}
export const ROLES: Message["role"][] = ["system", "user", "assistant"];
export interface ChatStat {
tokenCount: number;
wordCount: number;
@@ -48,7 +42,7 @@ export interface ChatSession {
topic: string;
memoryPrompt: string;
messages: Message[];
messages: ChatMessage[];
stat: ChatStat;
lastUpdate: number;
lastSummarizeIndex: number;
@@ -57,7 +51,7 @@ export interface ChatSession {
}
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
export const BOT_HELLO: Message = createMessage({
export const BOT_HELLO: ChatMessage = createMessage({
role: "assistant",
content: Locale.Store.BotHello,
});
@@ -89,24 +83,24 @@ interface ChatStore {
newSession: (mask?: Mask) => void;
deleteSession: (index: number) => void;
currentSession: () => ChatSession;
onNewMessage: (message: Message) => void;
onNewMessage: (message: ChatMessage) => void;
onUserInput: (content: string) => Promise<void>;
summarizeSession: () => void;
updateStat: (message: Message) => void;
updateStat: (message: ChatMessage) => void;
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
updateMessage: (
sessionIndex: number,
messageIndex: number,
updater: (message?: Message) => void,
updater: (message?: ChatMessage) => void,
) => void;
resetSession: () => void;
getMessagesWithMemory: () => Message[];
getMemoryPrompt: () => Message;
getMessagesWithMemory: () => ChatMessage[];
getMemoryPrompt: () => ChatMessage;
clearAllData: () => void;
}
function countMessages(msgs: Message[]) {
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
}
@@ -241,12 +235,12 @@ export const useChatStore = create<ChatStore>()(
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const userMessage: Message = createMessage({
const userMessage: ChatMessage = createMessage({
role: "user",
content,
});
const botMessage: Message = createMessage({
const botMessage: ChatMessage = createMessage({
role: "assistant",
streaming: true,
id: userMessage.id! + 1,
@@ -278,45 +272,54 @@ export const useChatStore = create<ChatStore>()(
// make request
console.log("[User Input] ", sendMessages);
requestChatStream(sendMessages, {
onMessage(content, done) {
// stream response
if (done) {
botMessage.streaming = false;
botMessage.content = content;
get().onNewMessage(botMessage);
ControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
} else {
botMessage.content = content;
set(() => ({}));
}
api.llm.chat({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
botMessage.content = message;
set(() => ({}));
},
onError(error, statusCode) {
onFinish(message) {
botMessage.streaming = false;
botMessage.content = message;
get().onNewMessage(botMessage);
ChatControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
set(() => ({}));
},
onError(error) {
const isAborted = error.message.includes("aborted");
if (statusCode === 401) {
botMessage.content = Locale.Error.Unauthorized;
} else if (!isAborted) {
if (
botMessage.content !== Locale.Error.Unauthorized &&
!isAborted
) {
botMessage.content += "\n\n" + Locale.Store.Error;
} else if (botMessage.content.length === 0) {
botMessage.content = prettyObject(error);
}
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
set(() => ({}));
ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
ChatControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
console.error("[Chat] error ", error);
},
onController(controller) {
// collect controller for stop/retry
ControllerPool.addController(
ChatControllerPool.addController(
sessionIndex,
botMessage.id ?? messageIndex,
controller,
);
},
modelConfig: { ...modelConfig },
});
},
@@ -330,7 +333,7 @@ export const useChatStore = create<ChatStore>()(
? Locale.Store.Prompt.History(session.memoryPrompt)
: "",
date: "",
} as Message;
} as ChatMessage;
},
getMessagesWithMemory() {
@@ -385,7 +388,7 @@ export const useChatStore = create<ChatStore>()(
updateMessage(
sessionIndex: number,
messageIndex: number,
updater: (message?: Message) => void,
updater: (message?: ChatMessage) => void,
) {
const sessions = get().sessions;
const session = sessions.at(sessionIndex);
@@ -410,13 +413,24 @@ export const useChatStore = create<ChatStore>()(
session.topic === DEFAULT_TOPIC &&
countMessages(session.messages) >= SUMMARIZE_MIN_LEN
) {
requestWithPrompt(session.messages, Locale.Store.Prompt.Topic, {
model: "gpt-3.5-turbo",
}).then((res) => {
get().updateCurrentSession(
(session) =>
(session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
);
const topicMessages = session.messages.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
messages: topicMessages,
config: {
model: "gpt-3.5-turbo",
},
onFinish(message) {
get().updateCurrentSession(
(session) =>
(session.topic =
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
);
},
});
}
@@ -450,26 +464,24 @@ export const useChatStore = create<ChatStore>()(
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
session.mask.modelConfig.sendMemory
) {
requestChatStream(
toBeSummarizedMsgs.concat({
api.llm.chat({
messages: toBeSummarizedMsgs.concat({
role: "system",
content: Locale.Store.Prompt.Summarize,
date: "",
}),
{
overrideModel: "gpt-3.5-turbo",
onMessage(message, done) {
session.memoryPrompt = message;
if (done) {
console.log("[Memory] ", session.memoryPrompt);
session.lastSummarizeIndex = lastSummarizeIndex;
}
},
onError(error) {
console.error("[Summarize] ", error);
},
config: { ...modelConfig, stream: true },
onUpdate(message) {
session.memoryPrompt = message;
},
);
onFinish(message) {
console.log("[Memory] ", message);
session.lastSummarizeIndex = lastSummarizeIndex;
},
onError(err) {
console.error("[Summarize] ", err);
},
});
}
},

View File

@@ -2,7 +2,7 @@ import { create } from "zustand";
import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, Message } from "./chat";
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { StoreKey } from "../constant";
@@ -10,7 +10,7 @@ export type Mask = {
id: number;
avatar: string;
name: string;
context: Message[];
context: ChatMessage[];
modelConfig: ModelConfig;
lang: Lang;
builtin: boolean;

View File

@@ -1,7 +1,8 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant";
import { requestUsage } from "../requests";
import { FETCH_COMMIT_URL, StoreKey } from "../constant";
import { api } from "../client/api";
import { showToast } from "../components/ui-lib";
export interface UpdateStore {
lastUpdate: number;
@@ -73,10 +74,17 @@ export const useUpdateStore = create<UpdateStore>()(
lastUpdateUsage: Date.now(),
}));
const usage = await requestUsage();
try {
const usage = await api.llm.usage();
if (usage) {
set(() => usage);
if (usage) {
set(() => ({
used: usage.used,
subscription: usage.total,
}));
}
} catch (e) {
showToast((e as Error).message);
}
},
}),