Merge branch 'Yidadaa:main' into main

This commit is contained in:
Gan-Xing
2023-05-15 07:31:52 +08:00
committed by GitHub
21 changed files with 524 additions and 593 deletions

View File

@@ -1,49 +1,8 @@
import { createParser } from "eventsource-parser";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { requestOpenai } from "../../common";
async function createStream(res: Response) {
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const stream = new ReadableStream({
async start(controller) {
function onParse(event: any) {
if (event.type === "event") {
const data = event.data;
// https://beta.openai.com/docs/api-reference/completions/create#completions/create-stream
if (data === "[DONE]") {
controller.close();
return;
}
try {
const json = JSON.parse(data);
const text = json.choices[0].delta.content;
const queue = encoder.encode(text);
controller.enqueue(queue);
} catch (e) {
controller.error(e);
}
}
}
const parser = createParser(onParse);
for await (const chunk of res.body as any) {
parser.feed(decoder.decode(chunk, { stream: true }));
}
},
});
return stream;
}
function formatResponse(msg: any) {
const jsonMsg = ["```json\n", JSON.stringify(msg, null, " "), "\n```"].join(
"",
);
return new Response(jsonMsg);
}
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
@@ -58,40 +17,10 @@ async function handle(
}
try {
const api = await requestOpenai(req);
const contentType = api.headers.get("Content-Type") ?? "";
// streaming response
if (contentType.includes("stream")) {
const stream = await createStream(api);
const res = new Response(stream);
res.headers.set("Content-Type", contentType);
return res;
}
// try to parse error msg
try {
const mayBeErrorBody = await api.json();
if (mayBeErrorBody.error) {
console.error("[OpenAI Response] ", mayBeErrorBody);
return formatResponse(mayBeErrorBody);
} else {
const res = new Response(JSON.stringify(mayBeErrorBody));
res.headers.set("Content-Type", "application/json");
res.headers.set("Cache-Control", "no-cache");
return res;
}
} catch (e) {
console.error("[OpenAI Parse] ", e);
return formatResponse({
msg: "invalid response from openai server",
error: e,
});
}
return await requestOpenai(req);
} catch (e) {
console.error("[OpenAI] ", e);
return formatResponse(e);
return NextResponse.json(prettyObject(e));
}
}

View File

@@ -1,9 +0,0 @@
import type {
CreateChatCompletionRequest,
CreateChatCompletionResponse,
} from "openai";
export type ChatRequest = CreateChatCompletionRequest;
export type ChatResponse = CreateChatCompletionResponse;
export type Updater<T> = (updater: (value: T) => void) => void;

83
app/client/api.ts Normal file
View File

@@ -0,0 +1,83 @@
import { ACCESS_CODE_PREFIX } from "../constant";
import { ModelConfig, ModelType, useAccessStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];
export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;
export interface RequestMessage {
role: MessageRole;
content: string;
}
export interface LLMConfig {
model: string;
temperature?: number;
top_p?: number;
stream?: boolean;
presence_penalty?: number;
frequency_penalty?: number;
}
export interface ChatOptions {
messages: RequestMessage[];
config: LLMConfig;
onUpdate?: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError?: (err: Error) => void;
onController?: (controller: AbortController) => void;
}
export interface LLMUsage {
used: number;
total: number;
}
export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
}
export class ClientApi {
public llm: LLMApi;
constructor() {
this.llm = new ChatGPTApi();
}
config() {}
prompts() {}
masks() {}
}
export const api = new ClientApi();
export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {
"Content-Type": "application/json",
};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}

37
app/client/controller.ts Normal file
View File

@@ -0,0 +1,37 @@
// To store message streaming controller
export const ChatControllerPool = {
controllers: {} as Record<string, AbortController>,
addController(
sessionIndex: number,
messageId: number,
controller: AbortController,
) {
const key = this.key(sessionIndex, messageId);
this.controllers[key] = controller;
return key;
},
stop(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
const controller = this.controllers[key];
controller?.abort();
},
stopAll() {
Object.values(this.controllers).forEach((v) => v.abort());
},
hasPending() {
return Object.values(this.controllers).length > 0;
},
remove(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
delete this.controllers[key];
},
key(sessionIndex: number, messageIndex: number) {
return `${sessionIndex},${messageIndex}`;
},
};

View File

@@ -0,0 +1,192 @@
import { REQUEST_TIMEOUT_MS } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { ChatOptions, getHeaders, LLMApi, LLMUsage } from "../api";
import Locale from "../../locales";
export class ChatGPTApi implements LLMApi {
public ChatPath = "v1/chat/completions";
public UsagePath = "dashboard/billing/usage";
public SubsPath = "dashboard/billing/subscription";
path(path: string): string {
const openaiUrl = useAccessStore.getState().openaiUrl;
if (openaiUrl.endsWith("/")) openaiUrl.slice(0, openaiUrl.length - 1);
return [openaiUrl, path].join("/");
}
extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}
async chat(options: ChatOptions) {
const messages = options.messages.map((v) => ({
role: v.role,
content: v.content,
}));
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
},
};
const requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
};
console.log("[Request] openai payload: ", requestPayload);
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(this.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
headers: getHeaders(),
};
// make a fetch request
const reqestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
if (shouldStream) {
let responseText = "";
const finish = () => {
options.onFinish(responseText);
};
const res = await fetch(chatPath, chatPayload);
clearTimeout(reqestTimeoutId);
if (res.status === 401) {
responseText += "\n\n" + Locale.Error.Unauthorized;
return finish();
}
if (
!res.ok ||
!res.headers.get("Content-Type")?.includes("stream") ||
!res.body
) {
return options.onError?.(new Error());
}
const reader = res.body.getReader();
const decoder = new TextDecoder("utf-8");
while (true) {
const { done, value } = await reader.read();
if (done) {
return finish();
}
const chunk = decoder.decode(value, { stream: true });
const lines = chunk.split("data: ");
for (const line of lines) {
const text = line.trim();
if (line.startsWith("[DONE]")) {
return finish();
}
if (text.length === 0) continue;
try {
const json = JSON.parse(text);
const delta = json.choices[0].delta.content;
if (delta) {
responseText += delta;
options.onUpdate?.(responseText, delta);
}
} catch (e) {
console.error("[Request] parse error", text, chunk);
}
}
}
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(reqestTimeoutId);
const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
}
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError?.(e as Error);
}
}
async usage() {
const formatDate = (d: Date) =>
`${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d
.getDate()
.toString()
.padStart(2, "0")}`;
const ONE_DAY = 1 * 24 * 60 * 60 * 1000;
const now = new Date();
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
const startDate = formatDate(startOfMonth);
const endDate = formatDate(new Date(Date.now() + ONE_DAY));
const [used, subs] = await Promise.all([
fetch(
this.path(
`${this.UsagePath}?start_date=${startDate}&end_date=${endDate}`,
),
{
method: "GET",
headers: getHeaders(),
},
),
fetch(this.path(this.SubsPath), {
method: "GET",
headers: getHeaders(),
}),
]);
if (!used.ok || !subs.ok || used.status === 401) {
throw new Error(Locale.Error.Unauthorized);
}
const response = (await used.json()) as {
total_usage?: number;
error?: {
type: string;
message: string;
};
};
const total = (await subs.json()) as {
hard_limit_usd?: number;
};
if (response.error && response.error.type) {
throw Error(response.error.message);
}
if (response.total_usage) {
response.total_usage = Math.round(response.total_usage) / 100;
}
if (total.hard_limit_usd) {
total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100;
}
return {
used: response.total_usage,
total: total.hard_limit_usd,
} as LLMUsage;
}
}

View File

@@ -22,7 +22,7 @@ import BottomIcon from "../icons/bottom.svg";
import StopIcon from "../icons/pause.svg";
import {
Message,
ChatMessage,
SubmitKey,
useChatStore,
BOT_HELLO,
@@ -43,7 +43,7 @@ import {
import dynamic from "next/dynamic";
import { ControllerPool } from "../requests";
import { ChatControllerPool } from "../client/controller";
import { Prompt, usePromptStore } from "../store/prompt";
import Locale from "../locales";
@@ -63,7 +63,7 @@ const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
});
function exportMessages(messages: Message[], topic: string) {
function exportMessages(messages: ChatMessage[], topic: string) {
const mdText =
`# ${topic}\n\n` +
messages
@@ -331,8 +331,8 @@ export function ChatActions(props: {
}
// stop all responses
const couldStop = ControllerPool.hasPending();
const stopAll = () => ControllerPool.stopAll();
const couldStop = ChatControllerPool.hasPending();
const stopAll = () => ChatControllerPool.stopAll();
return (
<div className={chatStyle["chat-input-actions"]}>
@@ -394,7 +394,7 @@ export function ChatActions(props: {
}
export function Chat() {
type RenderMessage = Message & { preview?: boolean };
type RenderMessage = ChatMessage & { preview?: boolean };
const chatStore = useChatStore();
const [session, sessionIndex] = useChatStore((state) => [
@@ -487,7 +487,7 @@ export function Chat() {
// stop response
const onUserStop = (messageId: number) => {
ControllerPool.stop(sessionIndex, messageId);
ChatControllerPool.stop(sessionIndex, messageId);
};
// check if should send message
@@ -507,7 +507,7 @@ export function Chat() {
e.preventDefault();
}
};
const onRightClick = (e: any, message: Message) => {
const onRightClick = (e: any, message: ChatMessage) => {
// copy to clipboard
if (selectOrCopy(e.currentTarget, message.content)) {
e.preventDefault();

View File

@@ -186,7 +186,7 @@
.chat-item-delete {
position: absolute;
top: 10px;
right: -20px;
right: 0;
transition: all ease 0.3s;
opacity: 0;
cursor: pointer;
@@ -194,7 +194,7 @@
.chat-item:hover > .chat-item-delete {
opacity: 0.5;
right: 10px;
transform: translateX(-10px);
}
.chat-item:hover > .chat-item-delete:hover {

View File

@@ -23,7 +23,6 @@ import {
} from "react-router-dom";
import { SideBar } from "./sidebar";
import { useAppConfig } from "../store/config";
import { useMaskStore } from "../store/mask";
export function Loading(props: { noLogo?: boolean }) {
return (
@@ -91,12 +90,24 @@ const useHasHydrated = () => {
return hasHydrated;
};
const loadAsyncGoogleFont = () => {
const linkEl = document.createElement("link");
linkEl.rel = "stylesheet";
linkEl.href =
"/google-fonts/css2?family=Noto+Sans+SC:wght@300;400;700;900&display=swap";
document.head.appendChild(linkEl);
};
function Screen() {
const config = useAppConfig();
const location = useLocation();
const isHome = location.pathname === Path.Home;
const isMobileScreen = useMobileScreen();
useEffect(() => {
loadAsyncGoogleFont();
}, []);
return (
<div
className={

View File

@@ -13,7 +13,8 @@ import EyeIcon from "../icons/eye.svg";
import CopyIcon from "../icons/copy.svg";
import { DEFAULT_MASK_AVATAR, Mask, useMaskStore } from "../store/mask";
import { Message, ModelConfig, ROLES, useChatStore } from "../store";
import { ChatMessage, ModelConfig, useChatStore } from "../store";
import { ROLES } from "../client/api";
import { Input, List, ListItem, Modal, Popover, Select } from "./ui-lib";
import { Avatar, AvatarPicker } from "./emoji";
import Locale, { AllLangs, Lang } from "../locales";
@@ -22,7 +23,7 @@ import { useNavigate } from "react-router-dom";
import chatStyle from "./chat.module.scss";
import { useState } from "react";
import { downloadAs, readFromFile } from "../utils";
import { Updater } from "../api/openai/typing";
import { Updater } from "../typing";
import { ModelConfigList } from "./model-config";
import { FileName, Path } from "../constant";
import { BUILTIN_MASK_STORE } from "../masks";
@@ -107,8 +108,8 @@ export function MaskConfig(props: {
}
function ContextPromptItem(props: {
prompt: Message;
update: (prompt: Message) => void;
prompt: ChatMessage;
update: (prompt: ChatMessage) => void;
remove: () => void;
}) {
const [focusingInput, setFocusingInput] = useState(false);
@@ -160,12 +161,12 @@ function ContextPromptItem(props: {
}
export function ContextPrompts(props: {
context: Message[];
updateContext: (updater: (context: Message[]) => void) => void;
context: ChatMessage[];
updateContext: (updater: (context: ChatMessage[]) => void) => void;
}) {
const context = props.context;
const addContextPrompt = (prompt: Message) => {
const addContextPrompt = (prompt: ChatMessage) => {
props.updateContext((context) => context.push(prompt));
};
@@ -173,7 +174,7 @@ export function ContextPrompts(props: {
props.updateContext((context) => context.splice(i, 1));
};
const updateContextPrompt = (i: number, prompt: Message) => {
const updateContextPrompt = (i: number, prompt: ChatMessage) => {
props.updateContext((context) => (context[i] = prompt));
};

View File

@@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100;
export const ACCESS_CODE_PREFIX = "ak-";
export const LAST_INPUT_KEY = "last-input";
export const REQUEST_TIMEOUT_MS = 60000;

View File

@@ -34,11 +34,6 @@ export default function RootLayout({
<head>
<meta name="version" content={buildConfig.commitId} />
<link rel="manifest" href="/site.webmanifest"></link>
<link rel="preconnect" href="https://fonts.proxy.ustclug.org"></link>
<link
rel="stylesheet"
href="https://fonts.proxy.ustclug.org/css2?family=Noto+Sans+SC:wght@300;400;700;900&display=swap"
></link>
<script src="/serviceWorkerRegister.js" defer></script>
</head>
<body>{children}</body>

View File

@@ -1,285 +0,0 @@
import type { ChatRequest, ChatResponse } from "./api/openai/typing";
import {
Message,
ModelConfig,
ModelType,
useAccessStore,
useAppConfig,
useChatStore,
} from "./store";
import { showToast } from "./components/ui-lib";
import { ACCESS_CODE_PREFIX } from "./constant";
const TIME_OUT_MS = 60000;
const makeRequestParam = (
messages: Message[],
options?: {
stream?: boolean;
overrideModel?: ModelType;
},
): ChatRequest => {
let sendMessages = messages.map((v) => ({
role: v.role,
content: v.content,
}));
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
};
// override model config
if (options?.overrideModel) {
modelConfig.model = options.overrideModel;
}
return {
messages: sendMessages,
stream: options?.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
};
};
export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}
export function requestOpenaiClient(path: string) {
const openaiUrl = useAccessStore.getState().openaiUrl;
return (body: any, method = "POST") =>
fetch(openaiUrl + path, {
method,
body: body && JSON.stringify(body),
headers: getHeaders(),
});
}
export async function requestChat(
messages: Message[],
options?: {
model?: ModelType;
},
) {
const req: ChatRequest = makeRequestParam(messages, {
overrideModel: options?.model,
});
const res = await requestOpenaiClient("v1/chat/completions")(req);
try {
const response = (await res.json()) as ChatResponse;
return response;
} catch (error) {
console.error("[Request Chat] ", error, res.body);
}
}
export async function requestUsage() {
const formatDate = (d: Date) =>
`${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d
.getDate()
.toString()
.padStart(2, "0")}`;
const ONE_DAY = 1 * 24 * 60 * 60 * 1000;
const now = new Date();
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
const startDate = formatDate(startOfMonth);
const endDate = formatDate(new Date(Date.now() + ONE_DAY));
const [used, subs] = await Promise.all([
requestOpenaiClient(
`dashboard/billing/usage?start_date=${startDate}&end_date=${endDate}`,
)(null, "GET"),
requestOpenaiClient("dashboard/billing/subscription")(null, "GET"),
]);
const response = (await used.json()) as {
total_usage?: number;
error?: {
type: string;
message: string;
};
};
const total = (await subs.json()) as {
hard_limit_usd?: number;
};
if (response.error && response.error.type) {
showToast(response.error.message);
return;
}
if (response.total_usage) {
response.total_usage = Math.round(response.total_usage) / 100;
}
if (total.hard_limit_usd) {
total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100;
}
return {
used: response.total_usage,
subscription: total.hard_limit_usd,
};
}
export async function requestChatStream(
messages: Message[],
options?: {
modelConfig?: ModelConfig;
overrideModel?: ModelType;
onMessage: (message: string, done: boolean) => void;
onError: (error: Error, statusCode?: number) => void;
onController?: (controller: AbortController) => void;
},
) {
const req = makeRequestParam(messages, {
stream: true,
overrideModel: options?.overrideModel,
});
console.log("[Request] ", req);
const controller = new AbortController();
const reqTimeoutId = setTimeout(() => controller.abort(), TIME_OUT_MS);
try {
const openaiUrl = useAccessStore.getState().openaiUrl;
const res = await fetch(openaiUrl + "v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getHeaders(),
},
body: JSON.stringify(req),
signal: controller.signal,
});
clearTimeout(reqTimeoutId);
let responseText = "";
const finish = () => {
options?.onMessage(responseText, true);
controller.abort();
};
if (res.ok) {
const reader = res.body?.getReader();
const decoder = new TextDecoder();
options?.onController?.(controller);
while (true) {
const resTimeoutId = setTimeout(() => finish(), TIME_OUT_MS);
const content = await reader?.read();
clearTimeout(resTimeoutId);
if (!content || !content.value) {
break;
}
const text = decoder.decode(content.value, { stream: true });
responseText += text;
const done = content.done;
options?.onMessage(responseText, false);
if (done) {
break;
}
}
finish();
} else if (res.status === 401) {
console.error("Unauthorized");
options?.onError(new Error("Unauthorized"), res.status);
} else {
console.error("Stream Error", res.body);
options?.onError(new Error("Stream Error"), res.status);
}
} catch (err) {
console.error("NetWork Error", err);
options?.onError(err as Error);
}
}
export async function requestWithPrompt(
messages: Message[],
prompt: string,
options?: {
model?: ModelType;
},
) {
messages = messages.concat([
{
role: "user",
content: prompt,
date: new Date().toLocaleString(),
},
]);
const res = await requestChat(messages, options);
return res?.choices?.at(0)?.message?.content ?? "";
}
// To store message streaming controller
export const ControllerPool = {
controllers: {} as Record<string, AbortController>,
addController(
sessionIndex: number,
messageId: number,
controller: AbortController,
) {
const key = this.key(sessionIndex, messageId);
this.controllers[key] = controller;
return key;
},
stop(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
const controller = this.controllers[key];
controller?.abort();
},
stopAll() {
Object.values(this.controllers).forEach((v) => v.abort());
},
hasPending() {
return Object.values(this.controllers).length > 0;
},
remove(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
delete this.controllers[key];
},
key(sessionIndex: number, messageIndex: number) {
return `${sessionIndex},${messageIndex}`;
},
};

View File

@@ -1,7 +1,7 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { StoreKey } from "../constant";
import { getHeaders } from "../requests";
import { getHeaders } from "../client/api";
import { BOT_HELLO } from "./chat";
import { ALL_MODELS } from "./config";

View File

@@ -1,12 +1,6 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { type ChatCompletionResponseMessage } from "openai";
import {
ControllerPool,
requestChatStream,
requestWithPrompt,
} from "../requests";
import { trimTopic } from "../utils";
import Locale from "../locales";
@@ -14,8 +8,11 @@ import { showToast } from "../components/ui-lib";
import { ModelType } from "./config";
import { createEmptyMask, Mask } from "./mask";
import { StoreKey } from "../constant";
import { api, RequestMessage } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
export type Message = ChatCompletionResponseMessage & {
export type ChatMessage = RequestMessage & {
date: string;
streaming?: boolean;
isError?: boolean;
@@ -23,7 +20,7 @@ export type Message = ChatCompletionResponseMessage & {
model?: ModelType;
};
export function createMessage(override: Partial<Message>): Message {
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
return {
id: Date.now(),
date: new Date().toLocaleString(),
@@ -33,8 +30,6 @@ export function createMessage(override: Partial<Message>): Message {
};
}
export const ROLES: Message["role"][] = ["system", "user", "assistant"];
export interface ChatStat {
tokenCount: number;
wordCount: number;
@@ -47,7 +42,7 @@ export interface ChatSession {
topic: string;
memoryPrompt: string;
messages: Message[];
messages: ChatMessage[];
stat: ChatStat;
lastUpdate: number;
lastSummarizeIndex: number;
@@ -56,7 +51,7 @@ export interface ChatSession {
}
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
export const BOT_HELLO: Message = createMessage({
export const BOT_HELLO: ChatMessage = createMessage({
role: "assistant",
content: Locale.Store.BotHello,
});
@@ -88,24 +83,24 @@ interface ChatStore {
newSession: (mask?: Mask) => void;
deleteSession: (index: number) => void;
currentSession: () => ChatSession;
onNewMessage: (message: Message) => void;
onNewMessage: (message: ChatMessage) => void;
onUserInput: (content: string) => Promise<void>;
summarizeSession: () => void;
updateStat: (message: Message) => void;
updateStat: (message: ChatMessage) => void;
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
updateMessage: (
sessionIndex: number,
messageIndex: number,
updater: (message?: Message) => void,
updater: (message?: ChatMessage) => void,
) => void;
resetSession: () => void;
getMessagesWithMemory: () => Message[];
getMemoryPrompt: () => Message;
getMessagesWithMemory: () => ChatMessage[];
getMemoryPrompt: () => ChatMessage;
clearAllData: () => void;
}
function countMessages(msgs: Message[]) {
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
}
@@ -240,12 +235,12 @@ export const useChatStore = create<ChatStore>()(
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const userMessage: Message = createMessage({
const userMessage: ChatMessage = createMessage({
role: "user",
content,
});
const botMessage: Message = createMessage({
const botMessage: ChatMessage = createMessage({
role: "assistant",
streaming: true,
id: userMessage.id! + 1,
@@ -277,45 +272,54 @@ export const useChatStore = create<ChatStore>()(
// make request
console.log("[User Input] ", sendMessages);
requestChatStream(sendMessages, {
onMessage(content, done) {
// stream response
if (done) {
botMessage.streaming = false;
botMessage.content = content;
get().onNewMessage(botMessage);
ControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
} else {
botMessage.content = content;
set(() => ({}));
}
api.llm.chat({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
botMessage.content = message;
set(() => ({}));
},
onError(error, statusCode) {
onFinish(message) {
botMessage.streaming = false;
botMessage.content = message;
get().onNewMessage(botMessage);
ChatControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
set(() => ({}));
},
onError(error) {
const isAborted = error.message.includes("aborted");
if (statusCode === 401) {
botMessage.content = Locale.Error.Unauthorized;
} else if (!isAborted) {
if (
botMessage.content !== Locale.Error.Unauthorized &&
!isAborted
) {
botMessage.content += "\n\n" + Locale.Store.Error;
} else if (botMessage.content.length === 0) {
botMessage.content = prettyObject(error);
}
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
set(() => ({}));
ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
ChatControllerPool.remove(
sessionIndex,
botMessage.id ?? messageIndex,
);
console.error("[Chat] error ", error);
},
onController(controller) {
// collect controller for stop/retry
ControllerPool.addController(
ChatControllerPool.addController(
sessionIndex,
botMessage.id ?? messageIndex,
controller,
);
},
modelConfig: { ...modelConfig },
});
},
@@ -329,7 +333,7 @@ export const useChatStore = create<ChatStore>()(
? Locale.Store.Prompt.History(session.memoryPrompt)
: "",
date: "",
} as Message;
} as ChatMessage;
},
getMessagesWithMemory() {
@@ -384,7 +388,7 @@ export const useChatStore = create<ChatStore>()(
updateMessage(
sessionIndex: number,
messageIndex: number,
updater: (message?: Message) => void,
updater: (message?: ChatMessage) => void,
) {
const sessions = get().sessions;
const session = sessions.at(sessionIndex);
@@ -402,7 +406,7 @@ export const useChatStore = create<ChatStore>()(
summarizeSession() {
const session = get().currentSession();
// remove error messages if any
const cleanMessages = session.messages.filter((msg) => !msg.isError);
@@ -412,13 +416,24 @@ export const useChatStore = create<ChatStore>()(
session.topic === DEFAULT_TOPIC &&
countMessages(cleanMessages) >= SUMMARIZE_MIN_LEN
) {
requestWithPrompt(cleanMessages, Locale.Store.Prompt.Topic, {
model: "gpt-3.5-turbo",
}).then((res) => {
get().updateCurrentSession(
(session) =>
(session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
);
const topicMessages = cleanMessages.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
messages: topicMessages,
config: {
model: "gpt-3.5-turbo",
},
onFinish(message) {
get().updateCurrentSession(
(session) =>
(session.topic =
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
);
},
});
}
@@ -426,7 +441,7 @@ export const useChatStore = create<ChatStore>()(
let toBeSummarizedMsgs = cleanMessages.slice(
session.lastSummarizeIndex,
);
const historyMsgLength = countMessages(toBeSummarizedMsgs);
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
@@ -452,26 +467,24 @@ export const useChatStore = create<ChatStore>()(
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
session.mask.modelConfig.sendMemory
) {
requestChatStream(
toBeSummarizedMsgs.concat({
api.llm.chat({
messages: toBeSummarizedMsgs.concat({
role: "system",
content: Locale.Store.Prompt.Summarize,
date: "",
}),
{
overrideModel: "gpt-3.5-turbo",
onMessage(message, done) {
session.memoryPrompt = message;
if (done) {
console.log("[Memory] ", session.memoryPrompt);
session.lastSummarizeIndex = lastSummarizeIndex;
}
},
onError(error) {
console.error("[Summarize] ", error);
},
config: { ...modelConfig, stream: true },
onUpdate(message) {
session.memoryPrompt = message;
},
);
onFinish(message) {
console.log("[Memory] ", message);
session.lastSummarizeIndex = lastSummarizeIndex;
},
onError(err) {
console.error("[Summarize] ", err);
},
});
}
},

View File

@@ -2,7 +2,7 @@ import { create } from "zustand";
import { persist } from "zustand/middleware";
import { BUILTIN_MASKS } from "../masks";
import { getLang, Lang } from "../locales";
import { DEFAULT_TOPIC, Message } from "./chat";
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { StoreKey } from "../constant";
@@ -10,7 +10,7 @@ export type Mask = {
id: number;
avatar: string;
name: string;
context: Message[];
context: ChatMessage[];
modelConfig: ModelConfig;
lang: Lang;
builtin: boolean;

View File

@@ -1,7 +1,8 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant";
import { requestUsage } from "../requests";
import { FETCH_COMMIT_URL, StoreKey } from "../constant";
import { api } from "../client/api";
import { showToast } from "../components/ui-lib";
export interface UpdateStore {
lastUpdate: number;
@@ -73,10 +74,17 @@ export const useUpdateStore = create<UpdateStore>()(
lastUpdateUsage: Date.now(),
}));
const usage = await requestUsage();
try {
const usage = await api.llm.usage();
if (usage) {
set(() => usage);
if (usage) {
set(() => ({
used: usage.used,
subscription: usage.total,
}));
}
} catch (e) {
showToast((e as Error).message);
}
},
}),

1
app/typing.ts Normal file
View File

@@ -0,0 +1 @@
export type Updater<T> = (updater: (value: T) => void) => void;

8
app/utils/format.ts Normal file
View File

@@ -0,0 +1,8 @@
export function prettyObject(msg: any) {
const prettyMsg = [
"```json\n",
JSON.stringify(msg, null, " "),
"\n```",
].join("");
return prettyMsg;
}