feat: model provider refactor done

This commit is contained in:
Dean-YZG
2024-05-15 21:38:25 +08:00
parent 240d330001
commit a0e4a468d6
33 changed files with 3077 additions and 8 deletions

9
app/client/core/index.ts Normal file
View File

@@ -0,0 +1,9 @@
export * from "./types";
export * from "./providerClient";
export * from "./modelClient";
export * from "./locale";
export * from "./shim";

19
app/client/core/locale.ts Normal file
View File

@@ -0,0 +1,19 @@
import { Lang, getLang } from "@/app/locales";
interface PlainConfig {
[k: string]: PlainConfig | string;
}
export type LocaleMap<
TextPlainConfig extends PlainConfig,
Default extends Lang,
> = Partial<Record<Lang, TextPlainConfig>> & {
[name in Default]: TextPlainConfig;
};
export function getLocaleText<
TextPlainConfig extends PlainConfig,
DefaultLang extends Lang,
>(textMap: LocaleMap<TextPlainConfig, DefaultLang>, defaultLang: DefaultLang) {
return textMap[getLang()] || textMap[defaultLang];
}

View File

@@ -0,0 +1,46 @@
import { ChatRequestPayload, Model, ModelConfig, ChatHandlers } from "./types";
import { ProviderClient, ProviderTemplateName } from "./providerClient";
export class ModelClient {
static getAllProvidersDefaultModels = () => {
return ProviderClient.getAllProvidersDefaultModels();
};
constructor(
private model: Model,
private modelConfig: ModelConfig,
private providerClient: ProviderClient,
) {}
chat(payload: ChatRequestPayload, handlers: ChatHandlers) {
try {
return this.providerClient.streamChat(
{
...payload,
modelConfig: this.modelConfig,
model: this.model.name,
},
handlers,
);
} catch (e) {
handlers.onError(e as Error);
}
}
summerize(payload: ChatRequestPayload) {
try {
return this.providerClient.chat({
...payload,
modelConfig: this.modelConfig,
model: this.model.name,
});
} catch (e) {
return "";
}
}
}
export function ModelClientFactory(model: Model, modelConfig: ModelConfig) {
const providerClient = new ProviderClient(model.providerTemplateName);
return new ModelClient(model, modelConfig, providerClient);
}

View File

@@ -0,0 +1,137 @@
import {
ChatHandlers,
IProviderTemplate,
Model,
StandChatReponseMessage,
StandChatRequestPayload,
} from "./types";
import * as ProviderTemplates from "@/app/client/providers";
import { cloneDeep } from "lodash-es";
export type ProviderTemplate =
(typeof ProviderTemplates)[keyof typeof ProviderTemplates];
export type ProviderTemplateName =
(typeof ProviderTemplates)[keyof typeof ProviderTemplates]["prototype"]["name"];
export class ProviderClient {
provider: IProviderTemplate<any, any, any>;
static ProviderTemplates = ProviderTemplates;
static getAllProvidersDefaultModels = () => {
return Object.values(ProviderClient.ProviderTemplates).reduce(
(r, p) => ({
...r,
[p.prototype.name]: cloneDeep(p.prototype.models),
}),
{} as Record<ProviderTemplateName, Model[]>,
);
};
static getAllProviderTemplates = () => {
return Object.values(ProviderClient.ProviderTemplates).reduce(
(r, p) => ({
...r,
[p.prototype.name]: p,
}),
{} as Record<ProviderTemplateName, ProviderTemplate>,
);
};
static getProviderTemplateList = () => {
return Object.values(ProviderClient.ProviderTemplates);
};
constructor(providerTemplateName: string) {
this.provider = this.getProviderTemplate(providerTemplateName);
}
get settingItems() {
const { providerMeta } = this.provider;
const { settingItems } = providerMeta;
return settingItems;
}
private getProviderTemplate(providerTemplateName: string) {
const providerTemplate =
Object.values(ProviderTemplates).find(
(template) => template.prototype.name === providerTemplateName,
) || ProviderTemplates.NextChatProvider;
return new providerTemplate();
}
getModelConfig(modelName: string) {
const { models } = this.provider;
return (
models.find((config) => config.name === modelName) ||
models.find((config) => config.isDefaultSelected)
);
}
async chat(
payload: StandChatRequestPayload<string>,
): Promise<StandChatReponseMessage> {
return this.provider.chat({
...payload,
stream: false,
isVisionModel: this.getModelConfig(payload.model)?.isVisionModel,
});
}
streamChat(payload: StandChatRequestPayload<string>, handlers: ChatHandlers) {
return this.provider.streamChat(
{
...payload,
stream: true,
isVisionModel: this.getModelConfig(payload.model)?.isVisionModel,
},
handlers.onProgress,
handlers.onFinish,
handlers.onError,
);
}
}
export interface Provider {
name: string; // id of provider
displayName: string;
isActive: boolean;
providerTemplateName: ProviderTemplateName;
models: Model[];
}
function createProvider(
provider: ProviderTemplateName,
params?: Omit<Provider, "providerTemplateName">,
): Provider;
function createProvider(
provider: ProviderTemplate,
params?: Omit<Provider, "providerTemplateName">,
): Provider;
function createProvider(
provider: ProviderTemplate | ProviderTemplateName,
params?: Omit<Provider, "providerTemplateName">,
): Provider {
let providerTemplate: ProviderTemplate;
if (typeof provider === "string") {
providerTemplate = ProviderClient.getAllProviderTemplates()[provider];
} else {
providerTemplate = provider;
}
const {
name = providerTemplate.prototype.name,
displayName = providerTemplate.prototype.providerMeta.displayName,
models = providerTemplate.prototype.models,
} = params ?? {};
return {
name,
displayName,
isActive: true,
models,
providerTemplateName: providerTemplate.prototype.name,
};
}
export { createProvider };

25
app/client/core/shim.ts Normal file
View File

@@ -0,0 +1,25 @@
import { getClientConfig } from "@/app/config/client";
if (!(window.fetch as any).__hijacked__) {
let _fetch = window.fetch;
function fetch(...args: Parameters<typeof _fetch>) {
const { isApp } = getClientConfig() || {};
let fetch: typeof _fetch = _fetch;
if (isApp) {
try {
fetch = window.__TAURI__!.http.fetch;
} catch (e) {
fetch = _fetch;
}
}
return fetch(...args);
}
fetch.__hijacked__ = true;
window.fetch = fetch;
}

164
app/client/core/types.ts Normal file
View File

@@ -0,0 +1,164 @@
import { RequestMessage } from "../api";
// ===================================== LLM Types start ======================================
export interface ModelConfig {
temperature: number;
top_p: number;
presence_penalty: number;
frequency_penalty: number;
max_tokens: number;
}
export type Model = {
name: string; // id of model in a provider
displayName: string;
isVisionModel?: boolean;
isDefaultActive: boolean; // model is initialized to be active
isDefaultSelected?: boolean; // model is initialized to be as default used model
providerTemplateName: string;
};
// ===================================== LLM Types end ======================================
// ===================================== Chat Request Types start ======================================
export interface ChatRequestPayload<SettingKeys extends string = ""> {
messages: RequestMessage[];
providerConfig: Record<SettingKeys, string>;
context: {
isApp: boolean;
};
}
export interface StandChatRequestPayload<SettingKeys extends string = "">
extends ChatRequestPayload<SettingKeys> {
modelConfig: ModelConfig;
model: string;
}
export interface InternalChatRequestPayload<SettingKeys extends string = "">
extends StandChatRequestPayload<SettingKeys> {
isVisionModel: Model["isVisionModel"];
stream: boolean;
}
export interface ProviderRequestPayload {
headers: Record<string, string>;
body: string;
url: string;
method: string;
}
export interface ChatHandlers {
onProgress: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError: (err: Error) => void;
}
// ===================================== Chat Request Types end ======================================
// ===================================== Chat Response Types start ======================================
export interface StandChatReponseMessage {
message: string;
}
// ===================================== Chat Request Types end ======================================
// ===================================== Provider Settings Types start ======================================
type NumberRange = [number, number];
export type Validator =
| "required"
| "number"
| "string"
| NumberRange
| NumberRange[];
export type CommonSettingItem<SettingKeys extends string> = {
name: SettingKeys;
title?: string;
description?: string;
validators?: Validator[];
};
export type InputSettingItem = {
type: "input";
placeholder?: string;
} & (
| {
inputType?: "password" | "normal";
defaultValue?: string;
}
| {
inputType?: "number";
defaultValue?: number;
}
);
export type SelectSettingItem = {
type: "select";
options: {
name: string;
value: "number" | "string" | "boolean";
}[];
placeholder?: string;
};
export type RangeSettingItem = {
type: "range";
range: NumberRange;
};
export type SwitchSettingItem = {
type: "switch";
};
export type SettingItem<SettingKeys extends string = ""> =
CommonSettingItem<SettingKeys> &
(
| InputSettingItem
| SelectSettingItem
| RangeSettingItem
| SwitchSettingItem
);
// ===================================== Provider Settings Types end ======================================
// ===================================== Provider Template Types start ======================================
export interface IProviderTemplate<
SettingKeys extends string,
NAME extends string,
Meta extends Record<string, any>,
> {
readonly name: NAME;
readonly metas: Meta;
readonly providerMeta: {
displayName: string;
settingItems: SettingItem<SettingKeys>[];
};
readonly models: Model[];
// formatChatPayload(payload: InternalChatRequestPayload<SettingKeys>): ProviderRequestPayload;
// readWholeMessageResponseBody(res: WholeMessageResponseBody): StandChatReponseMessage;
streamChat(
payload: InternalChatRequestPayload<SettingKeys>,
onProgress?: (message: string, chunk: string) => void,
onFinish?: (message: string) => void,
onError?: (err: Error) => void,
): AbortController;
chat(
payload: InternalChatRequestPayload<SettingKeys>,
): Promise<StandChatReponseMessage>;
}
export interface Serializable<Snapshot> {
serialize(): Snapshot;
}