try getAccessToken in app, fixbug to fetch in none stream mode

This commit is contained in:
lloydzhou 2024-07-09 14:50:40 +08:00
parent 011b76e4e7
commit fadd7f6eb4
3 changed files with 47 additions and 33 deletions

View File

@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/app/api/auth"; import { auth } from "@/app/api/auth";
import { isModelAvailableInServer } from "@/app/utils/model"; import { isModelAvailableInServer } from "@/app/utils/model";
import { getAccessToken } from "@/app/utils/baidu";
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
@ -30,6 +31,18 @@ async function handle(
}); });
} }
if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) {
return NextResponse.json(
{
error: true,
message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`,
},
{
status: 401,
},
);
}
try { try {
const response = await request(req); const response = await request(req);
return response; return response;
@ -88,7 +101,10 @@ async function request(req: NextRequest) {
10 * 60 * 1000, 10 * 60 * 1000,
); );
const { access_token } = await getAccessToken(); const { access_token } = await getAccessToken(
serverConfig.baiduApiKey,
serverConfig.baiduSecretKey,
);
const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`; const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
const fetchOptions: RequestInit = { const fetchOptions: RequestInit = {
@ -133,11 +149,9 @@ async function request(req: NextRequest) {
console.error(`[Baidu] filter`, e); console.error(`[Baidu] filter`, e);
} }
} }
console.log("[Baidu request]", fetchOptions.headers, req.method);
try { try {
const res = await fetch(fetchUrl, fetchOptions); const res = await fetch(fetchUrl, fetchOptions);
console.log("[Baidu response]", res.status, " ", res.headers, res.url);
// to prevent browser prompt for credentials // to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers); const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate"); newHeaders.delete("www-authenticate");
@ -153,24 +167,3 @@ async function request(req: NextRequest) {
clearTimeout(timeoutId); clearTimeout(timeoutId);
} }
} }
/**
* 使 AKSK Access Token
* @return
*/
async function getAccessToken(): Promise<{
access_token: string;
expires_in: number;
error?: number;
}> {
const AK = serverConfig.baiduApiKey;
const SK = serverConfig.baiduSecretKey;
const res = await fetch(
`${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
{
method: "POST",
},
);
const resJson = await res.json();
return resJson;
}

View File

@ -6,6 +6,7 @@ import {
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
} from "@/app/constant"; } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getAccessToken } from "@/app/utils/baidu";
import { import {
ChatOptions, ChatOptions,
@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi {
return [baseUrl, path].join("/"); return [baseUrl, path].join("/");
} }
extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}
async chat(options: ChatOptions) { async chat(options: ChatOptions) {
const messages = options.messages.map((v) => ({ const messages = options.messages.map((v) => ({
role: v.role, role: v.role,
content: getMessageTextContent(v), content: getMessageTextContent(v),
})); }));
// "error_code": 336006, "error_msg": "the length of messages must be an odd number",
if (messages.length % 2 === 0) {
messages.unshift({
role: "user",
content: " ",
});
}
const modelConfig = { const modelConfig = {
...useAppConfig.getState().modelConfig, ...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig,
@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi {
}, },
}; };
const shouldStream = !!options.config.stream;
const requestPayload: RequestPayload = { const requestPayload: RequestPayload = {
messages, messages,
stream: options.config.stream, stream: shouldStream,
model: modelConfig.model, model: modelConfig.model,
temperature: modelConfig.temperature, temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty, presence_penalty: modelConfig.presence_penalty,
@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi {
console.log("[Request] Baidu payload: ", requestPayload); console.log("[Request] Baidu payload: ", requestPayload);
const shouldStream = !!options.config.stream;
const controller = new AbortController(); const controller = new AbortController();
options.onController?.(controller); options.onController?.(controller);
try { try {
const chatPath = this.path(Baidu.ChatPath(modelConfig.model)); let chatPath = this.path(Baidu.ChatPath(modelConfig.model));
// getAccessToken can not run in browser, because cors error
if (!!getClientConfig()?.isApp) {
const accessStore = useAccessStore.getState();
if (accessStore.useCustomConfig) {
if (accessStore.isValidBaidu()) {
const { access_token } = await getAccessToken(
accessStore.baiduApiKey,
accessStore.baiduSecretKey,
);
chatPath = `${chatPath}${
chatPath.includes("?") ? "&" : "?"
}access_token=${access_token}`;
}
}
}
const chatPayload = { const chatPayload = {
method: "POST", method: "POST",
body: JSON.stringify(requestPayload), body: JSON.stringify(requestPayload),
@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi {
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);
const resJson = await res.json(); const resJson = await res.json();
const message = this.extractMessage(resJson); const message = resJson?.result;
options.onFinish(message); options.onFinish(message);
} }
} catch (e) { } catch (e) {

View File

@ -124,7 +124,7 @@ export const Baidu = {
if (modelName === "ernie-3.5-8k") { if (modelName === "ernie-3.5-8k") {
endpoint = "completions"; endpoint = "completions";
} }
return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
}, },
}; };