feat: model provider refactor done

This commit is contained in:
Dean-YZG
2024-05-15 21:38:25 +08:00
parent 240d330001
commit a0e4a468d6
33 changed files with 3077 additions and 8 deletions

View File

@@ -0,0 +1,41 @@
import Locale from "./locale";
import { SettingItem } from "../../core/types";
import { modelConfigs as openaiModelConfigs } from "../openai/config";
export const AzureMetas = {
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
ChatPath: "v1/chat/completions",
OpenAI: "/api/openai",
};
export type SettingKeys = "azureUrl" | "azureApiKey" | "azureApiVersion";
export const modelConfigs = openaiModelConfigs;
export const settingItems: SettingItem<SettingKeys>[] = [
{
name: "azureUrl",
title: Locale.Endpoint.Title,
description: Locale.Endpoint.SubTitle + AzureMetas.ExampleEndpoint,
placeholder: AzureMetas.ExampleEndpoint,
type: "input",
},
{
name: "azureApiKey",
title: Locale.ApiKey.Title,
description: Locale.ApiKey.SubTitle,
placeholder: Locale.ApiKey.Placeholder,
type: "input",
inputType: "password",
validators: ["required"],
},
{
name: "azureApiVersion",
title: Locale.ApiVerion.Title,
description: Locale.ApiVerion.SubTitle,
placeholder: "2023-08-01-preview",
type: "input",
validators: ["required"],
},
];

View File

@@ -0,0 +1,326 @@
import { settingItems, SettingKeys, modelConfigs, AzureMetas } from "./config";
import {
InternalChatRequestPayload,
IProviderTemplate,
} from "../../core/types";
import { getMessageTextContent } from "@/app/utils";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales";
export type AzureProviderSettingKeys = SettingKeys;
export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];
export interface MultimodalContent {
type: "text" | "image_url";
text?: string;
image_url?: {
url: string;
};
}
export interface RequestMessage {
role: MessageRole;
content: string | MultimodalContent[];
}
interface RequestPayload {
messages: {
role: "system" | "user" | "assistant";
content: string | MultimodalContent[];
}[];
stream?: boolean;
model: string;
temperature: number;
presence_penalty: number;
frequency_penalty: number;
top_p: number;
max_tokens?: number;
}
export default class Azure
implements IProviderTemplate<SettingKeys, "azure", typeof AzureMetas>
{
name = "azure" as const;
metas = AzureMetas;
models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name }));
providerMeta = {
displayName: "Azure",
settingItems,
};
readonly REQUEST_TIMEOUT_MS = 60000;
private path(payload: InternalChatRequestPayload<SettingKeys>): string {
const {
providerConfig: { azureUrl, azureApiVersion },
} = payload;
const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion);
let baseUrl = azureUrl;
if (!baseUrl) {
baseUrl = "/api/openai";
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(AzureMetas.OpenAI)) {
baseUrl = "https://" + baseUrl;
}
console.log("[Proxy Endpoint] ", baseUrl, path);
return [baseUrl, path].join("/");
}
private getHeaders(payload: InternalChatRequestPayload<SettingKeys>) {
const { azureApiKey } = payload.providerConfig;
const headers: Record<string, string> = {
"Content-Type": "application/json",
Accept: "application/json",
};
const authHeader = "Authorization";
const makeBearer = (s: string) => `Bearer ${s.trim()}`;
const validString = (x?: string): x is string => Boolean(x && x.length > 0);
// when using google api in app, not set auth header
if (validString(azureApiKey)) {
headers[authHeader] = makeBearer(azureApiKey);
}
return headers;
}
private formatChatPayload(payload: InternalChatRequestPayload<SettingKeys>) {
const { messages, isVisionModel, model, stream, modelConfig } = payload;
const {
temperature,
presence_penalty,
frequency_penalty,
top_p,
max_tokens,
} = modelConfig;
const openAiMessages = messages.map((v) => ({
role: v.role,
content: isVisionModel ? v.content : getMessageTextContent(v),
}));
const requestPayload: RequestPayload = {
messages: openAiMessages,
stream,
model,
temperature,
presence_penalty,
frequency_penalty,
top_p,
};
// add max_tokens to vision model
if (isVisionModel) {
requestPayload["max_tokens"] = Math.max(max_tokens, 4000);
}
console.log("[Request] openai payload: ", requestPayload);
return {
headers: this.getHeaders(payload),
body: JSON.stringify(requestPayload),
method: "POST",
url: this.path(payload),
};
}
private readWholeMessageResponseBody(res: any) {
return {
message: res.choices?.at(0)?.message?.content ?? "",
};
}
private getTimer = (onabort: () => void = () => {}) => {
const controller = new AbortController();
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
this.REQUEST_TIMEOUT_MS,
);
controller.signal.onabort = onabort;
return {
...controller,
clear: () => {
clearTimeout(requestTimeoutId);
},
};
};
async chat(payload: InternalChatRequestPayload<SettingKeys>) {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
// make a fetch request
const requestTimeoutId = setTimeout(
() => timer.abort(),
this.REQUEST_TIMEOUT_MS,
);
const res = await fetch(requestPayload.url, {
headers: {
...requestPayload.headers,
},
body: requestPayload.body,
method: requestPayload.method,
signal: timer.signal,
});
clearTimeout(requestTimeoutId);
const resJson = await res.json();
const message = this.readWholeMessageResponseBody(resJson);
return message;
}
streamChat(
payload: InternalChatRequestPayload<SettingKeys>,
onProgress: (message: string, chunk: string) => void,
onFinish: (message: string) => void,
onError: (err: Error) => void,
) {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
let responseText = "";
let remainText = "";
let finished = false;
// animate response to make it looks smooth
const animateResponseText = () => {
if (finished || timer.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
if (responseText?.length === 0) {
onError(new Error("empty response from server"));
}
return;
}
if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
onProgress(responseText, fetchText);
}
requestAnimationFrame(animateResponseText);
};
// start animaion
animateResponseText();
const finish = () => {
if (!finished) {
finished = true;
onFinish(responseText + remainText);
}
};
timer.signal.onabort = finish;
fetchEventSource(requestPayload.url, {
...requestPayload,
async onopen(res) {
timer.clear();
const contentType = res.headers.get("content-type");
console.log("[OpenAI] request response content type: ", contentType);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const choices = json.choices as Array<{
delta: { content: string };
}>;
const delta = choices[0]?.delta?.content;
const textmoderation = json?.prompt_filter_results;
if (delta) {
remainText += delta;
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
onError(e);
throw e;
},
openWhenHidden: true,
});
return timer;
}
}
function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
}

View File

@@ -0,0 +1,109 @@
import { getLocaleText } from "../../core/locale";
export default getLocaleText<
{
ApiKey: {
Title: string;
SubTitle: string;
Placeholder: string;
};
Endpoint: {
Title: string;
SubTitle: string;
};
ApiVerion: {
Title: string;
SubTitle: string;
};
},
"en"
>(
{
cn: {
ApiKey: {
Title: "接口密钥",
SubTitle: "使用自定义 Azure Key 绕过密码访问限制",
Placeholder: "Azure API Key",
},
Endpoint: {
Title: "接口地址",
SubTitle: "样例:",
},
ApiVerion: {
Title: "接口版本 (azure api version)",
SubTitle: "选择指定的部分版本",
},
},
en: {
ApiKey: {
Title: "Azure Api Key",
SubTitle: "Check your api key from Azure console",
Placeholder: "Azure Api Key",
},
Endpoint: {
Title: "Azure Endpoint",
SubTitle: "Example: ",
},
ApiVerion: {
Title: "Azure Api Version",
SubTitle: "Check your api version from azure console",
},
},
pt: {
ApiKey: {
Title: "Chave API Azure",
SubTitle: "Verifique sua chave API do console Azure",
Placeholder: "Chave API Azure",
},
Endpoint: {
Title: "Endpoint Azure",
SubTitle: "Exemplo: ",
},
ApiVerion: {
Title: "Versão API Azure",
SubTitle: "Verifique sua versão API do console Azure",
},
},
sk: {
ApiKey: {
Title: "API kľúč Azure",
SubTitle: "Skontrolujte svoj API kľúč v Azure konzole",
Placeholder: "API kľúč Azure",
},
Endpoint: {
Title: "Koncový bod Azure",
SubTitle: "Príklad: ",
},
ApiVerion: {
Title: "Verzia API Azure",
SubTitle: "Skontrolujte svoju verziu API v Azure konzole",
},
},
tw: {
ApiKey: {
Title: "介面金鑰",
SubTitle: "使用自定義 Azure Key 繞過密碼存取限制",
Placeholder: "Azure API Key",
},
Endpoint: {
Title: "介面(Endpoint) 地址",
SubTitle: "樣例:",
},
ApiVerion: {
Title: "介面版本 (azure api version)",
SubTitle: "選擇指定的部分版本",
},
},
},
"en",
);