From 8093d1ffbaafa8eddfa9139f20b971f957351423 Mon Sep 17 00:00:00 2001 From: Dean-YZG Date: Fri, 17 May 2024 21:11:21 +0800 Subject: [PATCH] feat: 1) Present 'maxtokens' as properties tied to a single model. 2) Remove the original author's implementation of the send verification logic and replace it with a user input validator. Pre-verification 3) Provides the ability to pull the 'User Visible modellist' provided by 'provider' 4) Provider-related parameters are passed in the constructor of 'providerClient'. Not passed in the 'chat' method --- app/client/common/index.ts | 5 + app/client/{core => common}/locale.ts | 0 app/client/{core => common}/types.ts | 56 ++++-- app/client/common/utils.ts | 26 +++ app/client/core/index.ts | 4 +- app/client/core/modelClient.ts | 78 +++++++-- app/client/core/providerClient.ts | 206 ++++++++++++++++------- app/client/providers/anthropic/config.ts | 31 +++- app/client/providers/anthropic/index.ts | 202 ++++++++-------------- app/client/providers/anthropic/locale.ts | 26 ++- app/client/providers/anthropic/utils.ts | 38 +++++ app/client/providers/azure/config.ts | 18 +- app/client/providers/azure/index.ts | 153 +++++++---------- app/client/providers/azure/locale.ts | 26 ++- app/client/providers/azure/utils.ts | 27 +++ app/client/providers/google/config.ts | 23 ++- app/client/providers/google/index.ts | 114 ++++++------- app/client/providers/google/locale.ts | 22 ++- app/client/providers/google/utils.ts | 10 ++ app/client/providers/nextchat/config.ts | 2 +- app/client/providers/nextchat/index.ts | 87 +++------- app/client/providers/nextchat/utils.ts | 18 ++ app/client/providers/openai/config.ts | 34 +++- app/client/providers/openai/index.ts | 149 +++++++--------- app/client/providers/openai/locale.ts | 20 ++- app/client/providers/openai/utils.ts | 18 ++ app/components/List/index.tsx | 14 +- app/store/provider.ts | 48 ++++-- app/utils/hooks.ts | 2 +- yarn.lock | 7 +- 30 files changed, 883 insertions(+), 581 deletions(-) create mode 100644 app/client/common/index.ts rename app/client/{core => common}/locale.ts (100%) rename app/client/{core => common}/types.ts (75%) create mode 100644 app/client/common/utils.ts create mode 100644 app/client/providers/anthropic/utils.ts create mode 100644 app/client/providers/azure/utils.ts create mode 100644 app/client/providers/google/utils.ts create mode 100644 app/client/providers/nextchat/utils.ts create mode 100644 app/client/providers/openai/utils.ts diff --git a/app/client/common/index.ts b/app/client/common/index.ts new file mode 100644 index 000000000..807aac0cd --- /dev/null +++ b/app/client/common/index.ts @@ -0,0 +1,5 @@ +export * from "./types"; + +export * from "./locale"; + +export * from "./utils"; diff --git a/app/client/core/locale.ts b/app/client/common/locale.ts similarity index 100% rename from app/client/core/locale.ts rename to app/client/common/locale.ts diff --git a/app/client/core/types.ts b/app/client/common/types.ts similarity index 75% rename from app/client/core/types.ts rename to app/client/common/types.ts index 19c9975e1..49031ea7f 100644 --- a/app/client/core/types.ts +++ b/app/client/common/types.ts @@ -1,5 +1,7 @@ import { RequestMessage } from "../api"; +export { type RequestMessage }; + // ===================================== LLM Types start ====================================== export interface ModelConfig { @@ -10,35 +12,50 @@ export interface ModelConfig { max_tokens: number; } -export type Model = { +export interface ModelSettings extends Omit { + global_max_tokens: number; +} + +export type ModelTemplate = { name: string; // id of model in a provider displayName: string; isVisionModel?: boolean; isDefaultActive: boolean; // model is initialized to be active isDefaultSelected?: boolean; // model is initialized to be as default used model - providerTemplateName: string; + max_tokens?: number; }; +export interface Model extends Omit { + providerTemplateName: string; + isActive: boolean; + providerName: string; + available: boolean; + customized: boolean; // Only customized model is allowed to be modified +} + +export interface ModelInfo extends Pick { + [k: string]: any; +} + // ===================================== LLM Types end ====================================== // ===================================== Chat Request Types start ====================================== -export interface ChatRequestPayload { +export interface ChatRequestPayload { messages: RequestMessage[]; - providerConfig: Record; context: { isApp: boolean; }; } -export interface StandChatRequestPayload - extends ChatRequestPayload { +export interface StandChatRequestPayload extends ChatRequestPayload { modelConfig: ModelConfig; model: string; } export interface InternalChatRequestPayload - extends StandChatRequestPayload { + extends StandChatRequestPayload { + providerConfig: Partial>; isVisionModel: Model["isVisionModel"]; stream: boolean; } @@ -50,12 +67,18 @@ export interface ProviderRequestPayload { method: string; } -export interface ChatHandlers { +export interface InternalChatHandlers { onProgress: (message: string, chunk: string) => void; onFinish: (message: string) => void; onError: (err: Error) => void; } +export interface ChatHandlers extends InternalChatHandlers { + onProgress: (chunk: string) => void; + onFinish: () => void; + onFlash: (message: string) => void; +} + // ===================================== Chat Request Types end ====================================== // ===================================== Chat Response Types start ====================================== @@ -75,7 +98,8 @@ export type Validator = | "number" | "string" | NumberRange - | NumberRange[]; + | NumberRange[] + | ((v: any) => Promise); export type CommonSettingItem = { name: SettingKeys; @@ -141,22 +165,20 @@ export interface IProviderTemplate< displayName: string; settingItems: SettingItem[]; }; - readonly models: Model[]; - - // formatChatPayload(payload: InternalChatRequestPayload): ProviderRequestPayload; - - // readWholeMessageResponseBody(res: WholeMessageResponseBody): StandChatReponseMessage; + readonly defaultModels: ModelTemplate[]; streamChat( payload: InternalChatRequestPayload, - onProgress?: (message: string, chunk: string) => void, - onFinish?: (message: string) => void, - onError?: (err: Error) => void, + handlers: ChatHandlers, ): AbortController; chat( payload: InternalChatRequestPayload, ): Promise; + + getAvailableModels?( + providerConfig: InternalChatRequestPayload["providerConfig"], + ): Promise; } export interface Serializable { diff --git a/app/client/common/utils.ts b/app/client/common/utils.ts new file mode 100644 index 000000000..de23c7825 --- /dev/null +++ b/app/client/common/utils.ts @@ -0,0 +1,26 @@ +import { RequestMessage } from "./types"; + +export function getMessageTextContent(message: RequestMessage) { + if (typeof message.content === "string") { + return message.content; + } + for (const c of message.content) { + if (c.type === "text") { + return c.text ?? ""; + } + } + return ""; +} + +export function getMessageImages(message: RequestMessage): string[] { + if (typeof message.content === "string") { + return []; + } + const urls: string[] = []; + for (const c of message.content) { + if (c.type === "image_url") { + urls.push(c.image_url?.url ?? ""); + } + } + return urls; +} diff --git a/app/client/core/index.ts b/app/client/core/index.ts index 3b4c3610f..2ffc6679e 100644 --- a/app/client/core/index.ts +++ b/app/client/core/index.ts @@ -1,9 +1,9 @@ -export * from "./types"; +export * from "../common/types"; export * from "./providerClient"; export * from "./modelClient"; -export * from "./locale"; +export * from "../common/locale"; export * from "./shim"; diff --git a/app/client/core/modelClient.ts b/app/client/core/modelClient.ts index 17c1f2639..eeb160bdb 100644 --- a/app/client/core/modelClient.ts +++ b/app/client/core/modelClient.ts @@ -1,23 +1,28 @@ -import { ChatRequestPayload, Model, ModelConfig, ChatHandlers } from "./types"; -import { ProviderClient, ProviderTemplateName } from "./providerClient"; +import { + ChatRequestPayload, + Model, + ModelSettings, + InternalChatHandlers, +} from "../common"; +import { Provider, ProviderClient } from "./providerClient"; export class ModelClient { - static getAllProvidersDefaultModels = () => { - return ProviderClient.getAllProvidersDefaultModels(); - }; - constructor( private model: Model, - private modelConfig: ModelConfig, + private modelSettings: ModelSettings, private providerClient: ProviderClient, ) {} - chat(payload: ChatRequestPayload, handlers: ChatHandlers) { + chat(payload: ChatRequestPayload, handlers: InternalChatHandlers) { try { return this.providerClient.streamChat( { ...payload, - modelConfig: this.modelConfig, + modelConfig: { + ...this.modelSettings, + max_tokens: + this.model.max_tokens ?? this.modelSettings.global_max_tokens, + }, model: this.model.name, }, handlers, @@ -31,7 +36,11 @@ export class ModelClient { try { return this.providerClient.chat({ ...payload, - modelConfig: this.modelConfig, + modelConfig: { + ...this.modelSettings, + max_tokens: + this.model.max_tokens ?? this.modelSettings.global_max_tokens, + }, model: this.model.name, }); } catch (e) { @@ -40,7 +49,50 @@ export class ModelClient { } } -export function ModelClientFactory(model: Model, modelConfig: ModelConfig) { - const providerClient = new ProviderClient(model.providerTemplateName); - return new ModelClient(model, modelConfig, providerClient); +// must generate new ModelClient during every chat +export function ModelClientFactory( + model: Model, + provider: Provider, + modelSettings: ModelSettings, +) { + const providerClient = new ProviderClient(provider); + return new ModelClient(model, modelSettings, providerClient); +} + +export function getFiltertModels( + models: readonly Model[], + customModels: string, +) { + const modelTable: Record = {}; + + // default models + models.forEach((m) => { + modelTable[m.name] = m; + }); + + // server custom models + customModels + .split(",") + .filter((v) => !!v && v.length > 0) + .forEach((m) => { + const available = !m.startsWith("-"); + const nameConfig = + m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; + const [name, displayName] = nameConfig.split("="); + + // enable or disable all models + if (name === "all") { + Object.values(modelTable).forEach( + (model) => (model.available = available), + ); + } else { + modelTable[name] = { + ...modelTable[name], + displayName, + available, + }; + } + }); + + return modelTable; } diff --git a/app/client/core/providerClient.ts b/app/client/core/providerClient.ts index 65849b5c2..863527eaf 100644 --- a/app/client/core/providerClient.ts +++ b/app/client/core/providerClient.ts @@ -1,118 +1,182 @@ import { - ChatHandlers, IProviderTemplate, + InternalChatHandlers, Model, + ModelTemplate, StandChatReponseMessage, StandChatRequestPayload, -} from "./types"; +} from "../common"; import * as ProviderTemplates from "@/app/client/providers"; -import { cloneDeep } from "lodash-es"; +import { nanoid } from "nanoid"; -export type ProviderTemplate = - (typeof ProviderTemplates)[keyof typeof ProviderTemplates]; +export type ProviderTemplate = IProviderTemplate; export type ProviderTemplateName = (typeof ProviderTemplates)[keyof typeof ProviderTemplates]["prototype"]["name"]; +export interface Provider< + Providerconfig extends Record = Record, +> { + name: string; // id of provider + isActive: boolean; + providerTemplateName: ProviderTemplateName; + providerConfig: Providerconfig; + isDefault: boolean; // Not allow to modify models of default provider + updated: boolean; // provider initial is finished + + displayName: string; + models: Model[]; +} + +const providerTemplates = Object.values(ProviderTemplates).reduce( + (r, t) => ({ + ...r, + [t.prototype.name]: new t(), + }), + {} as Record, +); + export class ProviderClient { - provider: IProviderTemplate; + providerTemplate: IProviderTemplate; - static ProviderTemplates = ProviderTemplates; - - static getAllProvidersDefaultModels = () => { - return Object.values(ProviderClient.ProviderTemplates).reduce( - (r, p) => ({ - ...r, - [p.prototype.name]: cloneDeep(p.prototype.models), - }), - {} as Record, - ); - }; + static ProviderTemplates = providerTemplates; static getAllProviderTemplates = () => { - return Object.values(ProviderClient.ProviderTemplates).reduce( - (r, p) => ({ + return Object.values(providerTemplates).reduce( + (r, t) => ({ ...r, - [p.prototype.name]: p, + [t.name]: t, }), {} as Record, ); }; - static getProviderTemplateList = () => { - return Object.values(ProviderClient.ProviderTemplates); + static getProviderTemplateMetaList = () => { + return Object.values(providerTemplates).map((t) => ({ + ...t.providerMeta, + name: t.name, + })); }; - constructor(providerTemplateName: string) { - this.provider = this.getProviderTemplate(providerTemplateName); - } - - get settingItems() { - const { providerMeta } = this.provider; - const { settingItems } = providerMeta; - return settingItems; + constructor(private provider: Provider) { + const { providerTemplateName } = provider; + this.providerTemplate = this.getProviderTemplate(providerTemplateName); } private getProviderTemplate(providerTemplateName: string) { - const providerTemplate = - Object.values(ProviderTemplates).find( - (template) => template.prototype.name === providerTemplateName, - ) || ProviderTemplates.NextChatProvider; + const providerTemplate = Object.values(providerTemplates).find( + (template) => template.name === providerTemplateName, + ); - return new providerTemplate(); + return providerTemplate || providerTemplates.openai; } - getModelConfig(modelName: string) { + private getModelConfig(modelName: string) { const { models } = this.provider; return ( - models.find((config) => config.name === modelName) || - models.find((config) => config.isDefaultSelected) + models.find((m) => m.name === modelName) || + models.find((m) => m.isDefaultSelected) ); } + getAvailableModels() { + return Promise.resolve( + this.providerTemplate.getAvailableModels?.(this.provider.providerConfig), + ) + .then((res) => { + const { defaultModels } = this.providerTemplate; + const availableModelsSet = new Set( + (res ?? defaultModels).map((o) => o.name), + ); + return defaultModels.filter((m) => availableModelsSet.has(m.name)); + }) + .catch(() => { + return this.providerTemplate.defaultModels; + }); + } + async chat( - payload: StandChatRequestPayload, + payload: StandChatRequestPayload, ): Promise { - return this.provider.chat({ + return this.providerTemplate.chat({ ...payload, stream: false, isVisionModel: this.getModelConfig(payload.model)?.isVisionModel, + providerConfig: this.provider.providerConfig, }); } - streamChat(payload: StandChatRequestPayload, handlers: ChatHandlers) { - return this.provider.streamChat( + streamChat(payload: StandChatRequestPayload, handlers: InternalChatHandlers) { + let responseText = ""; + let remainText = ""; + + const timer = this.providerTemplate.streamChat( { ...payload, stream: true, isVisionModel: this.getModelConfig(payload.model)?.isVisionModel, + providerConfig: this.provider.providerConfig, + }, + { + onProgress: (chunk) => { + remainText += chunk; + }, + onError: (err) => { + handlers.onError(err); + }, + onFinish: () => {}, + onFlash: (message: string) => { + handlers.onFinish(message); + }, }, - handlers.onProgress, - handlers.onFinish, - handlers.onError, ); + + timer.signal.onabort = () => { + const message = responseText + remainText; + remainText = ""; + handlers.onFinish(message); + }; + + const animateResponseText = () => { + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + handlers.onProgress(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + }; + + // start animaion + animateResponseText(); + + return timer; } } -export interface Provider { - name: string; // id of provider - displayName: string; - isActive: boolean; - providerTemplateName: ProviderTemplateName; - models: Model[]; -} +type Params = Omit; function createProvider( provider: ProviderTemplateName, - params?: Omit, + isDefault: true, +): Provider; +function createProvider(provider: ProviderTemplate, isDefault: true): Provider; +function createProvider( + provider: ProviderTemplateName, + isDefault: false, + params: Params, ): Provider; function createProvider( provider: ProviderTemplate, - params?: Omit, + isDefault: false, + params: Params, ): Provider; function createProvider( provider: ProviderTemplate | ProviderTemplateName, - params?: Omit, + isDefault: boolean, + params?: Params, ): Provider { let providerTemplate: ProviderTemplate; if (typeof provider === "string") { @@ -120,17 +184,41 @@ function createProvider( } else { providerTemplate = provider; } + + const name = `${providerTemplate.name}__${nanoid()}`; + const { - name = providerTemplate.prototype.name, - displayName = providerTemplate.prototype.providerMeta.displayName, - models = providerTemplate.prototype.models, + displayName = providerTemplate.providerMeta.displayName, + models = providerTemplate.defaultModels.map((m) => + createModelFromModelTemplate(m, providerTemplate, name), + ), + providerConfig, } = params ?? {}; + return { name, displayName, isActive: true, models, - providerTemplateName: providerTemplate.prototype.name, + providerTemplateName: providerTemplate.name, + providerConfig: isDefault ? {} : providerConfig!, + isDefault, + updated: true, + }; +} + +function createModelFromModelTemplate( + m: ModelTemplate, + p: ProviderTemplate, + providerName: string, +) { + return { + ...m, + providerTemplateName: p.name, + providerName, + isActive: m.isDefaultActive, + available: true, + customized: false, }; } diff --git a/app/client/providers/anthropic/config.ts b/app/client/providers/anthropic/config.ts index d58270d18..fe45a7aaf 100644 --- a/app/client/providers/anthropic/config.ts +++ b/app/client/providers/anthropic/config.ts @@ -1,4 +1,4 @@ -import { SettingItem } from "../../core/types"; +import { SettingItem } from "../../common"; import Locale from "./locale"; export type SettingKeys = @@ -13,6 +13,12 @@ export const AnthropicMetas = { Vision: "2023-06-01", }; +export const ClaudeMapper = { + assistant: "assistant", + user: "user", + system: "user", +} as const; + export const modelConfigs = [ { name: "claude-instant-1.2", @@ -58,6 +64,8 @@ export const modelConfigs = [ }, ]; +const defaultEndpoint = "/api/anthropic"; + export const settingItems: SettingItem[] = [ { name: "anthropicUrl", @@ -65,7 +73,22 @@ export const settingItems: SettingItem[] = [ description: Locale.Endpoint.SubTitle + AnthropicMetas.ExampleEndpoint, placeholder: AnthropicMetas.ExampleEndpoint, type: "input", - validators: ["required"], + defaultValue: defaultEndpoint, + validators: [ + "required", + async (v: any) => { + if (typeof v === "string" && !v.startsWith(defaultEndpoint)) { + try { + new URL(v); + } catch (e) { + return Locale.Endpoint.Error.IllegalURL; + } + } + if (typeof v === "string" && v.endsWith("/")) { + return Locale.Endpoint.Error.EndWithBackslash; + } + }, + ], }, { name: "anthropicApiKey", @@ -74,7 +97,7 @@ export const settingItems: SettingItem[] = [ placeholder: Locale.ApiKey.Placeholder, type: "input", inputType: "password", - validators: ["required"], + // validators: ["required"], }, { name: "anthropicApiVersion", @@ -82,6 +105,6 @@ export const settingItems: SettingItem[] = [ description: Locale.ApiVerion.SubTitle, placeholder: AnthropicMetas.Vision, type: "input", - validators: ["required"], + // validators: ["required"], }, ]; diff --git a/app/client/providers/anthropic/index.ts b/app/client/providers/anthropic/index.ts index eb40987b4..7d2f03350 100644 --- a/app/client/providers/anthropic/index.ts +++ b/app/client/providers/anthropic/index.ts @@ -1,29 +1,27 @@ -import { getMessageTextContent } from "@/app/utils"; import { AnthropicMetas, + ClaudeMapper, SettingKeys, modelConfigs, settingItems, } from "./config"; import { + ChatHandlers, InternalChatRequestPayload, IProviderTemplate, -} from "../../core/types"; + getMessageTextContent, + RequestMessage, +} from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; import Locale from "@/app/locales"; -import { prettyObject } from "@/app/utils/format"; +import { getAuthKey, trimEnd, prettyObject } from "./utils"; +import { cloneDeep } from "lodash-es"; export type AnthropicProviderSettingKeys = SettingKeys; -const ClaudeMapper = { - assistant: "assistant", - user: "user", - system: "user", -} as const; - export type MultiBlockContent = { type: "image" | "text"; source?: { @@ -75,64 +73,25 @@ export default class AnthropicProvider settingItems, }; - models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name })); + defaultModels = modelConfigs; readonly REQUEST_TIMEOUT_MS = 60000; private path(payload: InternalChatRequestPayload) { const { providerConfig: { anthropicUrl }, - context: { isApp }, } = payload; - let baseUrl: string = anthropicUrl; - - // if endpoint is empty, use default endpoint - if (baseUrl.trim().length === 0) { - baseUrl = "/api/anthropic"; - } - - if (!baseUrl.startsWith("http") && !baseUrl.startsWith("/api")) { - baseUrl = "https://" + baseUrl; - } - - baseUrl = trimEnd(baseUrl, "/"); - - return `${baseUrl}/${AnthropicMetas.ChatPath}`; + return `${trimEnd(anthropicUrl!)}/${AnthropicMetas.ChatPath}`; } - private formatChatPayload(payload: InternalChatRequestPayload) { - const { - messages, - isVisionModel, - model, - stream, - modelConfig, - providerConfig, - } = payload; - const { anthropicApiKey, anthropicApiVersion, anthropicUrl } = - providerConfig; - const { temperature, top_p, max_tokens } = modelConfig; + private formatMessage( + messages: RequestMessage[], + payload: InternalChatRequestPayload, + ) { + const { isVisionModel } = payload; - const keys = ["system", "user"]; - - // roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages - for (let i = 0; i < messages.length - 1; i++) { - const message = messages[i]; - const nextMessage = messages[i + 1]; - - if (keys.includes(message.role) && keys.includes(nextMessage.role)) { - messages[i] = [ - message, - { - role: "assistant", - content: ";", - }, - ] as any; - } - } - - const prompt = messages + return messages .flat() .filter((v) => { if (!v.content) return false; @@ -180,6 +139,40 @@ export default class AnthropicProvider }), }; }); + } + + private formatChatPayload(payload: InternalChatRequestPayload) { + const { + messages: outsideMessages, + model, + stream, + modelConfig, + providerConfig, + } = payload; + const { anthropicApiKey, anthropicApiVersion } = providerConfig; + const { temperature, top_p, max_tokens } = modelConfig; + + const keys = ["system", "user"]; + + // roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages + const messages = cloneDeep(outsideMessages); + + for (let i = 0; i < messages.length - 1; i++) { + const message = messages[i]; + const nextMessage = messages[i + 1]; + + if (keys.includes(message.role) && keys.includes(nextMessage.role)) { + messages[i] = [ + message, + { + role: "assistant", + content: ";", + }, + ] as any; + } + } + + const prompt = this.formatMessage(messages, payload); const requestBody: AnthropicChatRequest = { messages: prompt, @@ -196,7 +189,7 @@ export default class AnthropicProvider "Content-Type": "application/json", Accept: "application/json", "x-api-key": anthropicApiKey ?? "", - "anthropic-version": anthropicApiVersion, + "anthropic-version": anthropicApiVersion ?? "", Authorization: getAuthKey(anthropicApiKey), }, body: JSON.stringify(requestBody), @@ -204,6 +197,7 @@ export default class AnthropicProvider url: this.path(payload), }; } + private readWholeMessageResponseBody(res: any) { return { message: res?.content?.[0]?.text ?? "", @@ -259,50 +253,12 @@ export default class AnthropicProvider streamChat( payload: InternalChatRequestPayload, - onProgress: (message: string, chunk: string) => void, - onFinish: (message: string) => void, - onError: (err: Error) => void, + handlers: ChatHandlers, ) { const requestPayload = this.formatChatPayload(payload); - let responseText = ""; - let remainText = ""; - let finished = false; - const timer = this.getTimer(); - // animate response to make it looks smooth - const animateResponseText = () => { - if (finished || timer.signal.aborted) { - responseText += remainText; - console.log("[Response Animation] finished"); - if (responseText?.length === 0) { - onError(new Error("empty response from server")); - } - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - onProgress(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - }; - - // start animaion - animateResponseText(); - - const finish = () => { - if (!finished) { - finished = true; - onFinish(responseText + remainText); - } - }; - fetchEventSource(requestPayload.url, { ...requestPayload, async onopen(res) { @@ -311,8 +267,8 @@ export default class AnthropicProvider console.log("[OpenAI] request response content type: ", contentType); if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); + const responseText = await res.clone().text(); + return handlers.onFlash(responseText); } if ( @@ -322,29 +278,29 @@ export default class AnthropicProvider ?.startsWith(EventStreamContentType) || res.status !== 200 ) { - const responseTexts = [responseText]; + const responseTexts = []; + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + let extraInfo = await res.clone().text(); try { const resJson = await res.clone().json(); extraInfo = prettyObject(resJson); } catch {} - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - if (extraInfo) { responseTexts.push(extraInfo); } - responseText = responseTexts.join("\n\n"); + const responseText = responseTexts.join("\n\n"); - return finish(); + return handlers.onFlash(responseText); } }, onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); + if (msg.data === "[DONE]") { + return; } const text = msg.data; try { @@ -353,20 +309,19 @@ export default class AnthropicProvider delta: { content: string }; }>; const delta = choices[0]?.delta?.content; - const textmoderation = json?.prompt_filter_results; if (delta) { - remainText += delta; + handlers.onProgress(delta); } } catch (e) { console.error("[Request] parse error", text, msg); } }, onclose() { - finish(); + handlers.onFinish(); }, onerror(e) { - onError(e); + handlers.onError(e); throw e; }, openWhenHidden: true, @@ -375,28 +330,3 @@ export default class AnthropicProvider return timer; } } - -function trimEnd(s: string, end = " ") { - if (end.length === 0) return s; - - while (s.endsWith(end)) { - s = s.slice(0, -end.length); - } - - return s; -} - -function bearer(value: string) { - return `Bearer ${value.trim()}`; -} - -function getAuthKey(apiKey = "") { - let authKey = ""; - - if (apiKey) { - // use user's api key first - authKey = bearer(apiKey); - } - - return authKey; -} diff --git a/app/client/providers/anthropic/locale.ts b/app/client/providers/anthropic/locale.ts index a683eea9b..9aabd9e16 100644 --- a/app/client/providers/anthropic/locale.ts +++ b/app/client/providers/anthropic/locale.ts @@ -1,4 +1,4 @@ -import { getLocaleText } from "../../core/locale"; +import { getLocaleText } from "../../common"; export default getLocaleText< { @@ -10,6 +10,10 @@ export default getLocaleText< Endpoint: { Title: string; SubTitle: string; + Error: { + EndWithBackslash: string; + IllegalURL: string; + }; }; ApiVerion: { Title: string; @@ -29,6 +33,10 @@ export default getLocaleText< Endpoint: { Title: "接口地址", SubTitle: "样例:", + Error: { + EndWithBackslash: "不能以「/」结尾", + IllegalURL: "请输入一个完整可用的url", + }, }, ApiVerion: { @@ -47,6 +55,10 @@ export default getLocaleText< Endpoint: { Title: "Endpoint Address", SubTitle: "Example:", + Error: { + EndWithBackslash: "Cannot end with '/'", + IllegalURL: "Please enter a complete available url", + }, }, ApiVerion: { @@ -64,6 +76,10 @@ export default getLocaleText< Endpoint: { Title: "Endpoint Address", SubTitle: "Exemplo: ", + Error: { + EndWithBackslash: "Não é possível terminar com '/'", + IllegalURL: "Insira um URL completo disponível", + }, }, ApiVerion: { @@ -81,6 +97,10 @@ export default getLocaleText< Endpoint: { Title: "Adresa koncového bodu", SubTitle: "Príklad:", + Error: { + EndWithBackslash: "Nemôže končiť znakom „/“", + IllegalURL: "Zadajte úplnú dostupnú adresu URL", + }, }, ApiVerion: { @@ -98,6 +118,10 @@ export default getLocaleText< Endpoint: { Title: "終端地址", SubTitle: "範例:", + Error: { + EndWithBackslash: "不能以「/」結尾", + IllegalURL: "請輸入一個完整可用的url", + }, }, ApiVerion: { diff --git a/app/client/providers/anthropic/utils.ts b/app/client/providers/anthropic/utils.ts new file mode 100644 index 000000000..9a36f2d72 --- /dev/null +++ b/app/client/providers/anthropic/utils.ts @@ -0,0 +1,38 @@ +export function trimEnd(s: string, end = " ") { + if (end.length === 0) return s; + + while (s.endsWith(end)) { + s = s.slice(0, -end.length); + } + + return s; +} + +export function bearer(value: string) { + return `Bearer ${value.trim()}`; +} + +export function getAuthKey(apiKey = "") { + let authKey = ""; + + if (apiKey) { + // use user's api key first + authKey = bearer(apiKey); + } + + return authKey; +} + +export function prettyObject(msg: any) { + const obj = msg; + if (typeof msg !== "string") { + msg = JSON.stringify(msg, null, " "); + } + if (msg === "{}") { + return obj.toString(); + } + if (msg.startsWith("```json")) { + return msg; + } + return ["```json", msg, "```"].join("\n"); +} diff --git a/app/client/providers/azure/config.ts b/app/client/providers/azure/config.ts index 01e978503..c26b29f8b 100644 --- a/app/client/providers/azure/config.ts +++ b/app/client/providers/azure/config.ts @@ -1,12 +1,11 @@ import Locale from "./locale"; -import { SettingItem } from "../../core/types"; +import { SettingItem } from "../../common"; import { modelConfigs as openaiModelConfigs } from "../openai/config"; export const AzureMetas = { ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", ChatPath: "v1/chat/completions", - OpenAI: "/api/openai", }; export type SettingKeys = "azureUrl" | "azureApiKey" | "azureApiVersion"; @@ -20,6 +19,21 @@ export const settingItems: SettingItem[] = [ description: Locale.Endpoint.SubTitle + AzureMetas.ExampleEndpoint, placeholder: AzureMetas.ExampleEndpoint, type: "input", + validators: [ + async (v: any) => { + if (typeof v === "string") { + try { + new URL(v); + } catch (e) { + return Locale.Endpoint.Error.IllegalURL; + } + } + if (typeof v === "string" && v.endsWith("/")) { + return Locale.Endpoint.Error.EndWithBackslash; + } + }, + "required", + ], }, { name: "azureApiKey", diff --git a/app/client/providers/azure/index.ts b/app/client/providers/azure/index.ts index 38b892ebf..2d5ee112e 100644 --- a/app/client/providers/azure/index.ts +++ b/app/client/providers/azure/index.ts @@ -1,15 +1,17 @@ import { settingItems, SettingKeys, modelConfigs, AzureMetas } from "./config"; import { + ChatHandlers, InternalChatRequestPayload, IProviderTemplate, -} from "../../core/types"; -import { getMessageTextContent } from "@/app/utils"; + ModelInfo, + getMessageTextContent, +} from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; -import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; +import { makeAzurePath, makeBearer, prettyObject, validString } from "./utils"; export type AzureProviderSettingKeys = SettingKeys; @@ -43,13 +45,30 @@ interface RequestPayload { max_tokens?: number; } +interface ModelList { + object: "list"; + data: Array<{ + capabilities: { + fine_tune: boolean; + inference: boolean; + completion: boolean; + chat_completion: boolean; + embeddings: boolean; + }; + lifecycle_status: "generally-available"; + id: string; + created_at: number; + object: "model"; + }>; +} + export default class Azure implements IProviderTemplate { name = "azure" as const; metas = AzureMetas; - models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name })); + defaultModels = modelConfigs; providerMeta = { displayName: "Azure", @@ -62,25 +81,11 @@ export default class Azure const { providerConfig: { azureUrl, azureApiVersion }, } = payload; + const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion!); - const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion); + console.log("[Proxy Endpoint] ", azureUrl, path); - let baseUrl = azureUrl; - - if (!baseUrl) { - baseUrl = "/api/openai"; - } - - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.slice(0, baseUrl.length - 1); - } - if (!baseUrl.startsWith("http") && !baseUrl.startsWith(AzureMetas.OpenAI)) { - baseUrl = "https://" + baseUrl; - } - - console.log("[Proxy Endpoint] ", baseUrl, path); - - return [baseUrl, path].join("/"); + return [azureUrl!, path].join("/"); } private getHeaders(payload: InternalChatRequestPayload) { @@ -90,14 +95,9 @@ export default class Azure "Content-Type": "application/json", Accept: "application/json", }; - const authHeader = "Authorization"; - const makeBearer = (s: string) => `Bearer ${s.trim()}`; - const validString = (x?: string): x is string => Boolean(x && x.length > 0); - - // when using google api in app, not set auth header if (validString(azureApiKey)) { - headers[authHeader] = makeBearer(azureApiKey); + headers["Authorization"] = makeBearer(azureApiKey); } return headers; @@ -197,52 +197,12 @@ export default class Azure streamChat( payload: InternalChatRequestPayload, - onProgress: (message: string, chunk: string) => void, - onFinish: (message: string) => void, - onError: (err: Error) => void, + handlers: ChatHandlers, ) { const requestPayload = this.formatChatPayload(payload); const timer = this.getTimer(); - let responseText = ""; - let remainText = ""; - let finished = false; - - // animate response to make it looks smooth - const animateResponseText = () => { - if (finished || timer.signal.aborted) { - responseText += remainText; - console.log("[Response Animation] finished"); - if (responseText?.length === 0) { - onError(new Error("empty response from server")); - } - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - onProgress(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - }; - - // start animaion - animateResponseText(); - - const finish = () => { - if (!finished) { - finished = true; - onFinish(responseText + remainText); - } - }; - - timer.signal.onabort = finish; - fetchEventSource(requestPayload.url, { ...requestPayload, async onopen(res) { @@ -251,8 +211,8 @@ export default class Azure console.log("[OpenAI] request response content type: ", contentType); if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); + const responseText = await res.clone().text(); + return handlers.onFlash(responseText); } if ( @@ -262,29 +222,29 @@ export default class Azure ?.startsWith(EventStreamContentType) || res.status !== 200 ) { - const responseTexts = [responseText]; + const responseTexts = []; + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + let extraInfo = await res.clone().text(); try { const resJson = await res.clone().json(); extraInfo = prettyObject(resJson); } catch {} - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - if (extraInfo) { responseTexts.push(extraInfo); } - responseText = responseTexts.join("\n\n"); + const responseText = responseTexts.join("\n\n"); - return finish(); + return handlers.onFlash(responseText); } }, onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); + if (msg.data === "[DONE]") { + return; } const text = msg.data; try { @@ -293,34 +253,41 @@ export default class Azure delta: { content: string }; }>; const delta = choices[0]?.delta?.content; - const textmoderation = json?.prompt_filter_results; if (delta) { - remainText += delta; + handlers.onProgress(delta); } } catch (e) { console.error("[Request] parse error", text, msg); } }, onclose() { - finish(); + handlers.onFinish(); }, onerror(e) { - onError(e); + handlers.onError(e); throw e; }, openWhenHidden: true, }); + return timer; } -} - -function makeAzurePath(path: string, apiVersion: string) { - // should omit /v1 prefix - path = path.replaceAll("v1/", ""); - - // should add api-key to query string - path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`; - - return path; + + async getAvailableModels( + providerConfig: Record, + ): Promise { + const { azureApiKey, azureUrl } = providerConfig; + const res = await fetch(`${azureUrl}/vi/models`, { + headers: { + Authorization: `Bearer ${azureApiKey}`, + }, + method: "GET", + }); + const data: ModelList = await res.json(); + + return data.data.map((o) => ({ + name: o.id, + })); + } } diff --git a/app/client/providers/azure/locale.ts b/app/client/providers/azure/locale.ts index b559b5b13..b6b7d2e75 100644 --- a/app/client/providers/azure/locale.ts +++ b/app/client/providers/azure/locale.ts @@ -1,4 +1,4 @@ -import { getLocaleText } from "../../core/locale"; +import { getLocaleText } from "../../common"; export default getLocaleText< { @@ -10,6 +10,10 @@ export default getLocaleText< Endpoint: { Title: string; SubTitle: string; + Error: { + EndWithBackslash: string; + IllegalURL: string; + }; }; ApiVerion: { Title: string; @@ -29,6 +33,10 @@ export default getLocaleText< Endpoint: { Title: "接口地址", SubTitle: "样例:", + Error: { + EndWithBackslash: "不能以「/」结尾", + IllegalURL: "请输入一个完整可用的url", + }, }, ApiVerion: { @@ -46,6 +54,10 @@ export default getLocaleText< Endpoint: { Title: "Azure Endpoint", SubTitle: "Example: ", + Error: { + EndWithBackslash: "Cannot end with '/'", + IllegalURL: "Please enter a complete available url", + }, }, ApiVerion: { @@ -63,6 +75,10 @@ export default getLocaleText< Endpoint: { Title: "Endpoint Azure", SubTitle: "Exemplo: ", + Error: { + EndWithBackslash: "Não é possível terminar com '/'", + IllegalURL: "Insira um URL completo disponível", + }, }, ApiVerion: { @@ -80,6 +96,10 @@ export default getLocaleText< Endpoint: { Title: "Koncový bod Azure", SubTitle: "Príklad: ", + Error: { + EndWithBackslash: "Nemôže končiť znakom „/“", + IllegalURL: "Zadajte úplnú dostupnú adresu URL", + }, }, ApiVerion: { @@ -97,6 +117,10 @@ export default getLocaleText< Endpoint: { Title: "介面(Endpoint) 地址", SubTitle: "樣例:", + Error: { + EndWithBackslash: "不能以「/」結尾", + IllegalURL: "請輸入一個完整可用的url", + }, }, ApiVerion: { diff --git a/app/client/providers/azure/utils.ts b/app/client/providers/azure/utils.ts new file mode 100644 index 000000000..fea7457c8 --- /dev/null +++ b/app/client/providers/azure/utils.ts @@ -0,0 +1,27 @@ +export function makeAzurePath(path: string, apiVersion: string) { + // should omit /v1 prefix + path = path.replaceAll("v1/", ""); + + // should add api-key to query string + path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`; + + return path; +} + +export function prettyObject(msg: any) { + const obj = msg; + if (typeof msg !== "string") { + msg = JSON.stringify(msg, null, " "); + } + if (msg === "{}") { + return obj.toString(); + } + if (msg.startsWith("```json")) { + return msg; + } + return ["```json", msg, "```"].join("\n"); +} + +export const makeBearer = (s: string) => `Bearer ${s.trim()}`; +export const validString = (x?: string): x is string => + Boolean(x && x.length > 0); diff --git a/app/client/providers/google/config.ts b/app/client/providers/google/config.ts index 6248b4913..3f7e02cdb 100644 --- a/app/client/providers/google/config.ts +++ b/app/client/providers/google/config.ts @@ -1,11 +1,9 @@ -import { SettingItem } from "../../core/types"; +import { SettingItem } from "../../common"; import Locale from "./locale"; export const GoogleMetas = { ExampleEndpoint: "https://generativelanguage.googleapis.com/", ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, - VisionChatPath: (modelName: string) => - `v1beta/models/${modelName}:generateContent`, }; export type SettingKeys = "googleUrl" | "googleApiKey" | "googleApiVersion"; @@ -41,7 +39,20 @@ export const settingItems: SettingItem[] = [ description: Locale.Endpoint.SubTitle + GoogleMetas.ExampleEndpoint, placeholder: GoogleMetas.ExampleEndpoint, type: "input", - validators: ["required"], + validators: [ + async (v: any) => { + if (typeof v === "string") { + try { + new URL(v); + } catch (e) { + return Locale.Endpoint.Error.IllegalURL; + } + } + if (typeof v === "string" && v.endsWith("/")) { + return Locale.Endpoint.Error.EndWithBackslash; + } + }, + ], }, { name: "googleApiKey", @@ -50,7 +61,7 @@ export const settingItems: SettingItem[] = [ placeholder: Locale.ApiKey.Placeholder, type: "input", inputType: "password", - validators: ["required"], + // validators: ["required"], }, { name: "googleApiVersion", @@ -58,6 +69,6 @@ export const settingItems: SettingItem[] = [ description: Locale.ApiVersion.SubTitle, placeholder: "2023-08-01-preview", type: "input", - validators: ["required"], + // validators: ["required"], }, ]; diff --git a/app/client/providers/google/index.ts b/app/client/providers/google/index.ts index 9bed21ea6..a0ea89605 100644 --- a/app/client/providers/google/index.ts +++ b/app/client/providers/google/index.ts @@ -1,13 +1,34 @@ -import { getMessageImages, getMessageTextContent } from "@/app/utils"; import { SettingKeys, modelConfigs, settingItems, GoogleMetas } from "./config"; import { + ChatHandlers, InternalChatRequestPayload, IProviderTemplate, + ModelInfo, StandChatReponseMessage, -} from "../../core/types"; + getMessageTextContent, + getMessageImages, +} from "../../common"; +import { ensureProperEnding, makeBearer, validString } from "./utils"; export type GoogleProviderSettingKeys = SettingKeys; +interface ModelList { + models: Array<{ + name: string; + baseModelId: string; + version: string; + displayName: string; + description: string; + inputTokenLimit: number; // Integer + outputTokenLimit: number; // Integer + supportedGenerationMethods: [string]; + temperature: number; + topP: number; + topK: number; // Integer + }>; + nextPageToken: string; +} + export default class GoogleProvider implements IProviderTemplate { @@ -18,7 +39,7 @@ export default class GoogleProvider displayName: "Google", settingItems, }; - models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name })); + defaultModels = modelConfigs; readonly REQUEST_TIMEOUT_MS = 60000; @@ -33,19 +54,8 @@ export default class GoogleProvider Accept: "application/json", }; - const authHeader = "Authorization"; - - const makeBearer = (s: string) => `Bearer ${s.trim()}`; - const validString = (x?: string): x is string => Boolean(x && x.length > 0); - - // when using google api in app, not set auth header - if (!isApp) { - // use user's api key first - if (validString(googleApiKey)) { - headers[authHeader] = makeBearer(googleApiKey); - } else { - throw new Error("no apiKey when chat through google"); - } + if (!isApp && validString(googleApiKey)) { + headers["Authorization"] = makeBearer(googleApiKey); } return headers; @@ -135,15 +145,9 @@ export default class GoogleProvider ], }; - let baseUrl = googleUrl; + let googleChatPath = GoogleMetas.ChatPath(model); - let googleChatPath = isVisionModel - ? GoogleMetas.VisionChatPath(model) - : GoogleMetas.ChatPath(model); - - if (!baseUrl) { - baseUrl = "/api/google/" + googleChatPath; - } + let baseUrl = googleUrl ?? "/api/google/" + googleChatPath; if (isApp) { baseUrl += `?key=${googleApiKey}`; @@ -193,44 +197,13 @@ export default class GoogleProvider streamChat( payload: InternalChatRequestPayload, - onProgress: (message: string, chunk: string) => void, - onFinish: (message: string) => void, - onError: (err: Error) => void, + handlers: ChatHandlers, ) { const requestPayload = this.formatChatPayload(payload); - let responseText = ""; - let remainText = ""; - let finished = false; const timer = this.getTimer(); let existingTexts: string[] = []; - const finish = () => { - finished = true; - onFinish(existingTexts.join("")); - }; - - // animate response to make it looks smooth - const animateResponseText = () => { - if (finished || timer.signal.aborted) { - responseText += remainText; - finish(); - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - onProgress(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - }; - - // start animaion - animateResponseText(); fetch(requestPayload.url, { ...requestPayload, @@ -250,18 +223,16 @@ export default class GoogleProvider try { let data = JSON.parse(ensureProperEnding(partialData)); if (data && data[0].error) { - onError(new Error(data[0].error.message)); + handlers.onError(new Error(data[0].error.message)); } else { - onError(new Error("Request failed")); + handlers.onError(new Error("Request failed")); } } catch (_) { - onError(new Error("Request failed")); + handlers.onError(new Error("Request failed")); } } console.log("Stream complete"); - // options.onFinish(responseText + remainText); - finished = true; return Promise.resolve(); } @@ -285,7 +256,7 @@ export default class GoogleProvider if (textArray.length > existingTexts.length) { const deltaArray = textArray.slice(existingTexts.length); existingTexts = textArray; - remainText += deltaArray.join(""); + handlers.onProgress(deltaArray.join("")); } } catch (error) { // console.log("[Response Animation] error: ", error,partialData); @@ -300,6 +271,7 @@ export default class GoogleProvider }); return timer; } + async chat( payload: InternalChatRequestPayload, ): Promise { @@ -328,11 +300,19 @@ export default class GoogleProvider return message; } -} -function ensureProperEnding(str: string) { - if (str.startsWith("[") && !str.endsWith("]")) { - return str + "]"; + async getAvailableModels( + providerConfig: Record, + ): Promise { + const { googleApiKey, googleUrl } = providerConfig; + const res = await fetch(`${googleUrl}/v1beta/models?key=${googleApiKey}`, { + headers: { + Authorization: `Bearer ${googleApiKey}`, + }, + method: "GET", + }); + const data: ModelList = await res.json(); + + return data.models; } - return str; } diff --git a/app/client/providers/google/locale.ts b/app/client/providers/google/locale.ts index 82b933b6a..5d92cd970 100644 --- a/app/client/providers/google/locale.ts +++ b/app/client/providers/google/locale.ts @@ -1,4 +1,4 @@ -import { getLocaleText } from "../../core/locale"; +import { getLocaleText } from "../../common"; export default getLocaleText< { @@ -10,6 +10,10 @@ export default getLocaleText< Endpoint: { Title: string; SubTitle: string; + Error: { + EndWithBackslash: string; + IllegalURL: string; + }; }; ApiVersion: { Title: string; @@ -29,6 +33,10 @@ export default getLocaleText< Endpoint: { Title: "终端地址", SubTitle: "示例:", + Error: { + EndWithBackslash: "不能以「/」结尾", + IllegalURL: "请输入一个完整可用的url", + }, }, ApiVersion: { @@ -46,6 +54,10 @@ export default getLocaleText< Endpoint: { Title: "Endpoint Address", SubTitle: "Example:", + Error: { + EndWithBackslash: "Cannot end with '/'", + IllegalURL: "Please enter a complete available url", + }, }, ApiVersion: { @@ -64,6 +76,10 @@ export default getLocaleText< Endpoint: { Title: "Adresa koncového bodu", SubTitle: "Príklad:", + Error: { + EndWithBackslash: "Nemôže končiť znakom „/“", + IllegalURL: "Zadajte úplnú dostupnú adresu URL", + }, }, ApiVersion: { @@ -81,6 +97,10 @@ export default getLocaleText< Endpoint: { Title: "終端地址", SubTitle: "範例:", + Error: { + EndWithBackslash: "不能以「/」結尾", + IllegalURL: "請輸入一個完整可用的url", + }, }, ApiVersion: { diff --git a/app/client/providers/google/utils.ts b/app/client/providers/google/utils.ts new file mode 100644 index 000000000..78258ef27 --- /dev/null +++ b/app/client/providers/google/utils.ts @@ -0,0 +1,10 @@ +export const makeBearer = (s: string) => `Bearer ${s.trim()}`; +export const validString = (x?: string): x is string => + Boolean(x && x.length > 0); + +export function ensureProperEnding(str: string) { + if (str.startsWith("[") && !str.endsWith("]")) { + return str + "]"; + } + return str; +} diff --git a/app/client/providers/nextchat/config.ts b/app/client/providers/nextchat/config.ts index 67dbd9005..a9eab766b 100644 --- a/app/client/providers/nextchat/config.ts +++ b/app/client/providers/nextchat/config.ts @@ -1,4 +1,4 @@ -import { SettingItem } from "../../core/types"; +import { SettingItem } from "../../common"; import { isVisionModel } from "@/app/utils"; import Locale from "@/app/locales"; diff --git a/app/client/providers/nextchat/index.ts b/app/client/providers/nextchat/index.ts index b33595391..a66001d95 100644 --- a/app/client/providers/nextchat/index.ts +++ b/app/client/providers/nextchat/index.ts @@ -4,19 +4,21 @@ import { SettingKeys, NextChatMetas, } from "./config"; -import { getMessageTextContent } from "@/app/utils"; import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { + ChatHandlers, + getMessageTextContent, InternalChatRequestPayload, IProviderTemplate, StandChatReponseMessage, -} from "../../core/types"; +} from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; +import { makeBearer, validString } from "./utils"; export type NextChatProviderSettingKeys = SettingKeys; @@ -56,7 +58,7 @@ export default class NextChatProvider name = "nextchat" as const; metas = NextChatMetas; - models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name })); + defaultModels = modelConfigs; providerMeta = { displayName: "NextChat", @@ -82,14 +84,9 @@ export default class NextChatProvider "Content-Type": "application/json", Accept: "application/json", }; - const authHeader = "Authorization"; - const makeBearer = (s: string) => `Bearer ${s.trim()}`; - const validString = (x?: string): x is string => Boolean(x && x.length > 0); - - // when using google api in app, not set auth header if (validString(accessCode)) { - headers[authHeader] = makeBearer(ACCESS_CODE_PREFIX + accessCode); + headers["Authorization"] = makeBearer(ACCESS_CODE_PREFIX + accessCode); } return headers; @@ -160,52 +157,12 @@ export default class NextChatProvider streamChat( payload: InternalChatRequestPayload, - onProgress: (message: string, chunk: string) => void, - onFinish: (message: string) => void, - onError: (err: Error) => void, + handlers: ChatHandlers, ) { const requestPayload = this.formatChatPayload(payload); - let responseText = ""; - let remainText = ""; - let finished = false; - const timer = this.getTimer(); - // animate response to make it looks smooth - const animateResponseText = () => { - if (finished || timer.signal.aborted) { - responseText += remainText; - console.log("[Response Animation] finished"); - if (responseText?.length === 0) { - onError(new Error("empty response from server")); - } - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - onProgress(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - }; - - // start animaion - animateResponseText(); - - const finish = () => { - if (!finished) { - finished = true; - onFinish(responseText + remainText); - } - }; - - timer.signal.onabort = finish; - fetchEventSource(requestPayload.url, { ...requestPayload, async onopen(res) { @@ -214,8 +171,8 @@ export default class NextChatProvider console.log("[OpenAI] request response content type: ", contentType); if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); + const responseText = await res.clone().text(); + return handlers.onFlash(responseText); } if ( @@ -225,29 +182,29 @@ export default class NextChatProvider ?.startsWith(EventStreamContentType) || res.status !== 200 ) { - const responseTexts = [responseText]; + const responseTexts = []; + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + let extraInfo = await res.clone().text(); try { const resJson = await res.clone().json(); extraInfo = prettyObject(resJson); } catch {} - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - if (extraInfo) { responseTexts.push(extraInfo); } - responseText = responseTexts.join("\n\n"); + const responseText = responseTexts.join("\n\n"); - return finish(); + return handlers.onFlash(responseText); } }, onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); + if (msg.data === "[DONE]") { + return; } const text = msg.data; try { @@ -256,20 +213,19 @@ export default class NextChatProvider delta: { content: string }; }>; const delta = choices[0]?.delta?.content; - const textmoderation = json?.prompt_filter_results; if (delta) { - remainText += delta; + handlers.onProgress(delta); } } catch (e) { console.error("[Request] parse error", text, msg); } }, onclose() { - finish(); + handlers.onFinish(); }, onerror(e) { - onError(e); + handlers.onError(e); throw e; }, openWhenHidden: true, @@ -277,6 +233,7 @@ export default class NextChatProvider return timer; } + async chat( payload: InternalChatRequestPayload<"accessCode">, ): Promise { diff --git a/app/client/providers/nextchat/utils.ts b/app/client/providers/nextchat/utils.ts new file mode 100644 index 000000000..24f6ef4f0 --- /dev/null +++ b/app/client/providers/nextchat/utils.ts @@ -0,0 +1,18 @@ +export const makeBearer = (s: string) => `Bearer ${s.trim()}`; + +export const validString = (x?: string): x is string => + Boolean(x && x.length > 0); + +export function prettyObject(msg: any) { + const obj = msg; + if (typeof msg !== "string") { + msg = JSON.stringify(msg, null, " "); + } + if (msg === "{}") { + return obj.toString(); + } + if (msg.startsWith("```json")) { + return msg; + } + return ["```json", msg, "```"].join("\n"); +} diff --git a/app/client/providers/openai/config.ts b/app/client/providers/openai/config.ts index 9cd47d92f..60c7073d4 100644 --- a/app/client/providers/openai/config.ts +++ b/app/client/providers/openai/config.ts @@ -1,8 +1,10 @@ -import { SettingItem } from "../../core/types"; +import { SettingItem } from "../../common"; import Locale from "./locale"; export const OPENAI_BASE_URL = "https://api.openai.com"; +export const ROLES = ["system", "user", "assistant"] as const; + export const OpenaiMetas = { ChatPath: "v1/chat/completions", UsagePath: "dashboard/billing/usage", @@ -12,15 +14,20 @@ export const OpenaiMetas = { export type SettingKeys = "openaiUrl" | "openaiApiKey"; -export const defaultModal = "gpt-3.5-turbo"; - export const modelConfigs = [ + { + name: "gpt-4o", + displayName: "gpt-4o", + isVision: false, + isDefaultActive: true, + isDefaultSelected: true, + }, { name: "gpt-3.5-turbo", displayName: "gpt-3.5-turbo", isVision: false, isDefaultActive: true, - isDefaultSelected: true, + isDefaultSelected: false, }, { name: "gpt-3.5-turbo-0301", @@ -150,13 +157,30 @@ export const modelConfigs = [ }, ]; +const defaultEndpoint = "/api/openai"; + export const settingItems: SettingItem[] = [ { name: "openaiUrl", title: Locale.Endpoint.Title, description: Locale.Endpoint.SubTitle, - defaultValue: OPENAI_BASE_URL, + defaultValue: defaultEndpoint, type: "input", + validators: [ + "required", + async (v: any) => { + if (typeof v === "string" && v.endsWith("/")) { + return Locale.Endpoint.Error.EndWithBackslash; + } + if ( + typeof v === "string" && + !v.startsWith(defaultEndpoint) && + !v.startsWith("http") + ) { + return Locale.Endpoint.SubTitle; + } + }, + ], }, { name: "openaiApiKey", diff --git a/app/client/providers/openai/index.ts b/app/client/providers/openai/index.ts index e1c051d42..1d336494a 100644 --- a/app/client/providers/openai/index.ts +++ b/app/client/providers/openai/index.ts @@ -1,19 +1,26 @@ -import { modelConfigs, settingItems, SettingKeys, OpenaiMetas } from "./config"; -import { getMessageTextContent } from "@/app/utils"; import { + ChatHandlers, InternalChatRequestPayload, IProviderTemplate, -} from "../../core/types"; + ModelInfo, + getMessageTextContent, +} from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; -import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; +import { makeBearer, validString, prettyObject } from "./utils"; +import { + modelConfigs, + settingItems, + SettingKeys, + OpenaiMetas, + ROLES, +} from "./config"; export type OpenAIProviderSettingKeys = SettingKeys; -export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; export interface MultimodalContent { @@ -28,7 +35,6 @@ export interface RequestMessage { role: MessageRole; content: string | MultimodalContent[]; } - interface RequestPayload { messages: { role: "system" | "user" | "assistant"; @@ -43,6 +49,16 @@ interface RequestPayload { max_tokens?: number; } +interface ModelList { + object: "list"; + data: Array<{ + id: string; + object: "model"; + created: number; + owned_by: "system" | "openai-internal"; + }>; +} + class OpenAIProvider implements IProviderTemplate { @@ -51,7 +67,7 @@ class OpenAIProvider readonly REQUEST_TIMEOUT_MS = 60000; - models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name })); + defaultModels = modelConfigs; providerMeta = { displayName: "OpenAI", @@ -62,25 +78,11 @@ class OpenAIProvider const { providerConfig: { openaiUrl }, } = payload; - const path = OpenaiMetas.ChatPath; - let baseUrl = openaiUrl; + console.log("[Proxy Endpoint] ", openaiUrl, path); - if (!baseUrl) { - baseUrl = "/api/openai"; - } - - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.slice(0, baseUrl.length - 1); - } - if (!baseUrl.startsWith("http") && !baseUrl.startsWith("/api/openai")) { - baseUrl = "https://" + baseUrl; - } - - console.log("[Proxy Endpoint] ", baseUrl, path); - - return [baseUrl, path].join("/"); + return [openaiUrl, path].join("/"); } private getHeaders(payload: InternalChatRequestPayload) { @@ -90,14 +92,9 @@ class OpenAIProvider "Content-Type": "application/json", Accept: "application/json", }; - const authHeader = "Authorization"; - const makeBearer = (s: string) => `Bearer ${s.trim()}`; - const validString = (x?: string): x is string => Boolean(x && x.length > 0); - - // when using google api in app, not set auth header if (validString(openaiApiKey)) { - headers[authHeader] = makeBearer(openaiApiKey); + headers["Authorization"] = makeBearer(openaiApiKey); } return headers; @@ -143,9 +140,11 @@ class OpenAIProvider }; } - private readWholeMessageResponseBody(res: any) { + private readWholeMessageResponseBody(res: { + choices: { message: { content: any } }[]; + }) { return { - message: res.choices?.at(0)?.message?.content ?? "", + message: res.choices?.[0]?.message?.content ?? "", }; } @@ -190,52 +189,12 @@ class OpenAIProvider streamChat( payload: InternalChatRequestPayload, - onProgress: (message: string, chunk: string) => void, - onFinish: (message: string) => void, - onError: (err: Error) => void, + handlers: ChatHandlers, ) { const requestPayload = this.formatChatPayload(payload); const timer = this.getTimer(); - let responseText = ""; - let remainText = ""; - let finished = false; - - // animate response to make it looks smooth - const animateResponseText = () => { - if (finished || timer.signal.aborted) { - responseText += remainText; - console.log("[Response Animation] finished"); - if (responseText?.length === 0) { - onError(new Error("empty response from server")); - } - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - onProgress(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - }; - - // start animaion - animateResponseText(); - - const finish = () => { - if (!finished) { - finished = true; - onFinish(responseText + remainText); - } - }; - - timer.signal.onabort = finish; - fetchEventSource(requestPayload.url, { ...requestPayload, async onopen(res) { @@ -244,8 +203,8 @@ class OpenAIProvider console.log("[OpenAI] request response content type: ", contentType); if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); + const responseText = await res.clone().text(); + return handlers.onFlash(responseText); } if ( @@ -255,29 +214,29 @@ class OpenAIProvider ?.startsWith(EventStreamContentType) || res.status !== 200 ) { - const responseTexts = [responseText]; + const responseTexts = []; + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + let extraInfo = await res.clone().text(); try { const resJson = await res.clone().json(); extraInfo = prettyObject(resJson); } catch {} - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - if (extraInfo) { responseTexts.push(extraInfo); } - responseText = responseTexts.join("\n\n"); + const responseText = responseTexts.join("\n\n"); - return finish(); + return handlers.onFlash(responseText); } }, onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); + if (msg.data === "[DONE]") { + return; } const text = msg.data; try { @@ -286,20 +245,19 @@ class OpenAIProvider delta: { content: string }; }>; const delta = choices[0]?.delta?.content; - const textmoderation = json?.prompt_filter_results; if (delta) { - remainText += delta; + handlers.onProgress(delta); } } catch (e) { console.error("[Request] parse error", text, msg); } }, onclose() { - finish(); + handlers.onFinish(); }, onerror(e) { - onError(e); + handlers.onError(e); throw e; }, openWhenHidden: true, @@ -307,6 +265,23 @@ class OpenAIProvider return timer; } + + async getAvailableModels( + providerConfig: Record, + ): Promise { + const { openaiApiKey, openaiUrl } = providerConfig; + const res = await fetch(`${openaiUrl}/vi/models`, { + headers: { + Authorization: `Bearer ${openaiApiKey}`, + }, + method: "GET", + }); + const data: ModelList = await res.json(); + + return data.data.map((o) => ({ + name: o.id, + })); + } } export default OpenAIProvider; diff --git a/app/client/providers/openai/locale.ts b/app/client/providers/openai/locale.ts index 30e269c4e..dab14e34f 100644 --- a/app/client/providers/openai/locale.ts +++ b/app/client/providers/openai/locale.ts @@ -1,4 +1,4 @@ -import { getLocaleText } from "../../core/locale"; +import { getLocaleText } from "../../common/locale"; export default getLocaleText< { @@ -11,6 +11,9 @@ export default getLocaleText< Endpoint: { Title: string; SubTitle: string; + Error: { + EndWithBackslash: string; + }; }; }, "en" @@ -26,6 +29,9 @@ export default getLocaleText< Endpoint: { Title: "接口地址", SubTitle: "除默认地址外,必须包含 http(s)://", + Error: { + EndWithBackslash: "不能以「/」结尾", + }, }, }, en: { @@ -38,6 +44,9 @@ export default getLocaleText< Endpoint: { Title: "OpenAI Endpoint", SubTitle: "Must starts with http(s):// or use /api/openai as default", + Error: { + EndWithBackslash: "Cannot end with '/'", + }, }, }, pt: { @@ -50,6 +59,9 @@ export default getLocaleText< Endpoint: { Title: "Endpoint OpenAI", SubTitle: "Deve começar com http(s):// ou usar /api/openai como padrão", + Error: { + EndWithBackslash: "Não é possível terminar com '/'", + }, }, }, sk: { @@ -63,6 +75,9 @@ export default getLocaleText< Title: "Koncový bod OpenAI", SubTitle: "Musí začínať http(s):// alebo použiť /api/openai ako predvolený", + Error: { + EndWithBackslash: "Nemôže končiť znakom „/“", + }, }, }, tw: { @@ -75,6 +90,9 @@ export default getLocaleText< Endpoint: { Title: "介面(Endpoint) 地址", SubTitle: "除預設地址外,必須包含 http(s)://", + Error: { + EndWithBackslash: "不能以「/」結尾", + }, }, }, }, diff --git a/app/client/providers/openai/utils.ts b/app/client/providers/openai/utils.ts new file mode 100644 index 000000000..24f6ef4f0 --- /dev/null +++ b/app/client/providers/openai/utils.ts @@ -0,0 +1,18 @@ +export const makeBearer = (s: string) => `Bearer ${s.trim()}`; + +export const validString = (x?: string): x is string => + Boolean(x && x.length > 0); + +export function prettyObject(msg: any) { + const obj = msg; + if (typeof msg !== "string") { + msg = JSON.stringify(msg, null, " "); + } + if (msg === "{}") { + return obj.toString(); + } + if (msg.startsWith("```json")) { + return msg; + } + return ["```json", msg, "```"].join("\n"); +} diff --git a/app/components/List/index.tsx b/app/components/List/index.tsx index f71456544..46e3036ee 100644 --- a/app/components/List/index.tsx +++ b/app/components/List/index.tsx @@ -37,6 +37,8 @@ type Error = error: false; }; +type Validate = (v: any) => Error | Promise; + export interface ListItemProps { title: string; subTitle?: string; @@ -44,7 +46,7 @@ export interface ListItemProps { className?: string; onClick?: () => void; nextline?: boolean; - validator?: (v: any) => Error | Promise; + validator?: Validate | Validate[]; } export const ListContext = createContext< @@ -92,7 +94,15 @@ export function ListItem(props: ListItemProps) { }, []); const handleValidate = useCallback((v: any) => { - const insideValidator = validator || (() => {}); + let insideValidator; + if (!validator) { + insideValidator = () => {}; + } else if (Array.isArray(validator)) { + insideValidator = (v: any) => + Promise.race(validator.map((validate) => validate(v))); + } else { + insideValidator = validator; + } Promise.resolve(insideValidator(v)).then((result) => { if (result && result.error) { diff --git a/app/store/provider.ts b/app/store/provider.ts index 691bc937e..bd7b366e3 100644 --- a/app/store/provider.ts +++ b/app/store/provider.ts @@ -9,22 +9,37 @@ import { import { StoreKey } from "../constant"; import { createPersistStore } from "../utils/store"; -export const DEFAULT_CONFIG = { - lastUpdate: Date.now(), // timestamp, to merge state +const firstUpdate = Date.now(); - providers: ProviderClient.getProviderTemplateList() - .filter((p) => p !== NextChatProvider) - .map((p) => createProvider(p)), -}; +function getDefaultConfig() { + const providers = Object.values(ProviderClient.ProviderTemplates) + .filter((t) => !(t instanceof NextChatProvider)) + .map((t) => createProvider(t, true)); -export type ProvidersConfig = typeof DEFAULT_CONFIG; + const initProvider = providers[0]; + + const currentModel = + initProvider.models.find((m) => m.isDefaultSelected) || + initProvider.models[0]; + + return { + lastUpdate: firstUpdate, // timestamp, to merge state + + currentModel: currentModel.name, + currentProvider: initProvider.name, + + providers, + }; +} + +export type ProvidersConfig = ReturnType; export const useProviders = createPersistStore( - { ...DEFAULT_CONFIG }, + { ...getDefaultConfig() }, (set, get) => { const methods = { reset() { - set(() => ({ ...DEFAULT_CONFIG })); + set(() => getDefaultConfig()); }, addProvider(provider: Provider) { @@ -53,10 +68,14 @@ export const useProviders = createPersistStore( return get().providers.find((p) => p.name === providerName); }, - addModel(model: Omit, provider: Provider) { + addModel( + model: Omit, + provider: Provider, + ) { const newModel: Model = { - providerTemplateName: provider.providerTemplateName, ...model, + providerTemplateName: provider.providerTemplateName, + customized: true, }; return methods.updateProvider({ ...provider, @@ -80,6 +99,13 @@ export const useProviders = createPersistStore( }); }, + switchModel(model: Model, provider: Provider) { + set(() => ({ + currentModel: model.name, + currentProvider: provider.name, + })); + }, + getModel( modelName: string, providerName: string, diff --git a/app/utils/hooks.ts b/app/utils/hooks.ts index 55d5d4fca..c5927ee14 100644 --- a/app/utils/hooks.ts +++ b/app/utils/hooks.ts @@ -1,6 +1,6 @@ import { useMemo } from "react"; import { useAccessStore, useAppConfig } from "../store"; -import { collectModels, collectModelsWithDefaultModel } from "./model"; +import { collectModelsWithDefaultModel } from "./model"; export function useAllModels() { const accessStore = useAccessStore(); diff --git a/yarn.lock b/yarn.lock index 7b6e4f5b5..95459d3db 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5312,16 +5312,11 @@ mz@^2.7.0: object-assign "^4.0.1" thenify-all "^1.0.0" -nanoid@^3.3.6: +nanoid@^3.3.6, nanoid@^3.3.7: version "3.3.7" resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.7.tgz#d0c301a691bc8d54efa0a2226ccf3fe2fd656bd8" integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g== -nanoid@^3.3.7: - version "3.3.7" - resolved "https://registry.npmmirror.com/nanoid/-/nanoid-3.3.7.tgz#d0c301a691bc8d54efa0a2226ccf3fe2fd656bd8" - integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g== - nanoid@^5.0.3: version "5.0.3" resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-5.0.3.tgz#6c97f53d793a7a1de6a38ebb46f50f95bf9793c7"