try getAccessToken in app, fixbug to fetch in none stream mode
This commit is contained in:
parent
011b76e4e7
commit
fadd7f6eb4
|
@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format";
|
||||||
import { NextRequest, NextResponse } from "next/server";
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
import { auth } from "@/app/api/auth";
|
import { auth } from "@/app/api/auth";
|
||||||
import { isModelAvailableInServer } from "@/app/utils/model";
|
import { isModelAvailableInServer } from "@/app/utils/model";
|
||||||
|
import { getAccessToken } from "@/app/utils/baidu";
|
||||||
|
|
||||||
const serverConfig = getServerSideConfig();
|
const serverConfig = getServerSideConfig();
|
||||||
|
|
||||||
|
@ -30,6 +31,18 @@ async function handle(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{
|
||||||
|
error: true,
|
||||||
|
message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
status: 401,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await request(req);
|
const response = await request(req);
|
||||||
return response;
|
return response;
|
||||||
|
@ -88,7 +101,10 @@ async function request(req: NextRequest) {
|
||||||
10 * 60 * 1000,
|
10 * 60 * 1000,
|
||||||
);
|
);
|
||||||
|
|
||||||
const { access_token } = await getAccessToken();
|
const { access_token } = await getAccessToken(
|
||||||
|
serverConfig.baiduApiKey,
|
||||||
|
serverConfig.baiduSecretKey,
|
||||||
|
);
|
||||||
const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
|
const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
|
||||||
|
|
||||||
const fetchOptions: RequestInit = {
|
const fetchOptions: RequestInit = {
|
||||||
|
@ -133,11 +149,9 @@ async function request(req: NextRequest) {
|
||||||
console.error(`[Baidu] filter`, e);
|
console.error(`[Baidu] filter`, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
console.log("[Baidu request]", fetchOptions.headers, req.method);
|
|
||||||
try {
|
try {
|
||||||
const res = await fetch(fetchUrl, fetchOptions);
|
const res = await fetch(fetchUrl, fetchOptions);
|
||||||
|
|
||||||
console.log("[Baidu response]", res.status, " ", res.headers, res.url);
|
|
||||||
// to prevent browser prompt for credentials
|
// to prevent browser prompt for credentials
|
||||||
const newHeaders = new Headers(res.headers);
|
const newHeaders = new Headers(res.headers);
|
||||||
newHeaders.delete("www-authenticate");
|
newHeaders.delete("www-authenticate");
|
||||||
|
@ -153,24 +167,3 @@ async function request(req: NextRequest) {
|
||||||
clearTimeout(timeoutId);
|
clearTimeout(timeoutId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 使用 AK,SK 生成鉴权签名(Access Token)
|
|
||||||
* @return 鉴权签名信息
|
|
||||||
*/
|
|
||||||
async function getAccessToken(): Promise<{
|
|
||||||
access_token: string;
|
|
||||||
expires_in: number;
|
|
||||||
error?: number;
|
|
||||||
}> {
|
|
||||||
const AK = serverConfig.baiduApiKey;
|
|
||||||
const SK = serverConfig.baiduSecretKey;
|
|
||||||
const res = await fetch(
|
|
||||||
`${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
|
|
||||||
{
|
|
||||||
method: "POST",
|
|
||||||
},
|
|
||||||
);
|
|
||||||
const resJson = await res.json();
|
|
||||||
return resJson;
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import {
|
||||||
REQUEST_TIMEOUT_MS,
|
REQUEST_TIMEOUT_MS,
|
||||||
} from "@/app/constant";
|
} from "@/app/constant";
|
||||||
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
|
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
|
||||||
|
import { getAccessToken } from "@/app/utils/baidu";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatOptions,
|
ChatOptions,
|
||||||
|
@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi {
|
||||||
return [baseUrl, path].join("/");
|
return [baseUrl, path].join("/");
|
||||||
}
|
}
|
||||||
|
|
||||||
extractMessage(res: any) {
|
|
||||||
return res.choices?.at(0)?.message?.content ?? "";
|
|
||||||
}
|
|
||||||
|
|
||||||
async chat(options: ChatOptions) {
|
async chat(options: ChatOptions) {
|
||||||
const messages = options.messages.map((v) => ({
|
const messages = options.messages.map((v) => ({
|
||||||
role: v.role,
|
role: v.role,
|
||||||
content: getMessageTextContent(v),
|
content: getMessageTextContent(v),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// "error_code": 336006, "error_msg": "the length of messages must be an odd number",
|
||||||
|
if (messages.length % 2 === 0) {
|
||||||
|
messages.unshift({
|
||||||
|
role: "user",
|
||||||
|
content: " ",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const modelConfig = {
|
const modelConfig = {
|
||||||
...useAppConfig.getState().modelConfig,
|
...useAppConfig.getState().modelConfig,
|
||||||
...useChatStore.getState().currentSession().mask.modelConfig,
|
...useChatStore.getState().currentSession().mask.modelConfig,
|
||||||
|
@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const shouldStream = !!options.config.stream;
|
||||||
const requestPayload: RequestPayload = {
|
const requestPayload: RequestPayload = {
|
||||||
messages,
|
messages,
|
||||||
stream: options.config.stream,
|
stream: shouldStream,
|
||||||
model: modelConfig.model,
|
model: modelConfig.model,
|
||||||
temperature: modelConfig.temperature,
|
temperature: modelConfig.temperature,
|
||||||
presence_penalty: modelConfig.presence_penalty,
|
presence_penalty: modelConfig.presence_penalty,
|
||||||
|
@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi {
|
||||||
|
|
||||||
console.log("[Request] Baidu payload: ", requestPayload);
|
console.log("[Request] Baidu payload: ", requestPayload);
|
||||||
|
|
||||||
const shouldStream = !!options.config.stream;
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
options.onController?.(controller);
|
options.onController?.(controller);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const chatPath = this.path(Baidu.ChatPath(modelConfig.model));
|
let chatPath = this.path(Baidu.ChatPath(modelConfig.model));
|
||||||
|
|
||||||
|
// getAccessToken can not run in browser, because cors error
|
||||||
|
if (!!getClientConfig()?.isApp) {
|
||||||
|
const accessStore = useAccessStore.getState();
|
||||||
|
if (accessStore.useCustomConfig) {
|
||||||
|
if (accessStore.isValidBaidu()) {
|
||||||
|
const { access_token } = await getAccessToken(
|
||||||
|
accessStore.baiduApiKey,
|
||||||
|
accessStore.baiduSecretKey,
|
||||||
|
);
|
||||||
|
chatPath = `${chatPath}${
|
||||||
|
chatPath.includes("?") ? "&" : "?"
|
||||||
|
}access_token=${access_token}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
const chatPayload = {
|
const chatPayload = {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: JSON.stringify(requestPayload),
|
body: JSON.stringify(requestPayload),
|
||||||
|
@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi {
|
||||||
clearTimeout(requestTimeoutId);
|
clearTimeout(requestTimeoutId);
|
||||||
|
|
||||||
const resJson = await res.json();
|
const resJson = await res.json();
|
||||||
const message = this.extractMessage(resJson);
|
const message = resJson?.result;
|
||||||
options.onFinish(message);
|
options.onFinish(message);
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
|
|
@ -124,7 +124,7 @@ export const Baidu = {
|
||||||
if (modelName === "ernie-3.5-8k") {
|
if (modelName === "ernie-3.5-8k") {
|
||||||
endpoint = "completions";
|
endpoint = "completions";
|
||||||
}
|
}
|
||||||
return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
|
return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue