From fadd7f6eb4cb9d70fb9758ee52c85aac768dc1be Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 14:50:40 +0800 Subject: [PATCH] try getAccessToken in app, fixbug to fetch in none stream mode --- app/api/baidu/[...path]/route.ts | 41 +++++++++++++------------------- app/client/platforms/baidu.ts | 37 +++++++++++++++++++++------- app/constant.ts | 2 +- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts index 27676d29d..5444ba4fe 100644 --- a/app/api/baidu/[...path]/route.ts +++ b/app/api/baidu/[...path]/route.ts @@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; import { isModelAvailableInServer } from "@/app/utils/model"; +import { getAccessToken } from "@/app/utils/baidu"; const serverConfig = getServerSideConfig(); @@ -30,6 +31,18 @@ async function handle( }); } + if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) { + return NextResponse.json( + { + error: true, + message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`, + }, + { + status: 401, + }, + ); + } + try { const response = await request(req); return response; @@ -88,7 +101,10 @@ async function request(req: NextRequest) { 10 * 60 * 1000, ); - const { access_token } = await getAccessToken(); + const { access_token } = await getAccessToken( + serverConfig.baiduApiKey, + serverConfig.baiduSecretKey, + ); const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`; const fetchOptions: RequestInit = { @@ -133,11 +149,9 @@ async function request(req: NextRequest) { console.error(`[Baidu] filter`, e); } } - console.log("[Baidu request]", fetchOptions.headers, req.method); try { const res = await fetch(fetchUrl, fetchOptions); - console.log("[Baidu response]", res.status, " ", res.headers, res.url); // to prevent browser prompt for credentials const newHeaders = new Headers(res.headers); newHeaders.delete("www-authenticate"); @@ -153,24 +167,3 @@ async function request(req: NextRequest) { clearTimeout(timeoutId); } } - -/** - * 使用 AK,SK 生成鉴权签名(Access Token) - * @return 鉴权签名信息 - */ -async function getAccessToken(): Promise<{ - access_token: string; - expires_in: number; - error?: number; -}> { - const AK = serverConfig.baiduApiKey; - const SK = serverConfig.baiduSecretKey; - const res = await fetch( - `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`, - { - method: "POST", - }, - ); - const resJson = await res.json(); - return resJson; -} diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts index 4fc3d2f64..188b78bf9 100644 --- a/app/client/platforms/baidu.ts +++ b/app/client/platforms/baidu.ts @@ -6,6 +6,7 @@ import { REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { getAccessToken } from "@/app/utils/baidu"; import { ChatOptions, @@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi { return [baseUrl, path].join("/"); } - extractMessage(res: any) { - return res.choices?.at(0)?.message?.content ?? ""; - } - async chat(options: ChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, content: getMessageTextContent(v), })); + // "error_code": 336006, "error_msg": "the length of messages must be an odd number", + if (messages.length % 2 === 0) { + messages.unshift({ + role: "user", + content: " ", + }); + } + const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi { }, }; + const shouldStream = !!options.config.stream; const requestPayload: RequestPayload = { messages, - stream: options.config.stream, + stream: shouldStream, model: modelConfig.model, temperature: modelConfig.temperature, presence_penalty: modelConfig.presence_penalty, @@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi { console.log("[Request] Baidu payload: ", requestPayload); - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { - const chatPath = this.path(Baidu.ChatPath(modelConfig.model)); + let chatPath = this.path(Baidu.ChatPath(modelConfig.model)); + + // getAccessToken can not run in browser, because cors error + if (!!getClientConfig()?.isApp) { + const accessStore = useAccessStore.getState(); + if (accessStore.useCustomConfig) { + if (accessStore.isValidBaidu()) { + const { access_token } = await getAccessToken( + accessStore.baiduApiKey, + accessStore.baiduSecretKey, + ); + chatPath = `${chatPath}${ + chatPath.includes("?") ? "&" : "?" + }access_token=${access_token}`; + } + } + } const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), @@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = resJson?.result; options.onFinish(message); } } catch (e) { diff --git a/app/constant.ts b/app/constant.ts index 0fd4d1c24..3d48dbb62 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -124,7 +124,7 @@ export const Baidu = { if (modelName === "ernie-3.5-8k") { endpoint = "completions"; } - return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; + return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; }, };