try getAccessToken in app, fixbug to fetch in none stream mode

This commit is contained in:
lloydzhou
2024-07-09 14:50:40 +08:00
parent 011b76e4e7
commit fadd7f6eb4
3 changed files with 47 additions and 33 deletions

View File

@@ -6,6 +6,7 @@ import {
REQUEST_TIMEOUT_MS,
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getAccessToken } from "@/app/utils/baidu";
import {
ChatOptions,
@@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi {
return [baseUrl, path].join("/");
}
extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}
async chat(options: ChatOptions) {
const messages = options.messages.map((v) => ({
role: v.role,
content: getMessageTextContent(v),
}));
// "error_code": 336006, "error_msg": "the length of messages must be an odd number",
if (messages.length % 2 === 0) {
messages.unshift({
role: "user",
content: " ",
});
}
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
@@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi {
},
};
const shouldStream = !!options.config.stream;
const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
stream: shouldStream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
@@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi {
console.log("[Request] Baidu payload: ", requestPayload);
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(Baidu.ChatPath(modelConfig.model));
let chatPath = this.path(Baidu.ChatPath(modelConfig.model));
// getAccessToken can not run in browser, because cors error
if (!!getClientConfig()?.isApp) {
const accessStore = useAccessStore.getState();
if (accessStore.useCustomConfig) {
if (accessStore.isValidBaidu()) {
const { access_token } = await getAccessToken(
accessStore.baiduApiKey,
accessStore.baiduSecretKey,
);
chatPath = `${chatPath}${
chatPath.includes("?") ? "&" : "?"
}access_token=${access_token}`;
}
}
}
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
@@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi {
clearTimeout(requestTimeoutId);
const resJson = await res.json();
const message = this.extractMessage(resJson);
const message = resJson?.result;
options.onFinish(message);
}
} catch (e) {