refactor: #1000 #1179 api layer for client-side only mode and local models

This commit is contained in:
Yidadaa
2023-05-15 01:33:46 +08:00
parent bd90caa99d
commit a3de277c43
15 changed files with 247 additions and 593 deletions

View File

@@ -1,39 +1,35 @@
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { ACCESS_CODE_PREFIX } from "../constant";
import { ModelType, useAccessStore } from "../store";
import { ModelConfig, ModelType, useAccessStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
export enum MessageRole {
System = "system",
User = "user",
Assistant = "assistant",
}
export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];
export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;
export interface Message {
export interface RequestMessage {
role: MessageRole;
content: string;
}
export interface LLMConfig {
model: string;
temperature?: number;
topP?: number;
top_p?: number;
stream?: boolean;
presencePenalty?: number;
frequencyPenalty?: number;
presence_penalty?: number;
frequency_penalty?: number;
}
export interface ChatOptions {
messages: Message[];
model: ChatModel;
messages: RequestMessage[];
config: LLMConfig;
onUpdate: (message: string, chunk: string) => void;
onUpdate?: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError: (err: Error) => void;
onUnAuth: () => void;
onError?: (err: Error) => void;
onController?: (controller: AbortController) => void;
}
export interface LLMUsage {
@@ -53,28 +49,6 @@ export class ClientApi {
this.llm = new ChatGPTApi();
}
headers() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};
const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;
// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
return headers;
}
config() {}
prompts() {}

View File

@@ -1,13 +1,13 @@
import { REQUEST_TIMEOUT_MS } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
EventStreamContentType,
fetchEventSource,
} from "@microsoft/fetch-event-source";
import { ChatOptions, LLMApi, LLMUsage } from "../api";
import { ChatOptions, getHeaders, LLMApi, LLMUsage } from "../api";
import Locale from "../../locales";
export class ChatGPTApi implements LLMApi {
public ChatPath = "v1/chat/completions";
public UsagePath = "dashboard/billing/usage";
public SubsPath = "dashboard/billing/subscription";
path(path: string): string {
const openaiUrl = useAccessStore.getState().openaiUrl;
@@ -29,7 +29,7 @@ export class ChatGPTApi implements LLMApi {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.model,
model: options.config.model,
},
};
@@ -45,6 +45,7 @@ export class ChatGPTApi implements LLMApi {
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(this.ChatPath);
@@ -52,6 +53,7 @@ export class ChatGPTApi implements LLMApi {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
headers: getHeaders(),
};
// make a fetch request
@@ -59,66 +61,128 @@ export class ChatGPTApi implements LLMApi {
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
if (shouldStream) {
let responseText = "";
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
if (
res.ok &&
res.headers.get("Content-Type") === EventStreamContentType
) {
return;
}
const finish = () => {
options.onFinish(responseText);
};
if (res.status === 401) {
// TODO: Unauthorized 401
responseText += "\n\n";
} else if (res.status !== 200) {
console.error("[Request] response", res);
throw new Error("[Request] server error");
const res = await fetch(chatPath, chatPayload);
clearTimeout(reqestTimeoutId);
if (res.status === 401) {
responseText += "\n\n" + Locale.Error.Unauthorized;
return finish();
}
if (
!res.ok ||
!res.headers.get("Content-Type")?.includes("stream") ||
!res.body
) {
return options.onError?.(new Error());
}
const reader = res.body.getReader();
const decoder = new TextDecoder("utf-8");
while (true) {
const { done, value } = await reader.read();
if (done) {
return finish();
}
const chunk = decoder.decode(value);
const lines = chunk.split("data: ");
for (const line of lines) {
const text = line.trim();
if (line.startsWith("[DONE]")) {
return finish();
}
},
onmessage: (ev) => {
if (ev.data === "[DONE]") {
return options.onFinish(responseText);
if (text.length === 0) continue;
const json = JSON.parse(text);
const delta = json.choices[0].delta.content;
if (delta) {
responseText += delta;
options.onUpdate?.(responseText, delta);
}
try {
const resJson = JSON.parse(ev.data);
const message = this.extractMessage(resJson);
responseText += message;
options.onUpdate(responseText, message);
} catch (e) {
console.error("[Request] stream error", e);
options.onError(e as Error);
}
},
onclose() {
options.onError(new Error("stream closed unexpected"));
},
onerror(err) {
options.onError(err);
},
});
}
}
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(reqestTimeoutId);
const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
}
clearTimeout(reqestTimeoutId);
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError(e as Error);
options.onError?.(e as Error);
}
}
async usage() {
const formatDate = (d: Date) =>
`${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d
.getDate()
.toString()
.padStart(2, "0")}`;
const ONE_DAY = 1 * 24 * 60 * 60 * 1000;
const now = new Date();
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
const startDate = formatDate(startOfMonth);
const endDate = formatDate(new Date(Date.now() + ONE_DAY));
const [used, subs] = await Promise.all([
fetch(
this.path(
`${this.UsagePath}?start_date=${startDate}&end_date=${endDate}`,
),
{
method: "GET",
headers: getHeaders(),
},
),
fetch(this.path(this.SubsPath), {
method: "GET",
headers: getHeaders(),
}),
]);
if (!used.ok || !subs.ok || used.status === 401) {
throw new Error(Locale.Error.Unauthorized);
}
const response = (await used.json()) as {
total_usage?: number;
error?: {
type: string;
message: string;
};
};
const total = (await subs.json()) as {
hard_limit_usd?: number;
};
if (response.error && response.error.type) {
throw Error(response.error.message);
}
if (response.total_usage) {
response.total_usage = Math.round(response.total_usage) / 100;
}
if (total.hard_limit_usd) {
total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100;
}
return {
used: 0,
total: 0,
used: response.total_usage,
total: total.hard_limit_usd,
} as LLMUsage;
}
}