From 785d3748e10c6c2fa5b21129aa8e35905876a171 Mon Sep 17 00:00:00 2001
From: Dogtiti <499960698@qq.com>
Date: Sat, 6 Jul 2024 13:05:09 +0800
Subject: [PATCH 01/20] feat: support baidu model

---
 .gitignore                       |   2 +-
 app/api/auth.ts                  |   3 +
 app/api/baidu/[...path]/route.ts | 176 +++++++++++++++++++++
 app/client/api.ts                |   5 +
 app/client/platforms/baidu.ts    | 252 +++++++++++++++++++++++++++++++
 app/components/exporter.tsx      |   2 +
 app/components/home.tsx          |   2 +
 app/components/settings.tsx      |  62 ++++++++
 app/config/server.ts             |  17 ++-
 app/constant.ts                  |  32 ++++
 app/locales/cn.ts                |  16 ++
 app/store/access.ts              |  10 ++
 app/store/chat.ts                |   4 +
 app/utils/model.ts               |   7 +-
 14 files changed, 586 insertions(+), 4 deletions(-)
 create mode 100644 app/api/baidu/[...path]/route.ts
 create mode 100644 app/client/platforms/baidu.ts

diff --git a/.gitignore b/.gitignore
index b00b0e325..a24c6e047 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,4 +43,4 @@ dev
 .env
 
 *.key
-*.key.pub
\ No newline at end of file
+*.key.pub
diff --git a/app/api/auth.ts b/app/api/auth.ts
index 2b4702aed..cce8847f4 100644
--- a/app/api/auth.ts
+++ b/app/api/auth.ts
@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
       case ModelProvider.Claude:
         systemApiKey = serverConfig.anthropicApiKey;
         break;
+      case ModelProvider.Ernie:
+        systemApiKey = serverConfig.baiduApiKey;
+        break;
       case ModelProvider.GPT:
       default:
         if (req.nextUrl.pathname.includes("azure/deployments")) {
diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts
new file mode 100644
index 000000000..27676d29d
--- /dev/null
+++ b/app/api/baidu/[...path]/route.ts
@@ -0,0 +1,176 @@
+import { getServerSideConfig } from "@/app/config/server";
+import {
+  BAIDU_BASE_URL,
+  ApiPath,
+  ModelProvider,
+  BAIDU_OATUH_URL,
+  ServiceProvider,
+} from "@/app/constant";
+import { prettyObject } from "@/app/utils/format";
+import { NextRequest, NextResponse } from "next/server";
+import { auth } from "@/app/api/auth";
+import { isModelAvailableInServer } from "@/app/utils/model";
+
+const serverConfig = getServerSideConfig();
+
+async function handle(
+  req: NextRequest,
+  { params }: { params: { path: string[] } },
+) {
+  console.log("[Baidu Route] params ", params);
+
+  if (req.method === "OPTIONS") {
+    return NextResponse.json({ body: "OK" }, { status: 200 });
+  }
+
+  const authResult = auth(req, ModelProvider.Ernie);
+  if (authResult.error) {
+    return NextResponse.json(authResult, {
+      status: 401,
+    });
+  }
+
+  try {
+    const response = await request(req);
+    return response;
+  } catch (e) {
+    console.error("[Baidu] ", e);
+    return NextResponse.json(prettyObject(e));
+  }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";
+export const preferredRegion = [
+  "arn1",
+  "bom1",
+  "cdg1",
+  "cle1",
+  "cpt1",
+  "dub1",
+  "fra1",
+  "gru1",
+  "hnd1",
+  "iad1",
+  "icn1",
+  "kix1",
+  "lhr1",
+  "pdx1",
+  "sfo1",
+  "sin1",
+  "syd1",
+];
+
+async function request(req: NextRequest) {
+  const controller = new AbortController();
+
+  let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Baidu, "");
+
+  let baseUrl = serverConfig.baiduUrl || BAIDU_BASE_URL;
+
+  if (!baseUrl.startsWith("http")) {
+    baseUrl = `https://${baseUrl}`;
+  }
+
+  if (baseUrl.endsWith("/")) {
+    baseUrl = baseUrl.slice(0, -1);
+  }
+
+  console.log("[Proxy] ", path);
+  console.log("[Base Url]", baseUrl);
+
+  const timeoutId = setTimeout(
+    () => {
+      controller.abort();
+    },
+    10 * 60 * 1000,
+  );
+
+  const { access_token } = await getAccessToken();
+  const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
+
+  const fetchOptions: RequestInit = {
+    headers: {
+      "Content-Type": "application/json",
+    },
+    method: req.method,
+    body: req.body,
+    redirect: "manual",
+    // @ts-ignore
+    duplex: "half",
+    signal: controller.signal,
+  };
+
+  // #1815 try to refuse some request to some models
+  if (serverConfig.customModels && req.body) {
+    try {
+      const clonedBody = await req.text();
+      fetchOptions.body = clonedBody;
+
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
+
+      // not undefined and is false
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.Baidu as string,
+        )
+      ) {
+        return NextResponse.json(
+          {
+            error: true,
+            message: `you are not allowed to use ${jsonBody?.model} model`,
+          },
+          {
+            status: 403,
+          },
+        );
+      }
+    } catch (e) {
+      console.error(`[Baidu] filter`, e);
+    }
+  }
+  console.log("[Baidu request]", fetchOptions.headers, req.method);
+  try {
+    const res = await fetch(fetchUrl, fetchOptions);
+
+    console.log("[Baidu response]", res.status, "   ", res.headers, res.url);
+    // to prevent browser prompt for credentials
+    const newHeaders = new Headers(res.headers);
+    newHeaders.delete("www-authenticate");
+    // to disable nginx buffering
+    newHeaders.set("X-Accel-Buffering", "no");
+
+    return new Response(res.body, {
+      status: res.status,
+      statusText: res.statusText,
+      headers: newHeaders,
+    });
+  } finally {
+    clearTimeout(timeoutId);
+  }
+}
+
+/**
+ * 使用 AK,SK 生成鉴权签名(Access Token)
+ * @return 鉴权签名信息
+ */
+async function getAccessToken(): Promise<{
+  access_token: string;
+  expires_in: number;
+  error?: number;
+}> {
+  const AK = serverConfig.baiduApiKey;
+  const SK = serverConfig.baiduSecretKey;
+  const res = await fetch(
+    `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
+    {
+      method: "POST",
+    },
+  );
+  const resJson = await res.json();
+  return resJson;
+}
diff --git a/app/client/api.ts b/app/client/api.ts
index 41ccbd8e1..74e0ef9a9 100644
--- a/app/client/api.ts
+++ b/app/client/api.ts
@@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
 import { ChatGPTApi } from "./platforms/openai";
 import { GeminiProApi } from "./platforms/google";
 import { ClaudeApi } from "./platforms/anthropic";
+import { ErnieApi } from "./platforms/baidu";
+
 export const ROLES = ["system", "user", "assistant"] as const;
 export type MessageRole = (typeof ROLES)[number];
 
@@ -104,6 +106,9 @@ export class ClientApi {
       case ModelProvider.Claude:
         this.llm = new ClaudeApi();
         break;
+      case ModelProvider.Ernie:
+        this.llm = new ErnieApi();
+        break;
       default:
         this.llm = new ChatGPTApi();
     }
diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts
new file mode 100644
index 000000000..e2f6f12dd
--- /dev/null
+++ b/app/client/platforms/baidu.ts
@@ -0,0 +1,252 @@
+"use client";
+import {
+  ApiPath,
+  Baidu,
+  DEFAULT_API_HOST,
+  REQUEST_TIMEOUT_MS,
+} from "@/app/constant";
+import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+
+import {
+  ChatOptions,
+  getHeaders,
+  LLMApi,
+  LLMModel,
+  MultimodalContent,
+} from "../api";
+import Locale from "../../locales";
+import {
+  EventStreamContentType,
+  fetchEventSource,
+} from "@fortaine/fetch-event-source";
+import { prettyObject } from "@/app/utils/format";
+import { getClientConfig } from "@/app/config/client";
+import { getMessageTextContent, isVisionModel } from "@/app/utils";
+
+export interface OpenAIListModelResponse {
+  object: string;
+  data: Array<{
+    id: string;
+    object: string;
+    root: string;
+  }>;
+}
+
+interface RequestPayload {
+  messages: {
+    role: "system" | "user" | "assistant";
+    content: string | MultimodalContent[];
+  }[];
+  stream?: boolean;
+  model: string;
+  temperature: number;
+  presence_penalty: number;
+  frequency_penalty: number;
+  top_p: number;
+  max_tokens?: number;
+}
+
+export class ErnieApi implements LLMApi {
+  path(path: string): string {
+    const accessStore = useAccessStore.getState();
+
+    let baseUrl = "";
+
+    if (accessStore.useCustomConfig) {
+      baseUrl = accessStore.baiduUrl;
+    }
+
+    if (baseUrl.length === 0) {
+      const isApp = !!getClientConfig()?.isApp;
+      baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu;
+    }
+
+    if (baseUrl.endsWith("/")) {
+      baseUrl = baseUrl.slice(0, baseUrl.length - 1);
+    }
+    if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Baidu)) {
+      baseUrl = "https://" + baseUrl;
+    }
+
+    console.log("[Proxy Endpoint] ", baseUrl, path);
+
+    return [baseUrl, path].join("/");
+  }
+
+  extractMessage(res: any) {
+    return res.choices?.at(0)?.message?.content ?? "";
+  }
+
+  async chat(options: ChatOptions) {
+    const visionModel = isVisionModel(options.config.model);
+    const messages = options.messages.map((v) => ({
+      role: v.role,
+      content: visionModel ? v.content : getMessageTextContent(v),
+    }));
+
+    const modelConfig = {
+      ...useAppConfig.getState().modelConfig,
+      ...useChatStore.getState().currentSession().mask.modelConfig,
+      ...{
+        model: options.config.model,
+      },
+    };
+
+    const requestPayload: RequestPayload = {
+      messages,
+      stream: options.config.stream,
+      model: modelConfig.model,
+      temperature: modelConfig.temperature,
+      presence_penalty: modelConfig.presence_penalty,
+      frequency_penalty: modelConfig.frequency_penalty,
+      top_p: modelConfig.top_p,
+    };
+
+    console.log("[Request] Baidu payload: ", requestPayload);
+
+    const shouldStream = !!options.config.stream;
+    const controller = new AbortController();
+    options.onController?.(controller);
+
+    try {
+      const chatPath = this.path(Baidu.ChatPath(modelConfig.model));
+      const chatPayload = {
+        method: "POST",
+        body: JSON.stringify(requestPayload),
+        signal: controller.signal,
+        headers: getHeaders(),
+      };
+
+      // make a fetch request
+      const requestTimeoutId = setTimeout(
+        () => controller.abort(),
+        REQUEST_TIMEOUT_MS,
+      );
+
+      if (shouldStream) {
+        let responseText = "";
+        let remainText = "";
+        let finished = false;
+
+        // animate response to make it looks smooth
+        function animateResponseText() {
+          if (finished || controller.signal.aborted) {
+            responseText += remainText;
+            console.log("[Response Animation] finished");
+            if (responseText?.length === 0) {
+              options.onError?.(new Error("empty response from server"));
+            }
+            return;
+          }
+
+          if (remainText.length > 0) {
+            const fetchCount = Math.max(1, Math.round(remainText.length / 60));
+            const fetchText = remainText.slice(0, fetchCount);
+            responseText += fetchText;
+            remainText = remainText.slice(fetchCount);
+            options.onUpdate?.(responseText, fetchText);
+          }
+
+          requestAnimationFrame(animateResponseText);
+        }
+
+        // start animaion
+        animateResponseText();
+
+        const finish = () => {
+          if (!finished) {
+            finished = true;
+            options.onFinish(responseText + remainText);
+          }
+        };
+
+        controller.signal.onabort = finish;
+
+        fetchEventSource(chatPath, {
+          ...chatPayload,
+          async onopen(res) {
+            clearTimeout(requestTimeoutId);
+            const contentType = res.headers.get("content-type");
+            console.log("[Baidu] request response content type: ", contentType);
+
+            if (contentType?.startsWith("text/plain")) {
+              responseText = await res.clone().text();
+              return finish();
+            }
+
+            if (
+              !res.ok ||
+              !res.headers
+                .get("content-type")
+                ?.startsWith(EventStreamContentType) ||
+              res.status !== 200
+            ) {
+              const responseTexts = [responseText];
+              let extraInfo = await res.clone().text();
+              try {
+                const resJson = await res.clone().json();
+                extraInfo = prettyObject(resJson);
+              } catch {}
+
+              if (res.status === 401) {
+                responseTexts.push(Locale.Error.Unauthorized);
+              }
+
+              if (extraInfo) {
+                responseTexts.push(extraInfo);
+              }
+
+              responseText = responseTexts.join("\n\n");
+
+              return finish();
+            }
+          },
+          onmessage(msg) {
+            if (msg.data === "[DONE]" || finished) {
+              return finish();
+            }
+            const text = msg.data;
+            try {
+              const json = JSON.parse(text);
+              const delta = json?.result;
+              if (delta) {
+                remainText += delta;
+              }
+            } catch (e) {
+              console.error("[Request] parse error", text, msg);
+            }
+          },
+          onclose() {
+            finish();
+          },
+          onerror(e) {
+            options.onError?.(e);
+            throw e;
+          },
+          openWhenHidden: true,
+        });
+      } else {
+        const res = await fetch(chatPath, chatPayload);
+        clearTimeout(requestTimeoutId);
+
+        const resJson = await res.json();
+        const message = this.extractMessage(resJson);
+        options.onFinish(message);
+      }
+    } catch (e) {
+      console.log("[Request] failed to make a chat request", e);
+      options.onError?.(e as Error);
+    }
+  }
+  async usage() {
+    return {
+      used: 0,
+      total: 0,
+    };
+  }
+
+  async models(): Promise<LLMModel[]> {
+    return [];
+  }
+}
+export { Baidu };
diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx
index 7281fc2f1..ec0060c72 100644
--- a/app/components/exporter.tsx
+++ b/app/components/exporter.tsx
@@ -321,6 +321,8 @@ export function PreviewActions(props: {
       api = new ClientApi(ModelProvider.GeminiPro);
     } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
       api = new ClientApi(ModelProvider.Claude);
+    } else if (config.modelConfig.providerName == ServiceProvider.Baidu) {
+      api = new ClientApi(ModelProvider.Ernie);
     } else {
       api = new ClientApi(ModelProvider.GPT);
     }
diff --git a/app/components/home.tsx b/app/components/home.tsx
index addb5e803..00af1f4ba 100644
--- a/app/components/home.tsx
+++ b/app/components/home.tsx
@@ -175,6 +175,8 @@ export function useLoadData() {
     api = new ClientApi(ModelProvider.GeminiPro);
   } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
     api = new ClientApi(ModelProvider.Claude);
+  } else if (config.modelConfig.providerName == ServiceProvider.Baidu) {
+    api = new ClientApi(ModelProvider.Ernie);
   } else {
     api = new ClientApi(ModelProvider.GPT);
   }
diff --git a/app/components/settings.tsx b/app/components/settings.tsx
index db08b48a9..7db09940d 100644
--- a/app/components/settings.tsx
+++ b/app/components/settings.tsx
@@ -53,6 +53,7 @@ import Link from "next/link";
 import {
   Anthropic,
   Azure,
+  Baidu,
   Google,
   OPENAI_BASE_URL,
   Path,
@@ -1187,6 +1188,67 @@ export function Settings() {
                       </ListItem>
                     </>
                   )}
+                  {accessStore.provider === ServiceProvider.Baidu && (
+                    <>
+                      <ListItem
+                        title={Locale.Settings.Access.Baidu.Endpoint.Title}
+                        subTitle={
+                          Locale.Settings.Access.Anthropic.Endpoint.SubTitle +
+                          Baidu.ExampleEndpoint
+                        }
+                      >
+                        <input
+                          type="text"
+                          value={accessStore.baiduUrl}
+                          placeholder={Baidu.ExampleEndpoint}
+                          onChange={(e) =>
+                            accessStore.update(
+                              (access) =>
+                                (access.baiduUrl = e.currentTarget.value),
+                            )
+                          }
+                        ></input>
+                      </ListItem>
+                      <ListItem
+                        title={Locale.Settings.Access.Baidu.ApiKey.Title}
+                        subTitle={Locale.Settings.Access.Baidu.ApiKey.SubTitle}
+                      >
+                        <PasswordInput
+                          value={accessStore.baiduApiKey}
+                          type="text"
+                          placeholder={
+                            Locale.Settings.Access.Baidu.ApiKey.Placeholder
+                          }
+                          onChange={(e) => {
+                            accessStore.update(
+                              (access) =>
+                                (access.baiduApiKey = e.currentTarget.value),
+                            );
+                          }}
+                        />
+                      </ListItem>
+                      <ListItem
+                        title={Locale.Settings.Access.Baidu.SecretKey.Title}
+                        subTitle={
+                          Locale.Settings.Access.Baidu.SecretKey.SubTitle
+                        }
+                      >
+                        <PasswordInput
+                          value={accessStore.baiduSecretKey}
+                          type="text"
+                          placeholder={
+                            Locale.Settings.Access.Baidu.SecretKey.Placeholder
+                          }
+                          onChange={(e) => {
+                            accessStore.update(
+                              (access) =>
+                                (access.baiduSecretKey = e.currentTarget.value),
+                            );
+                          }}
+                        />
+                      </ListItem>
+                    </>
+                  )}
                 </>
               )}
             </>
diff --git a/app/config/server.ts b/app/config/server.ts
index b7c85ce6a..2d09c5479 100644
--- a/app/config/server.ts
+++ b/app/config/server.ts
@@ -35,6 +35,16 @@ declare global {
       // google tag manager
       GTM_ID?: string;
 
+      // anthropic only
+      ANTHROPIC_URL?: string;
+      ANTHROPIC_API_KEY?: string;
+      ANTHROPIC_API_VERSION?: string;
+
+      // baidu only
+      BAIDU_URL?: string;
+      BAIDU_API_KEY?: string;
+      BAIDU_SECRET_KEY?: string;
+
       // custom template for preprocessing user input
       DEFAULT_INPUT_TEMPLATE?: string;
     }
@@ -92,7 +102,7 @@ export const getServerSideConfig = () => {
   const isAzure = !!process.env.AZURE_URL;
   const isGoogle = !!process.env.GOOGLE_API_KEY;
   const isAnthropic = !!process.env.ANTHROPIC_API_KEY;
-
+  const isBaidu = !!process.env.BAIDU_API_KEY;
   // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
   // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim());
   // const randomIndex = Math.floor(Math.random() * apiKeys.length);
@@ -124,6 +134,11 @@ export const getServerSideConfig = () => {
     anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
     anthropicUrl: process.env.ANTHROPIC_URL,
 
+    isBaidu,
+    baiduUrl: process.env.BAIDU_URL,
+    baiduApiKey: getApiKey(process.env.BAIDU_API_KEY),
+    baiduSecretKey: process.env.BAIDU_SECRET_KEY,
+
     gtmId: process.env.GTM_ID,
 
     needCode: ACCESS_CODES.size > 0,
diff --git a/app/constant.ts b/app/constant.ts
index d44b5b817..6ffc0e0b3 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -14,6 +14,10 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
 
 export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
 
+export const BAIDU_BASE_URL = "https://aip.baidubce.com";
+
+export const BAIDU_OATUH_URL = `${BAIDU_BASE_URL}/oauth/2.0/token`;
+
 export enum Path {
   Home = "/",
   Chat = "/chat",
@@ -28,6 +32,7 @@ export enum ApiPath {
   Azure = "/api/azure",
   OpenAI = "/api/openai",
   Anthropic = "/api/anthropic",
+  Baidu = "/api/baidu",
 }
 
 export enum SlotID {
@@ -71,12 +76,14 @@ export enum ServiceProvider {
   Azure = "Azure",
   Google = "Google",
   Anthropic = "Anthropic",
+  Baidu = "Baidu",
 }
 
 export enum ModelProvider {
   GPT = "GPT",
   GeminiPro = "GeminiPro",
   Claude = "Claude",
+  Ernie = "Ernie",
 }
 
 export const Anthropic = {
@@ -104,6 +111,12 @@ export const Google = {
   ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
 };
 
+export const Baidu = {
+  ExampleEndpoint: "https://aip.baidubce.com",
+  ChatPath: (modelName: string) =>
+    `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`,
+};
+
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
 // export const DEFAULT_SYSTEM_TEMPLATE = `
 // You are ChatGPT, a large language model trained by {{ServiceProvider}}.
@@ -173,6 +186,16 @@ const anthropicModels = [
   "claude-3-5-sonnet-20240620",
 ];
 
+const baiduModels = [
+  "ernie-4.0-turbo-8k",
+  "completions_pro=ernie-4.0-8k",
+  "ernie-4.0-8k-preview",
+  "completions_adv_pro=ernie-4.0-8k-preview-0518",
+  "ernie-4.0-8k-latest",
+  "completions=ernie-3.5-8k",
+  "ernie-3.5-8k-0205",
+];
+
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
     name,
@@ -210,6 +233,15 @@ export const DEFAULT_MODELS = [
       providerType: "anthropic",
     },
   })),
+  ...baiduModels.map((name) => ({
+    name,
+    available: true,
+    provider: {
+      id: "baidu",
+      providerName: "Baidu",
+      providerType: "baidu",
+    },
+  })),
 ] as const;
 
 export const CHAT_PAGE_SIZE = 15;
diff --git a/app/locales/cn.ts b/app/locales/cn.ts
index 2ff94e32d..a872ee75a 100644
--- a/app/locales/cn.ts
+++ b/app/locales/cn.ts
@@ -347,6 +347,22 @@ const cn = {
           SubTitle: "选择一个特定的 API 版本",
         },
       },
+      Baidu: {
+        ApiKey: {
+          Title: "接口密钥",
+          SubTitle: "使用自定义 Baidu API Key 绕过密码访问限制",
+          Placeholder: "Baidu API Key",
+        },
+        SecretKey: {
+          Title: "接口密钥",
+          SubTitle: "使用自定义 Baidu Secret Key 绕过密码访问限制",
+          Placeholder: "Baidu Secret Key",
+        },
+        Endpoint: {
+          Title: "接口地址",
+          SubTitle: "样例:",
+        },
+      },
       CustomModel: {
         Title: "自定义模型名",
         SubTitle: "增加自定义模型可选项,使用英文逗号隔开",
diff --git a/app/store/access.ts b/app/store/access.ts
index 03780779e..7e6d01b34 100644
--- a/app/store/access.ts
+++ b/app/store/access.ts
@@ -47,6 +47,11 @@ const DEFAULT_ACCESS_STATE = {
   anthropicApiVersion: "2023-06-01",
   anthropicUrl: "",
 
+  // baidu
+  baiduUrl: "",
+  baiduApiKey: "",
+  baiduSecretKey: "",
+
   // server config
   needCode: true,
   hideUserApiKey: false,
@@ -83,6 +88,10 @@ export const useAccessStore = createPersistStore(
       return ensure(get(), ["anthropicApiKey"]);
     },
 
+    isValidBaidu() {
+      return ensure(get(), ["baiduApiKey", "baiduSecretKey"]);
+    },
+
     isAuthorized() {
       this.fetch();
 
@@ -92,6 +101,7 @@ export const useAccessStore = createPersistStore(
         this.isValidAzure() ||
         this.isValidGoogle() ||
         this.isValidAnthropic() ||
+        this.isValidBaidu() ||
         !this.enabledAccessControl() ||
         (this.enabledAccessControl() && ensure(get(), ["accessCode"]))
       );
diff --git a/app/store/chat.ts b/app/store/chat.ts
index 44d41830a..45ab479d9 100644
--- a/app/store/chat.ts
+++ b/app/store/chat.ts
@@ -368,6 +368,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.Baidu) {
+          api = new ClientApi(ModelProvider.Ernie);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }
@@ -552,6 +554,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.Baidu) {
+          api = new ClientApi(ModelProvider.Ernie);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 249987726..6a02ed7eb 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -24,10 +24,13 @@ export function collectModelTable(
 
   // default models
   models.forEach((m) => {
+    // supoort name=displayName eg:completions_pro=ernie-4.0-8k
+    const [name, displayName] = m.name?.split("=");
     // using <modelName>@<providerId> as fullName
-    modelTable[`${m.name}@${m?.provider?.id}`] = {
+    modelTable[`${name}@${m?.provider?.id}`] = {
       ...m,
-      displayName: m.name, // 'provider' is copied over if it exists
+      name,
+      displayName: displayName || name, // 'provider' is copied over if it exists
     };
   });
 

From 9b3b4494ba6ff6a517ca17376d2550b1aa651c00 Mon Sep 17 00:00:00 2001
From: Dogtiti <499960698@qq.com>
Date: Sat, 6 Jul 2024 14:59:37 +0800
Subject: [PATCH 02/20] wip: doubao

---
 app/api/auth.ts                      |   3 +
 app/api/bytedance/[...path]/route.ts | 160 +++++++++++++++++
 app/client/api.ts                    |   5 +
 app/client/platforms/bytedance.ts    | 260 +++++++++++++++++++++++++++
 app/components/exporter.tsx          |   2 +
 app/components/home.tsx              |   2 +
 app/config/server.ts                 |   9 +
 app/constant.ts                      |  21 +++
 app/store/access.ts                  |   9 +
 app/store/chat.ts                    |   4 +
 10 files changed, 475 insertions(+)
 create mode 100644 app/api/bytedance/[...path]/route.ts
 create mode 100644 app/client/platforms/bytedance.ts

diff --git a/app/api/auth.ts b/app/api/auth.ts
index 2b4702aed..9c334f2fe 100644
--- a/app/api/auth.ts
+++ b/app/api/auth.ts
@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
       case ModelProvider.Claude:
         systemApiKey = serverConfig.anthropicApiKey;
         break;
+      case ModelProvider.Doubao:
+        systemApiKey = serverConfig.bytedanceApiKey;
+        break;
       case ModelProvider.GPT:
       default:
         if (req.nextUrl.pathname.includes("azure/deployments")) {
diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance/[...path]/route.ts
new file mode 100644
index 000000000..bffb60f6c
--- /dev/null
+++ b/app/api/bytedance/[...path]/route.ts
@@ -0,0 +1,160 @@
+import { getServerSideConfig } from "@/app/config/server";
+import {
+  BYTEDANCE_BASE_URL,
+  ApiPath,
+  ModelProvider,
+  ServiceProvider,
+} from "@/app/constant";
+import { prettyObject } from "@/app/utils/format";
+import { NextRequest, NextResponse } from "next/server";
+import { auth } from "@/app/api/auth";
+import { isModelAvailableInServer } from "@/app/utils/model";
+
+const serverConfig = getServerSideConfig();
+
+async function handle(
+  req: NextRequest,
+  { params }: { params: { path: string[] } },
+) {
+  console.log("[ByteDance Route] params ", params);
+
+  if (req.method === "OPTIONS") {
+    return NextResponse.json({ body: "OK" }, { status: 200 });
+  }
+
+  const authResult = auth(req, ModelProvider.Doubao);
+  if (authResult.error) {
+    return NextResponse.json(authResult, {
+      status: 401,
+    });
+  }
+
+  try {
+    const response = await request(req);
+    return response;
+  } catch (e) {
+    console.error("[ByteDance] ", e);
+    return NextResponse.json(prettyObject(e));
+  }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";
+export const preferredRegion = [
+  "arn1",
+  "bom1",
+  "cdg1",
+  "cle1",
+  "cpt1",
+  "dub1",
+  "fra1",
+  "gru1",
+  "hnd1",
+  "iad1",
+  "icn1",
+  "kix1",
+  "lhr1",
+  "pdx1",
+  "sfo1",
+  "sin1",
+  "syd1",
+];
+
+async function request(req: NextRequest) {
+  const controller = new AbortController();
+
+  let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.ByteDance, "");
+
+  let baseUrl = serverConfig.bytedanceUrl || BYTEDANCE_BASE_URL;
+
+  if (!baseUrl.startsWith("http")) {
+    baseUrl = `https://${baseUrl}`;
+  }
+
+  if (baseUrl.endsWith("/")) {
+    baseUrl = baseUrl.slice(0, -1);
+  }
+
+  console.log("[Proxy] ", path);
+  console.log("[Base Url]", baseUrl);
+
+  const timeoutId = setTimeout(
+    () => {
+      controller.abort();
+    },
+    10 * 60 * 1000,
+  );
+
+  const fetchUrl = `${baseUrl}${path}`;
+
+  const fetchOptions: RequestInit = {
+    headers: {
+      "Content-Type": "application/json",
+      Authorization: req.headers.get("Authorization") ?? "",
+    },
+    method: req.method,
+    body: req.body,
+    redirect: "manual",
+    // @ts-ignore
+    duplex: "half",
+    signal: controller.signal,
+  };
+
+  // #1815 try to refuse some request to some models
+  if (serverConfig.customModels && req.body) {
+    try {
+      const clonedBody = await req.text();
+      fetchOptions.body = clonedBody;
+
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
+
+      // not undefined and is false
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.ByteDance as string,
+        )
+      ) {
+        return NextResponse.json(
+          {
+            error: true,
+            message: `you are not allowed to use ${jsonBody?.model} model`,
+          },
+          {
+            status: 403,
+          },
+        );
+      }
+    } catch (e) {
+      console.error(`[ByteDance] filter`, e);
+    }
+  }
+  console.log("[ByteDance request]", fetchOptions.headers, req.method);
+  try {
+    const res = await fetch(fetchUrl, fetchOptions);
+
+    console.log(
+      "[ByteDance response]",
+      res.status,
+      "   ",
+      res.headers,
+      res.url,
+    );
+    // to prevent browser prompt for credentials
+    const newHeaders = new Headers(res.headers);
+    newHeaders.delete("www-authenticate");
+    // to disable nginx buffering
+    newHeaders.set("X-Accel-Buffering", "no");
+
+    return new Response(res.body, {
+      status: res.status,
+      statusText: res.statusText,
+      headers: newHeaders,
+    });
+  } finally {
+    clearTimeout(timeoutId);
+  }
+}
diff --git a/app/client/api.ts b/app/client/api.ts
index 41ccbd8e1..ee43fc7cc 100644
--- a/app/client/api.ts
+++ b/app/client/api.ts
@@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
 import { ChatGPTApi } from "./platforms/openai";
 import { GeminiProApi } from "./platforms/google";
 import { ClaudeApi } from "./platforms/anthropic";
+import { DoubaoApi } from "./platforms/bytedance";
+
 export const ROLES = ["system", "user", "assistant"] as const;
 export type MessageRole = (typeof ROLES)[number];
 
@@ -104,6 +106,9 @@ export class ClientApi {
       case ModelProvider.Claude:
         this.llm = new ClaudeApi();
         break;
+      case ModelProvider.Doubao:
+        this.llm = new DoubaoApi();
+        break;
       default:
         this.llm = new ChatGPTApi();
     }
diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts
new file mode 100644
index 000000000..92c1fd558
--- /dev/null
+++ b/app/client/platforms/bytedance.ts
@@ -0,0 +1,260 @@
+"use client";
+import {
+  ApiPath,
+  ByteDance,
+  DEFAULT_API_HOST,
+  REQUEST_TIMEOUT_MS,
+} from "@/app/constant";
+import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+
+import {
+  ChatOptions,
+  getHeaders,
+  LLMApi,
+  LLMModel,
+  MultimodalContent,
+} from "../api";
+import Locale from "../../locales";
+import {
+  EventStreamContentType,
+  fetchEventSource,
+} from "@fortaine/fetch-event-source";
+import { prettyObject } from "@/app/utils/format";
+import { getClientConfig } from "@/app/config/client";
+import { getMessageTextContent, isVisionModel } from "@/app/utils";
+
+export interface OpenAIListModelResponse {
+  object: string;
+  data: Array<{
+    id: string;
+    object: string;
+    root: string;
+  }>;
+}
+
+interface RequestPayload {
+  messages: {
+    role: "system" | "user" | "assistant";
+    content: string | MultimodalContent[];
+  }[];
+  stream?: boolean;
+  model: string;
+  temperature: number;
+  presence_penalty: number;
+  frequency_penalty: number;
+  top_p: number;
+  max_tokens?: number;
+}
+
+export class DoubaoApi implements LLMApi {
+  path(path: string): string {
+    const accessStore = useAccessStore.getState();
+
+    let baseUrl = "";
+
+    if (accessStore.useCustomConfig) {
+      baseUrl = accessStore.bytedanceUrl;
+    }
+
+    if (baseUrl.length === 0) {
+      const isApp = !!getClientConfig()?.isApp;
+      baseUrl = isApp
+        ? DEFAULT_API_HOST + "/api/proxy/bytedance"
+        : ApiPath.ByteDance;
+    }
+
+    if (baseUrl.endsWith("/")) {
+      baseUrl = baseUrl.slice(0, baseUrl.length - 1);
+    }
+    if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ByteDance)) {
+      baseUrl = "https://" + baseUrl;
+    }
+
+    console.log("[Proxy Endpoint] ", baseUrl, path);
+
+    return [baseUrl, path].join("/");
+  }
+
+  extractMessage(res: any) {
+    return res.choices?.at(0)?.message?.content ?? "";
+  }
+
+  async chat(options: ChatOptions) {
+    const visionModel = isVisionModel(options.config.model);
+    const messages = options.messages.map((v) => ({
+      role: v.role,
+      content: visionModel ? v.content : getMessageTextContent(v),
+    }));
+
+    const modelConfig = {
+      ...useAppConfig.getState().modelConfig,
+      ...useChatStore.getState().currentSession().mask.modelConfig,
+      ...{
+        model: options.config.model,
+      },
+    };
+
+    const requestPayload: RequestPayload = {
+      messages,
+      stream: options.config.stream,
+      model: modelConfig.model,
+      temperature: modelConfig.temperature,
+      presence_penalty: modelConfig.presence_penalty,
+      frequency_penalty: modelConfig.frequency_penalty,
+      top_p: modelConfig.top_p,
+    };
+
+    console.log("[Request] ByteDance payload: ", requestPayload);
+
+    const shouldStream = !!options.config.stream;
+    const controller = new AbortController();
+    options.onController?.(controller);
+
+    try {
+      const chatPath = this.path(ByteDance.ChatPath);
+      const chatPayload = {
+        method: "POST",
+        body: JSON.stringify(requestPayload),
+        signal: controller.signal,
+        headers: getHeaders(),
+      };
+
+      // make a fetch request
+      const requestTimeoutId = setTimeout(
+        () => controller.abort(),
+        REQUEST_TIMEOUT_MS,
+      );
+
+      if (shouldStream) {
+        let responseText = "";
+        let remainText = "";
+        let finished = false;
+
+        // animate response to make it looks smooth
+        function animateResponseText() {
+          if (finished || controller.signal.aborted) {
+            responseText += remainText;
+            console.log("[Response Animation] finished");
+            if (responseText?.length === 0) {
+              options.onError?.(new Error("empty response from server"));
+            }
+            return;
+          }
+
+          if (remainText.length > 0) {
+            const fetchCount = Math.max(1, Math.round(remainText.length / 60));
+            const fetchText = remainText.slice(0, fetchCount);
+            responseText += fetchText;
+            remainText = remainText.slice(fetchCount);
+            options.onUpdate?.(responseText, fetchText);
+          }
+
+          requestAnimationFrame(animateResponseText);
+        }
+
+        // start animaion
+        animateResponseText();
+
+        const finish = () => {
+          if (!finished) {
+            finished = true;
+            options.onFinish(responseText + remainText);
+          }
+        };
+
+        controller.signal.onabort = finish;
+
+        fetchEventSource(chatPath, {
+          ...chatPayload,
+          async onopen(res) {
+            clearTimeout(requestTimeoutId);
+            const contentType = res.headers.get("content-type");
+            console.log(
+              "[ByteDance] request response content type: ",
+              contentType,
+            );
+
+            if (contentType?.startsWith("text/plain")) {
+              responseText = await res.clone().text();
+              return finish();
+            }
+
+            if (
+              !res.ok ||
+              !res.headers
+                .get("content-type")
+                ?.startsWith(EventStreamContentType) ||
+              res.status !== 200
+            ) {
+              const responseTexts = [responseText];
+              let extraInfo = await res.clone().text();
+              try {
+                const resJson = await res.clone().json();
+                extraInfo = prettyObject(resJson);
+              } catch {}
+
+              if (res.status === 401) {
+                responseTexts.push(Locale.Error.Unauthorized);
+              }
+
+              if (extraInfo) {
+                responseTexts.push(extraInfo);
+              }
+
+              responseText = responseTexts.join("\n\n");
+
+              return finish();
+            }
+          },
+          onmessage(msg) {
+            if (msg.data === "[DONE]" || finished) {
+              return finish();
+            }
+            const text = msg.data;
+            try {
+              const json = JSON.parse(text);
+              const choices = json.choices as Array<{
+                delta: { content: string };
+              }>;
+              const delta = choices[0]?.delta?.content;
+              if (delta) {
+                remainText += delta;
+              }
+            } catch (e) {
+              console.error("[Request] parse error", text, msg);
+            }
+          },
+          onclose() {
+            finish();
+          },
+          onerror(e) {
+            options.onError?.(e);
+            throw e;
+          },
+          openWhenHidden: true,
+        });
+      } else {
+        const res = await fetch(chatPath, chatPayload);
+        clearTimeout(requestTimeoutId);
+
+        const resJson = await res.json();
+        const message = this.extractMessage(resJson);
+        options.onFinish(message);
+      }
+    } catch (e) {
+      console.log("[Request] failed to make a chat request", e);
+      options.onError?.(e as Error);
+    }
+  }
+  async usage() {
+    return {
+      used: 0,
+      total: 0,
+    };
+  }
+
+  async models(): Promise<LLMModel[]> {
+    return [];
+  }
+}
+export { ByteDance };
diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx
index 7281fc2f1..1cc531eb8 100644
--- a/app/components/exporter.tsx
+++ b/app/components/exporter.tsx
@@ -321,6 +321,8 @@ export function PreviewActions(props: {
       api = new ClientApi(ModelProvider.GeminiPro);
     } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
       api = new ClientApi(ModelProvider.Claude);
+    } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
+      api = new ClientApi(ModelProvider.Doubao);
     } else {
       api = new ClientApi(ModelProvider.GPT);
     }
diff --git a/app/components/home.tsx b/app/components/home.tsx
index addb5e803..7da20df22 100644
--- a/app/components/home.tsx
+++ b/app/components/home.tsx
@@ -175,6 +175,8 @@ export function useLoadData() {
     api = new ClientApi(ModelProvider.GeminiPro);
   } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
     api = new ClientApi(ModelProvider.Claude);
+  } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
+    api = new ClientApi(ModelProvider.Doubao);
   } else {
     api = new ClientApi(ModelProvider.GPT);
   }
diff --git a/app/config/server.ts b/app/config/server.ts
index b7c85ce6a..d50dbf1a1 100644
--- a/app/config/server.ts
+++ b/app/config/server.ts
@@ -32,6 +32,10 @@ declare global {
       GOOGLE_API_KEY?: string;
       GOOGLE_URL?: string;
 
+      // bytedance only
+      BYTEDANCE_URL?: string;
+      BYTEDANCE_API_KEY?: string;
+
       // google tag manager
       GTM_ID?: string;
 
@@ -92,6 +96,7 @@ export const getServerSideConfig = () => {
   const isAzure = !!process.env.AZURE_URL;
   const isGoogle = !!process.env.GOOGLE_API_KEY;
   const isAnthropic = !!process.env.ANTHROPIC_API_KEY;
+  const isBytedance = !!process.env.BYTEDANCE_API_KEY;
 
   // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
   // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim());
@@ -126,6 +131,10 @@ export const getServerSideConfig = () => {
 
     gtmId: process.env.GTM_ID,
 
+    isBytedance,
+    bytedanceApiKey: getApiKey(process.env.BYTEDANCE_API_KEY),
+    bytedanceUrl: process.env.BYTEDANCE_URL,
+
     needCode: ACCESS_CODES.size > 0,
     code: process.env.CODE,
     codes: ACCESS_CODES,
diff --git a/app/constant.ts b/app/constant.ts
index d44b5b817..1ed292d21 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -14,6 +14,8 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
 
 export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
 
+export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com";
+
 export enum Path {
   Home = "/",
   Chat = "/chat",
@@ -28,6 +30,7 @@ export enum ApiPath {
   Azure = "/api/azure",
   OpenAI = "/api/openai",
   Anthropic = "/api/anthropic",
+  ByteDance = "/api/bytedance",
 }
 
 export enum SlotID {
@@ -71,12 +74,14 @@ export enum ServiceProvider {
   Azure = "Azure",
   Google = "Google",
   Anthropic = "Anthropic",
+  ByteDance = "ByteDance",
 }
 
 export enum ModelProvider {
   GPT = "GPT",
   GeminiPro = "GeminiPro",
   Claude = "Claude",
+  Doubao = "Doubao",
 }
 
 export const Anthropic = {
@@ -104,6 +109,11 @@ export const Google = {
   ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
 };
 
+export const ByteDance = {
+  ExampleEndpoint: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
+  ChatPath: "/api/v3/chat/completions",
+};
+
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
 // export const DEFAULT_SYSTEM_TEMPLATE = `
 // You are ChatGPT, a large language model trained by {{ServiceProvider}}.
@@ -173,6 +183,8 @@ const anthropicModels = [
   "claude-3-5-sonnet-20240620",
 ];
 
+const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"];
+
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
     name,
@@ -210,6 +222,15 @@ export const DEFAULT_MODELS = [
       providerType: "anthropic",
     },
   })),
+  ...bytedanceModels.map((name) => ({
+    name,
+    available: true,
+    provider: {
+      id: "bytedance",
+      providerName: "ByteDance",
+      providerType: "bytedance",
+    },
+  })),
 ] as const;
 
 export const CHAT_PAGE_SIZE = 15;
diff --git a/app/store/access.ts b/app/store/access.ts
index 03780779e..b04748b8c 100644
--- a/app/store/access.ts
+++ b/app/store/access.ts
@@ -47,6 +47,10 @@ const DEFAULT_ACCESS_STATE = {
   anthropicApiVersion: "2023-06-01",
   anthropicUrl: "",
 
+  // bytedance
+  bytedanceApiKey: "",
+  bytedanceUrl: "",
+
   // server config
   needCode: true,
   hideUserApiKey: false,
@@ -83,6 +87,10 @@ export const useAccessStore = createPersistStore(
       return ensure(get(), ["anthropicApiKey"]);
     },
 
+    isValidByteDance() {
+      return ensure(get(), ["bytedanceApiKey"]);
+    },
+
     isAuthorized() {
       this.fetch();
 
@@ -92,6 +100,7 @@ export const useAccessStore = createPersistStore(
         this.isValidAzure() ||
         this.isValidGoogle() ||
         this.isValidAnthropic() ||
+        this.isValidByteDance() ||
         !this.enabledAccessControl() ||
         (this.enabledAccessControl() && ensure(get(), ["accessCode"]))
       );
diff --git a/app/store/chat.ts b/app/store/chat.ts
index 44d41830a..475d436d9 100644
--- a/app/store/chat.ts
+++ b/app/store/chat.ts
@@ -368,6 +368,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
+          api = new ClientApi(ModelProvider.Doubao);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }
@@ -552,6 +554,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
+          api = new ClientApi(ModelProvider.Doubao);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }

From f3e3f083774ab01db558a213a0b180fe995ad2c4 Mon Sep 17 00:00:00 2001
From: Dogtiti <499960698@qq.com>
Date: Sat, 6 Jul 2024 21:25:00 +0800
Subject: [PATCH 03/20] fix: apiClient

---
 app/client/api.ts | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/app/client/api.ts b/app/client/api.ts
index a3d5a36e0..f650139f9 100644
--- a/app/client/api.ts
+++ b/app/client/api.ts
@@ -225,6 +225,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi {
       return new ClientApi(ModelProvider.GeminiPro);
     case ServiceProvider.Anthropic:
       return new ClientApi(ModelProvider.Claude);
+    case ServiceProvider.Baidu:
+      return new ClientApi(ModelProvider.Ernie);
     default:
       return new ClientApi(ModelProvider.GPT);
   }

From 1caa61f4c0e8d35bfff2dd670925f8c1ceb8267a Mon Sep 17 00:00:00 2001
From: Dogtiti <499960698@qq.com>
Date: Sat, 6 Jul 2024 22:59:20 +0800
Subject: [PATCH 04/20] feat: swap name and displayName for bytedance in custom
 models

---
 app/client/api.ts    |  2 ++
 app/config/server.ts |  6 +++---
 app/constant.ts      |  9 ++++++++-
 app/utils/model.ts   | 12 ++++++++++--
 4 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/app/client/api.ts b/app/client/api.ts
index d2eeca46a..f2e83c391 100644
--- a/app/client/api.ts
+++ b/app/client/api.ts
@@ -225,6 +225,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi {
       return new ClientApi(ModelProvider.GeminiPro);
     case ServiceProvider.Anthropic:
       return new ClientApi(ModelProvider.Claude);
+    case ServiceProvider.ByteDance:
+      return new ClientApi(ModelProvider.Doubao);
     default:
       return new ClientApi(ModelProvider.GPT);
   }
diff --git a/app/config/server.ts b/app/config/server.ts
index d50dbf1a1..0f57d2d6d 100644
--- a/app/config/server.ts
+++ b/app/config/server.ts
@@ -32,13 +32,13 @@ declare global {
       GOOGLE_API_KEY?: string;
       GOOGLE_URL?: string;
 
+      // google tag manager
+      GTM_ID?: string;
+
       // bytedance only
       BYTEDANCE_URL?: string;
       BYTEDANCE_API_KEY?: string;
 
-      // google tag manager
-      GTM_ID?: string;
-
       // custom template for preprocessing user input
       DEFAULT_INPUT_TEMPLATE?: string;
     }
diff --git a/app/constant.ts b/app/constant.ts
index 1ed292d21..5b52073bb 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -183,7 +183,14 @@ const anthropicModels = [
   "claude-3-5-sonnet-20240620",
 ];
 
-const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"];
+const bytedanceModels = [
+  "Doubao-lite-4k",
+  "Doubao-lite-32k",
+  "Doubao-lite-128k",
+  "Doubao-pro-4k",
+  "Doubao-pro-32k",
+  "Doubao-pro-128k",
+];
 
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 249987726..62ecc55b3 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -39,7 +39,7 @@ export function collectModelTable(
       const available = !m.startsWith("-");
       const nameConfig =
         m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
-      const [name, displayName] = nameConfig.split("=");
+      let [name, displayName] = nameConfig.split("=");
 
       // enable or disable all models
       if (name === "all") {
@@ -50,9 +50,17 @@ export function collectModelTable(
         // 1. find model by name(), and set available value
         let count = 0;
         for (const fullName in modelTable) {
-          if (fullName.split("@").shift() == name) {
+          const [modelName, providerName] = fullName.split("@");
+          if (modelName === name) {
             count += 1;
             modelTable[fullName]["available"] = available;
+            // swap name and displayName for bytedance
+            if (providerName === "bytedance") {
+              const tempName = name;
+              name = displayName;
+              displayName = tempName;
+              modelTable[fullName]["name"] = name;
+            }
             if (displayName) {
               modelTable[fullName]["displayName"] = displayName;
             }

From 71af2628eb8d791070fc2b4818f6f46c9068c962 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 00:32:18 +0800
Subject: [PATCH 05/20] hotfix: old AZURE_URL config error:
 "DeploymentNotFound". #4945 #4930

---
 app/api/common.ts  | 25 +++++++++++++++++++++++++
 app/utils/model.ts | 10 ++++++++--
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/app/api/common.ts b/app/api/common.ts
index b2fae6df2..5223646d2 100644
--- a/app/api/common.ts
+++ b/app/api/common.ts
@@ -66,6 +66,31 @@ export async function requestOpenai(req: NextRequest) {
       "/api/azure/",
       "",
     )}?api-version=${azureApiVersion}`;
+
+    // Forward compatibility:
+    // if display_name(deployment_name) not set, and '{deploy-id}' in AZURE_URL
+    // then using default '{deploy-id}'
+    if (serverConfig.customModels) {
+      const modelName = path.split("/")[1];
+      let realDeployName = "";
+      serverConfig.customModels
+        .split(",")
+        .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
+        .forEach((m) => {
+          const [fullName, displayName] = m.split("=");
+          const [_, providerName] = fullName.split("@");
+          if (providerName === "azure" && !displayName) {
+            const [_, deployId] = serverConfig.azureUrl.split("deployments/");
+            if (deployId) {
+              realDeployName = deployId;
+            }
+          }
+        });
+      if (realDeployName) {
+        console.log("[Replace with DeployId", realDeployName);
+        path = path.replaceAll(modelName, realDeployName);
+      }
+    }
   }
 
   const fetchUrl = `${baseUrl}/${path}`;
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 249987726..0b160f101 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -47,10 +47,16 @@ export function collectModelTable(
           (model) => (model.available = available),
         );
       } else {
-        // 1. find model by name(), and set available value
+        // 1. find model by name, and set available value
+        const [customModelName, customProviderName] = name.split("@");
         let count = 0;
         for (const fullName in modelTable) {
-          if (fullName.split("@").shift() == name) {
+          const [modelName, providerName] = fullName.split("@");
+          if (
+            customModelName == modelName &&
+            (customProviderName === undefined ||
+              customProviderName === providerName)
+          ) {
             count += 1;
             modelTable[fullName]["available"] = available;
             if (displayName) {

From 34ab37f31e1fe968c86a4ddc8421a1bfe6a20a27 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 00:47:35 +0800
Subject: [PATCH 06/20] update CUSTOM_MODELS config for Azure mode.

---
 README.md    | 4 ++++
 README_CN.md | 5 +++++
 2 files changed, 9 insertions(+)

diff --git a/README.md b/README.md
index c77d2023c..2cac1088a 100644
--- a/README.md
+++ b/README.md
@@ -181,6 +181,7 @@ Specify OpenAI organization ID.
 ### `AZURE_URL` (optional)
 
 > Example: https://{azure-resource-url}/openai/deployments/{deploy-name}
+> if you config deployment name in `CUSTOM_MODELS`, you can remove `{deploy-name}` in `AZURE_URL`
 
 Azure deploy url.
 
@@ -245,6 +246,9 @@ To control custom models, use `+` to add a custom model, use `-` to hide a model
 
 User `-all` to disable all default models, `+all` to enable all default models.
 
+For Azure: use `modelName@azure=deploymentName` to customize model name and deployment name.
+> Example: `+gpt-3.5-turbo@azure=gpt35` will show option `gpt35(Azure)` in model list.
+
 ### `DEFAULT_MODEL` (optional)
 
 Change default model
diff --git a/README_CN.md b/README_CN.md
index 970ecdef2..c6cbf6539 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,7 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填
 ### `AZURE_URL` (可选)
 
 > 形如:https://{azure-resource-url}/openai/deployments/{deploy-name}
+> 如果你已经在`CUSTOM_MODELS`中参考`displayName`的方式配置了{deploy-name},那么可以从`AZURE_URL`中移除`{deploy-name}`
 
 Azure 部署地址。
 
@@ -156,6 +157,10 @@ anthropic claude Api Url.
 
 用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,使用 `模型名=展示名` 来自定义模型的展示名,用英文逗号隔开。
 
+在Azure的模式下,支持使用`modelName@azure=deploymentName`的方式配置模型名称和部署名称(deploy-name)
+> 示例:`+gpt-3.5-turbo@azure=gpt35`这个配置会在模型列表显示一个`gpt35(Azure)`的选项
+
+
 ### `DEFAULT_MODEL` (可选)
 
 更改默认模型

From 6ac9789a1c4065c19cdd1bab7a808fbc54c0b1a2 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 12:16:37 +0800
Subject: [PATCH 07/20] hotfix

---
 app/store/config.ts | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/app/store/config.ts b/app/store/config.ts
index 4b0a34f4f..1eaafe12b 100644
--- a/app/store/config.ts
+++ b/app/store/config.ts
@@ -49,7 +49,7 @@ export const DEFAULT_CONFIG = {
 
   modelConfig: {
     model: "gpt-3.5-turbo" as ModelType,
-    providerName: "Openai" as ServiceProvider,
+    providerName: "OpenAI" as ServiceProvider,
     temperature: 0.5,
     top_p: 1,
     max_tokens: 4000,

From f68cd2c5c04a33dda4187ee7db4bbfb4026b9e40 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 12:27:44 +0800
Subject: [PATCH 08/20] review code

---
 app/client/platforms/baidu.ts | 10 +++++-----
 app/constant.ts               | 23 +++++++++++++++++------
 app/utils/model.ts            |  6 ++----
 3 files changed, 24 insertions(+), 15 deletions(-)

diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts
index e2f6f12dd..4fc3d2f64 100644
--- a/app/client/platforms/baidu.ts
+++ b/app/client/platforms/baidu.ts
@@ -2,7 +2,7 @@
 import {
   ApiPath,
   Baidu,
-  DEFAULT_API_HOST,
+  BAIDU_BASE_URL,
   REQUEST_TIMEOUT_MS,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
@@ -21,7 +21,7 @@ import {
 } from "@fortaine/fetch-event-source";
 import { prettyObject } from "@/app/utils/format";
 import { getClientConfig } from "@/app/config/client";
-import { getMessageTextContent, isVisionModel } from "@/app/utils";
+import { getMessageTextContent } from "@/app/utils";
 
 export interface OpenAIListModelResponse {
   object: string;
@@ -58,7 +58,8 @@ export class ErnieApi implements LLMApi {
 
     if (baseUrl.length === 0) {
       const isApp = !!getClientConfig()?.isApp;
-      baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu;
+      // do not use proxy for baidubce api
+      baseUrl = isApp ? BAIDU_BASE_URL : ApiPath.Baidu;
     }
 
     if (baseUrl.endsWith("/")) {
@@ -78,10 +79,9 @@ export class ErnieApi implements LLMApi {
   }
 
   async chat(options: ChatOptions) {
-    const visionModel = isVisionModel(options.config.model);
     const messages = options.messages.map((v) => ({
       role: v.role,
-      content: visionModel ? v.content : getMessageTextContent(v),
+      content: getMessageTextContent(v),
     }));
 
     const modelConfig = {
diff --git a/app/constant.ts b/app/constant.ts
index 6ffc0e0b3..0fd4d1c24 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -112,9 +112,20 @@ export const Google = {
 };
 
 export const Baidu = {
-  ExampleEndpoint: "https://aip.baidubce.com",
-  ChatPath: (modelName: string) =>
-    `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`,
+  ExampleEndpoint: BAIDU_BASE_URL,
+  ChatPath: (modelName: string) => {
+    let endpoint = modelName;
+    if (modelName === "ernie-4.0-8k") {
+      endpoint = "completions_pro";
+    }
+    if (modelName === "ernie-4.0-8k-preview-0518") {
+      endpoint = "completions_adv_pro";
+    }
+    if (modelName === "ernie-3.5-8k") {
+      endpoint = "completions";
+    }
+    return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
+  },
 };
 
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -188,11 +199,11 @@ const anthropicModels = [
 
 const baiduModels = [
   "ernie-4.0-turbo-8k",
-  "completions_pro=ernie-4.0-8k",
+  "ernie-4.0-8k",
   "ernie-4.0-8k-preview",
-  "completions_adv_pro=ernie-4.0-8k-preview-0518",
+  "ernie-4.0-8k-preview-0518",
   "ernie-4.0-8k-latest",
-  "completions=ernie-3.5-8k",
+  "ernie-3.5-8k",
   "ernie-3.5-8k-0205",
 ];
 
diff --git a/app/utils/model.ts b/app/utils/model.ts
index 6a02ed7eb..7c778888e 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -24,13 +24,11 @@ export function collectModelTable(
 
   // default models
   models.forEach((m) => {
-    // supoort name=displayName eg:completions_pro=ernie-4.0-8k
-    const [name, displayName] = m.name?.split("=");
     // using <modelName>@<providerId> as fullName
-    modelTable[`${name}@${m?.provider?.id}`] = {
+    modelTable[`${m.name}@${m?.provider?.id}`] = {
       ...m,
       name,
-      displayName: displayName || name, // 'provider' is copied over if it exists
+      displayName: m.name, // 'provider' is copied over if it exists
     };
   });
 

From 011b76e4e720be49db847a12ba02a78961a0159e Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 13:39:39 +0800
Subject: [PATCH 09/20] review code

---
 app/utils/model.ts | 1 -
 1 file changed, 1 deletion(-)

diff --git a/app/utils/model.ts b/app/utils/model.ts
index 7c778888e..249987726 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -27,7 +27,6 @@ export function collectModelTable(
     // using <modelName>@<providerId> as fullName
     modelTable[`${m.name}@${m?.provider?.id}`] = {
       ...m,
-      name,
       displayName: m.name, // 'provider' is copied over if it exists
     };
   });

From fadd7f6eb4cb9d70fb9758ee52c85aac768dc1be Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 14:50:40 +0800
Subject: [PATCH 10/20] try getAccessToken in app, fixbug to fetch in none
 stream mode

---
 app/api/baidu/[...path]/route.ts | 41 +++++++++++++-------------------
 app/client/platforms/baidu.ts    | 37 +++++++++++++++++++++-------
 app/constant.ts                  |  2 +-
 3 files changed, 47 insertions(+), 33 deletions(-)

diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts
index 27676d29d..5444ba4fe 100644
--- a/app/api/baidu/[...path]/route.ts
+++ b/app/api/baidu/[...path]/route.ts
@@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
 import { isModelAvailableInServer } from "@/app/utils/model";
+import { getAccessToken } from "@/app/utils/baidu";
 
 const serverConfig = getServerSideConfig();
 
@@ -30,6 +31,18 @@ async function handle(
     });
   }
 
+  if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) {
+    return NextResponse.json(
+      {
+        error: true,
+        message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`,
+      },
+      {
+        status: 401,
+      },
+    );
+  }
+
   try {
     const response = await request(req);
     return response;
@@ -88,7 +101,10 @@ async function request(req: NextRequest) {
     10 * 60 * 1000,
   );
 
-  const { access_token } = await getAccessToken();
+  const { access_token } = await getAccessToken(
+    serverConfig.baiduApiKey,
+    serverConfig.baiduSecretKey,
+  );
   const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
 
   const fetchOptions: RequestInit = {
@@ -133,11 +149,9 @@ async function request(req: NextRequest) {
       console.error(`[Baidu] filter`, e);
     }
   }
-  console.log("[Baidu request]", fetchOptions.headers, req.method);
   try {
     const res = await fetch(fetchUrl, fetchOptions);
 
-    console.log("[Baidu response]", res.status, "   ", res.headers, res.url);
     // to prevent browser prompt for credentials
     const newHeaders = new Headers(res.headers);
     newHeaders.delete("www-authenticate");
@@ -153,24 +167,3 @@ async function request(req: NextRequest) {
     clearTimeout(timeoutId);
   }
 }
-
-/**
- * 使用 AK,SK 生成鉴权签名(Access Token)
- * @return 鉴权签名信息
- */
-async function getAccessToken(): Promise<{
-  access_token: string;
-  expires_in: number;
-  error?: number;
-}> {
-  const AK = serverConfig.baiduApiKey;
-  const SK = serverConfig.baiduSecretKey;
-  const res = await fetch(
-    `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
-    {
-      method: "POST",
-    },
-  );
-  const resJson = await res.json();
-  return resJson;
-}
diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts
index 4fc3d2f64..188b78bf9 100644
--- a/app/client/platforms/baidu.ts
+++ b/app/client/platforms/baidu.ts
@@ -6,6 +6,7 @@ import {
   REQUEST_TIMEOUT_MS,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+import { getAccessToken } from "@/app/utils/baidu";
 
 import {
   ChatOptions,
@@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi {
     return [baseUrl, path].join("/");
   }
 
-  extractMessage(res: any) {
-    return res.choices?.at(0)?.message?.content ?? "";
-  }
-
   async chat(options: ChatOptions) {
     const messages = options.messages.map((v) => ({
       role: v.role,
       content: getMessageTextContent(v),
     }));
 
+    // "error_code": 336006, "error_msg": "the length of messages must be an odd number",
+    if (messages.length % 2 === 0) {
+      messages.unshift({
+        role: "user",
+        content: " ",
+      });
+    }
+
     const modelConfig = {
       ...useAppConfig.getState().modelConfig,
       ...useChatStore.getState().currentSession().mask.modelConfig,
@@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi {
       },
     };
 
+    const shouldStream = !!options.config.stream;
     const requestPayload: RequestPayload = {
       messages,
-      stream: options.config.stream,
+      stream: shouldStream,
       model: modelConfig.model,
       temperature: modelConfig.temperature,
       presence_penalty: modelConfig.presence_penalty,
@@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi {
 
     console.log("[Request] Baidu payload: ", requestPayload);
 
-    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
     try {
-      const chatPath = this.path(Baidu.ChatPath(modelConfig.model));
+      let chatPath = this.path(Baidu.ChatPath(modelConfig.model));
+
+      // getAccessToken can not run in browser, because cors error
+      if (!!getClientConfig()?.isApp) {
+        const accessStore = useAccessStore.getState();
+        if (accessStore.useCustomConfig) {
+          if (accessStore.isValidBaidu()) {
+            const { access_token } = await getAccessToken(
+              accessStore.baiduApiKey,
+              accessStore.baiduSecretKey,
+            );
+            chatPath = `${chatPath}${
+              chatPath.includes("?") ? "&" : "?"
+            }access_token=${access_token}`;
+          }
+        }
+      }
       const chatPayload = {
         method: "POST",
         body: JSON.stringify(requestPayload),
@@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi {
         clearTimeout(requestTimeoutId);
 
         const resJson = await res.json();
-        const message = this.extractMessage(resJson);
+        const message = resJson?.result;
         options.onFinish(message);
       }
     } catch (e) {
diff --git a/app/constant.ts b/app/constant.ts
index 0fd4d1c24..3d48dbb62 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -124,7 +124,7 @@ export const Baidu = {
     if (modelName === "ernie-3.5-8k") {
       endpoint = "completions";
     }
-    return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
+    return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
   },
 };
 

From b14a0f24ae2b5d3dee298f6f573016b2356d7fac Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 14:57:19 +0800
Subject: [PATCH 11/20] update locales

---
 app/locales/cn.ts |  4 ++--
 app/locales/en.ts | 16 ++++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/app/locales/cn.ts b/app/locales/cn.ts
index a872ee75a..d7268807c 100644
--- a/app/locales/cn.ts
+++ b/app/locales/cn.ts
@@ -350,12 +350,12 @@ const cn = {
       Baidu: {
         ApiKey: {
           Title: "接口密钥",
-          SubTitle: "使用自定义 Baidu API Key 绕过密码访问限制",
+          SubTitle: "使用自定义 Baidu API Key",
           Placeholder: "Baidu API Key",
         },
         SecretKey: {
           Title: "接口密钥",
-          SubTitle: "使用自定义 Baidu Secret Key 绕过密码访问限制",
+          SubTitle: "使用自定义 Baidu Secret Key",
           Placeholder: "Baidu Secret Key",
         },
         Endpoint: {
diff --git a/app/locales/en.ts b/app/locales/en.ts
index aa153f523..3c0d8851f 100644
--- a/app/locales/en.ts
+++ b/app/locales/en.ts
@@ -334,6 +334,22 @@ const en: LocaleType = {
           SubTitle: "Select and input a specific API version",
         },
       },
+      Baidu: {
+        ApiKey: {
+          Title: "Baidu API Key",
+          SubTitle: "Use a custom Baidu API Key",
+          Placeholder: "Baidu API Key",
+        },
+        SecretKey: {
+          Title: "Baidu Secret Key",
+          SubTitle: "Use a custom Baidu Secret Key",
+          Placeholder: "Baidu Secret Key",
+        },
+        Endpoint: {
+          Title: "Endpoint Address",
+          SubTitle: "Example:",
+        },
+      },
       CustomModel: {
         Title: "Custom Models",
         SubTitle: "Custom model options, seperated by comma",

From 230e3823a90df67800f29be43d40e87ab42c1a76 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 15:02:44 +0800
Subject: [PATCH 12/20] update readme

---
 README.md    | 12 ++++++++++++
 README_CN.md | 12 ++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/README.md b/README.md
index c77d2023c..feaf197c4 100644
--- a/README.md
+++ b/README.md
@@ -212,6 +212,18 @@ anthropic claude Api version.
 
 anthropic claude Api Url.
 
+### `BAIDU_API_KEY` (optional)
+
+Baidu Api Key.
+
+### `BAIDU_SECRET_KEY` (optional)
+
+Baidu Secret Key.
+
+### `BAIDU_URL` (optional)
+
+Baidu Api Url.
+
 ### `HIDE_USER_API_KEY` (optional)
 
 > Default: Empty
diff --git a/README_CN.md b/README_CN.md
index 970ecdef2..827d4850f 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -126,6 +126,18 @@ anthropic claude Api version.
 
 anthropic claude Api Url.
 
+### `BAIDU_API_KEY` (可选)
+
+Baidu Api Key.
+
+### `BAIDU_SECRET_KEY` (可选)
+
+Baidu Secret Key.
+
+### `BAIDU_URL` (可选)
+
+Baidu Api Url.
+
 ### `HIDE_USER_API_KEY` (可选)
 
 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。

From 147fc9a35a39187babb2b5aae156d47949547423 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 15:10:23 +0800
Subject: [PATCH 13/20] fix ts type error

---
 app/api/baidu/[...path]/route.ts | 4 ++--
 app/api/common.ts                | 6 ++++--
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts
index 5444ba4fe..94c9963c7 100644
--- a/app/api/baidu/[...path]/route.ts
+++ b/app/api/baidu/[...path]/route.ts
@@ -102,8 +102,8 @@ async function request(req: NextRequest) {
   );
 
   const { access_token } = await getAccessToken(
-    serverConfig.baiduApiKey,
-    serverConfig.baiduSecretKey,
+    serverConfig.baiduApiKey as string,
+    serverConfig.baiduSecretKey as string,
   );
   const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
 
diff --git a/app/api/common.ts b/app/api/common.ts
index 5223646d2..1ffac7fce 100644
--- a/app/api/common.ts
+++ b/app/api/common.ts
@@ -70,7 +70,7 @@ export async function requestOpenai(req: NextRequest) {
     // Forward compatibility:
     // if display_name(deployment_name) not set, and '{deploy-id}' in AZURE_URL
     // then using default '{deploy-id}'
-    if (serverConfig.customModels) {
+    if (serverConfig.customModels && serverConfig.azureUrl) {
       const modelName = path.split("/")[1];
       let realDeployName = "";
       serverConfig.customModels
@@ -80,7 +80,9 @@ export async function requestOpenai(req: NextRequest) {
           const [fullName, displayName] = m.split("=");
           const [_, providerName] = fullName.split("@");
           if (providerName === "azure" && !displayName) {
-            const [_, deployId] = serverConfig.azureUrl.split("deployments/");
+            const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
+              "deployments/",
+            );
             if (deployId) {
               realDeployName = deployId;
             }

From f2a35f11140b4ee41828ad9024fee88ceebb24b0 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 16:38:22 +0800
Subject: [PATCH 14/20] add missing file

---
 app/utils/baidu.ts | 23 +++++++++++++++++++++++
 1 file changed, 23 insertions(+)
 create mode 100644 app/utils/baidu.ts

diff --git a/app/utils/baidu.ts b/app/utils/baidu.ts
new file mode 100644
index 000000000..ddeb17bd5
--- /dev/null
+++ b/app/utils/baidu.ts
@@ -0,0 +1,23 @@
+import { BAIDU_OATUH_URL } from "../constant";
+/**
+ * 使用 AK,SK 生成鉴权签名(Access Token)
+ * @return 鉴权签名信息
+ */
+export async function getAccessToken(
+  clientId: string,
+  clientSecret: string,
+): Promise<{
+  access_token: string;
+  expires_in: number;
+  error?: number;
+}> {
+  const res = await fetch(
+    `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${clientId}&client_secret=${clientSecret}`,
+    {
+      method: "POST",
+      mode: "cors",
+    },
+  );
+  const resJson = await res.json();
+  return resJson;
+}

From b3023543d67589c30f1c1ffd8f68fd712bc6c1aa Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 16:55:33 +0800
Subject: [PATCH 15/20] update

---
 app/utils/model.ts | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/app/utils/model.ts b/app/utils/model.ts
index adfbe287b..a3a014877 100644
--- a/app/utils/model.ts
+++ b/app/utils/model.ts
@@ -61,9 +61,7 @@ export function collectModelTable(
             modelTable[fullName]["available"] = available;
             // swap name and displayName for bytedance
             if (providerName === "bytedance") {
-              const tempName = name;
-              name = displayName;
-              displayName = tempName;
+              [name, displayName] = [displayName, name];
               modelTable[fullName]["name"] = name;
             }
             if (displayName) {

From 9d7e19cebf762ac7cd58e579040bd41c4d2cc15e Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 18:05:23 +0800
Subject: [PATCH 16/20] display doubao model name when select model

---
 app/api/bytedance/[...path]/route.ts |  9 +--------
 app/client/platforms/bytedance.ts    | 12 ++++--------
 app/components/chat.tsx              | 26 +++++++++++++++++++++++---
 3 files changed, 28 insertions(+), 19 deletions(-)

diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance/[...path]/route.ts
index bffb60f6c..336c837f0 100644
--- a/app/api/bytedance/[...path]/route.ts
+++ b/app/api/bytedance/[...path]/route.ts
@@ -132,17 +132,10 @@ async function request(req: NextRequest) {
       console.error(`[ByteDance] filter`, e);
     }
   }
-  console.log("[ByteDance request]", fetchOptions.headers, req.method);
+
   try {
     const res = await fetch(fetchUrl, fetchOptions);
 
-    console.log(
-      "[ByteDance response]",
-      res.status,
-      "   ",
-      res.headers,
-      res.url,
-    );
     // to prevent browser prompt for credentials
     const newHeaders = new Headers(res.headers);
     newHeaders.delete("www-authenticate");
diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts
index 92c1fd558..ce401e68d 100644
--- a/app/client/platforms/bytedance.ts
+++ b/app/client/platforms/bytedance.ts
@@ -2,7 +2,7 @@
 import {
   ApiPath,
   ByteDance,
-  DEFAULT_API_HOST,
+  BYTEDANCE_BASE_URL,
   REQUEST_TIMEOUT_MS,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
@@ -58,9 +58,7 @@ export class DoubaoApi implements LLMApi {
 
     if (baseUrl.length === 0) {
       const isApp = !!getClientConfig()?.isApp;
-      baseUrl = isApp
-        ? DEFAULT_API_HOST + "/api/proxy/bytedance"
-        : ApiPath.ByteDance;
+      baseUrl = isApp ? BYTEDANCE_BASE_URL : ApiPath.ByteDance;
     }
 
     if (baseUrl.endsWith("/")) {
@@ -94,9 +92,10 @@ export class DoubaoApi implements LLMApi {
       },
     };
 
+    const shouldStream = !!options.config.stream;
     const requestPayload: RequestPayload = {
       messages,
-      stream: options.config.stream,
+      stream: shouldStream,
       model: modelConfig.model,
       temperature: modelConfig.temperature,
       presence_penalty: modelConfig.presence_penalty,
@@ -104,9 +103,6 @@ export class DoubaoApi implements LLMApi {
       top_p: modelConfig.top_p,
     };
 
-    console.log("[Request] ByteDance payload: ", requestPayload);
-
-    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index b1bdf757f..ace404c10 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -467,6 +467,14 @@ export function ChatActions(props: {
       return filteredModels;
     }
   }, [allModels]);
+  const currentModelName = useMemo(() => {
+    const model = models.find(
+      (m) =>
+        m.name == currentModel &&
+        m?.provider?.providerName == currentProviderName,
+    );
+    return model?.displayName ?? "";
+  }, [models, currentModel, currentProviderName]);
   const [showModelSelector, setShowModelSelector] = useState(false);
   const [showUploadImage, setShowUploadImage] = useState(false);
 
@@ -489,7 +497,11 @@ export function ChatActions(props: {
         session.mask.modelConfig.providerName = nextModel?.provider
           ?.providerName as ServiceProvider;
       });
-      showToast(nextModel.name);
+      showToast(
+        nextModel?.provider?.providerName == "ByteDance"
+          ? nextModel.displayName
+          : nextModel.name,
+      );
     }
   }, [chatStore, currentModel, models]);
 
@@ -571,7 +583,7 @@ export function ChatActions(props: {
 
       <ChatAction
         onClick={() => setShowModelSelector(true)}
-        text={currentModel}
+        text={currentModelName}
         icon={<RobotIcon />}
       />
 
@@ -596,7 +608,15 @@ export function ChatActions(props: {
                 providerName as ServiceProvider;
               session.mask.syncGlobalConfig = false;
             });
-            showToast(model);
+            if (providerName == "ByteDance") {
+              const selectedModel = models.find(
+                (m) =>
+                  m.name == model && m?.provider.providerName == providerName,
+              );
+              showToast(selectedModel?.displayName ?? "");
+            } else {
+              showToast(model);
+            }
           }}
         />
       )}

From 1149d455890bfb73df98026d9fad11ecbfa88e52 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 18:06:59 +0800
Subject: [PATCH 17/20] remove check vision model

---
 app/client/platforms/bytedance.ts | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts
index ce401e68d..7677cafe1 100644
--- a/app/client/platforms/bytedance.ts
+++ b/app/client/platforms/bytedance.ts
@@ -21,7 +21,7 @@ import {
 } from "@fortaine/fetch-event-source";
 import { prettyObject } from "@/app/utils/format";
 import { getClientConfig } from "@/app/config/client";
-import { getMessageTextContent, isVisionModel } from "@/app/utils";
+import { getMessageTextContent } from "@/app/utils";
 
 export interface OpenAIListModelResponse {
   object: string;
@@ -78,10 +78,9 @@ export class DoubaoApi implements LLMApi {
   }
 
   async chat(options: ChatOptions) {
-    const visionModel = isVisionModel(options.config.model);
     const messages = options.messages.map((v) => ({
       role: v.role,
-      content: visionModel ? v.content : getMessageTextContent(v),
+      content: getMessageTextContent(v),
     }));
 
     const modelConfig = {

From 9d2a633f5e900c67343797a92de41635cdcbe25d Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 18:15:43 +0800
Subject: [PATCH 18/20] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?=
 =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE=E8=B1=86=E5=8C=85=E7=9A=84?=
 =?UTF-8?q?=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 README.md    | 11 +++++++++++
 README_CN.md | 11 +++++++++++
 2 files changed, 22 insertions(+)

diff --git a/README.md b/README.md
index 467bfbbe0..0815b723f 100644
--- a/README.md
+++ b/README.md
@@ -225,6 +225,14 @@ Baidu Secret Key.
 
 Baidu Api Url.
 
+### `BYTEDANCE_API_KEY` (optional)
+
+ByteDance Api Key.
+
+### `BYTEDANCE_URL` (optional)
+
+ByteDance Api Url.
+
 ### `HIDE_USER_API_KEY` (optional)
 
 > Default: Empty
@@ -261,6 +269,9 @@ User `-all` to disable all default models, `+all` to enable all default models.
 For Azure: use `modelName@azure=deploymentName` to customize model name and deployment name.
 > Example: `+gpt-3.5-turbo@azure=gpt35` will show option `gpt35(Azure)` in model list.
 
+For ByteDance: use `modelName@bytedance=deploymentName` to customize model name and deployment name.
+> Example: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx` will show option `Doubao-lite-4k(ByteDance)` in model list.
+
 ### `DEFAULT_MODEL` (optional)
 
 Change default model
diff --git a/README_CN.md b/README_CN.md
index e6c4d2011..321efe441 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -139,6 +139,14 @@ Baidu Secret Key.
 
 Baidu Api Url.
 
+### `BYTEDANCE_API_KEY` (可选)
+
+ByteDance Api Key.
+
+### `BYTEDANCE_URL` (可选)
+
+ByteDance Api Url.
+
 ### `HIDE_USER_API_KEY` (可选)
 
 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。
@@ -172,6 +180,9 @@ Baidu Api Url.
 在Azure的模式下,支持使用`modelName@azure=deploymentName`的方式配置模型名称和部署名称(deploy-name)
 > 示例:`+gpt-3.5-turbo@azure=gpt35`这个配置会在模型列表显示一个`gpt35(Azure)`的选项
 
+在ByteDance的模式下,支持使用`modelName@bytedance=deploymentName`的方式配置模型名称和部署名称(deploy-name)
+> 示例: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx`这个配置会在模型列表显示一个`Doubao-lite-4k(ByteDance)`的选项
+
 
 ### `DEFAULT_MODEL` (可选)
 

From 82be426f78449840158adab56a88aa94dfcfc2c7 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 18:19:34 +0800
Subject: [PATCH 19/20] fix eslint error

---
 app/components/chat.tsx | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index ace404c10..40e02cb57 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -611,7 +611,7 @@ export function ChatActions(props: {
             if (providerName == "ByteDance") {
               const selectedModel = models.find(
                 (m) =>
-                  m.name == model && m?.provider.providerName == providerName,
+                  m.name == model && m?.provider?.providerName == providerName,
               );
               showToast(selectedModel?.displayName ?? "");
             } else {

From bb349a03dac8e006c4d125779c506efa98283286 Mon Sep 17 00:00:00 2001
From: lloydzhou <lloydzhou@qq.com>
Date: Tue, 9 Jul 2024 19:21:27 +0800
Subject: [PATCH 20/20] fix get headers for bytedance

---
 app/client/api.ts | 26 +++++++++++++++++++++++---
 1 file changed, 23 insertions(+), 3 deletions(-)

diff --git a/app/client/api.ts b/app/client/api.ts
index ff81f5372..147b11ad2 100644
--- a/app/client/api.ts
+++ b/app/client/api.ts
@@ -179,6 +179,8 @@ export function getHeaders() {
     const isGoogle = modelConfig.providerName == ServiceProvider.Google;
     const isAzure = modelConfig.providerName === ServiceProvider.Azure;
     const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
+    const isBaidu = modelConfig.providerName == ServiceProvider.Baidu;
+    const isByteDance = modelConfig.providerName === ServiceProvider.ByteDance;
     const isEnabledAccessControl = accessStore.enabledAccessControl();
     const apiKey = isGoogle
       ? accessStore.googleApiKey
@@ -186,8 +188,18 @@ export function getHeaders() {
       ? accessStore.azureApiKey
       : isAnthropic
       ? accessStore.anthropicApiKey
+      : isByteDance
+      ? accessStore.bytedanceApiKey
       : accessStore.openaiApiKey;
-    return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl };
+    return {
+      isGoogle,
+      isAzure,
+      isAnthropic,
+      isBaidu,
+      isByteDance,
+      apiKey,
+      isEnabledAccessControl,
+    };
   }
 
   function getAuthHeader(): string {
@@ -203,10 +215,18 @@ export function getHeaders() {
   function validString(x: string): boolean {
     return x?.length > 0;
   }
-  const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } =
-    getConfig();
+  const {
+    isGoogle,
+    isAzure,
+    isAnthropic,
+    isBaidu,
+    apiKey,
+    isEnabledAccessControl,
+  } = getConfig();
   // when using google api in app, not set auth header
   if (isGoogle && clientConfig?.isApp) return headers;
+  // when using baidu api in app, not set auth header
+  if (isBaidu && clientConfig?.isApp) return headers;
 
   const authHeader = getAuthHeader();