From f68cd2c5c04a33dda4187ee7db4bbfb4026b9e40 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 12:27:44 +0800 Subject: [PATCH] review code --- app/client/platforms/baidu.ts | 10 +++++----- app/constant.ts | 23 +++++++++++++++++------ app/utils/model.ts | 6 ++---- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts index e2f6f12dd..4fc3d2f64 100644 --- a/app/client/platforms/baidu.ts +++ b/app/client/platforms/baidu.ts @@ -2,7 +2,7 @@ import { ApiPath, Baidu, - DEFAULT_API_HOST, + BAIDU_BASE_URL, REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; @@ -21,7 +21,7 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import { getClientConfig } from "@/app/config/client"; -import { getMessageTextContent, isVisionModel } from "@/app/utils"; +import { getMessageTextContent } from "@/app/utils"; export interface OpenAIListModelResponse { object: string; @@ -58,7 +58,8 @@ export class ErnieApi implements LLMApi { if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu; + // do not use proxy for baidubce api + baseUrl = isApp ? BAIDU_BASE_URL : ApiPath.Baidu; } if (baseUrl.endsWith("/")) { @@ -78,10 +79,9 @@ export class ErnieApi implements LLMApi { } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v) => ({ role: v.role, - content: visionModel ? v.content : getMessageTextContent(v), + content: getMessageTextContent(v), })); const modelConfig = { diff --git a/app/constant.ts b/app/constant.ts index 6ffc0e0b3..0fd4d1c24 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -112,9 +112,20 @@ export const Google = { }; export const Baidu = { - ExampleEndpoint: "https://aip.baidubce.com", - ChatPath: (modelName: string) => - `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`, + ExampleEndpoint: BAIDU_BASE_URL, + ChatPath: (modelName: string) => { + let endpoint = modelName; + if (modelName === "ernie-4.0-8k") { + endpoint = "completions_pro"; + } + if (modelName === "ernie-4.0-8k-preview-0518") { + endpoint = "completions_adv_pro"; + } + if (modelName === "ernie-3.5-8k") { + endpoint = "completions"; + } + return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; + }, }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -188,11 +199,11 @@ const anthropicModels = [ const baiduModels = [ "ernie-4.0-turbo-8k", - "completions_pro=ernie-4.0-8k", + "ernie-4.0-8k", "ernie-4.0-8k-preview", - "completions_adv_pro=ernie-4.0-8k-preview-0518", + "ernie-4.0-8k-preview-0518", "ernie-4.0-8k-latest", - "completions=ernie-3.5-8k", + "ernie-3.5-8k", "ernie-3.5-8k-0205", ]; diff --git a/app/utils/model.ts b/app/utils/model.ts index 6a02ed7eb..7c778888e 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -24,13 +24,11 @@ export function collectModelTable( // default models models.forEach((m) => { - // supoort name=displayName eg:completions_pro=ernie-4.0-8k - const [name, displayName] = m.name?.split("="); // using @ as fullName - modelTable[`${name}@${m?.provider?.id}`] = { + modelTable[`${m.name}@${m?.provider?.id}`] = { ...m, name, - displayName: displayName || name, // 'provider' is copied over if it exists + displayName: m.name, // 'provider' is copied over if it exists }; });