feat: 1) Present 'maxtokens' as properties tied to a single model. 2) Remove the original author's implementation of the send verification logic and replace it with a user input validator. Pre-verification 3) Provides the ability to pull the 'User Visible modellist' provided by 'provider' 4) Provider-related parameters are passed in the constructor of 'providerClient'. Not passed in the 'chat' method

This commit is contained in:
Dean-YZG
2024-05-17 21:11:21 +08:00
parent 74a6e1260e
commit 8093d1ffba
30 changed files with 883 additions and 581 deletions

View File

@@ -1,15 +1,17 @@
import { settingItems, SettingKeys, modelConfigs, AzureMetas } from "./config";
import {
ChatHandlers,
InternalChatRequestPayload,
IProviderTemplate,
} from "../../core/types";
import { getMessageTextContent } from "@/app/utils";
ModelInfo,
getMessageTextContent,
} from "../../common";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales";
import { makeAzurePath, makeBearer, prettyObject, validString } from "./utils";
export type AzureProviderSettingKeys = SettingKeys;
@@ -43,13 +45,30 @@ interface RequestPayload {
max_tokens?: number;
}
interface ModelList {
object: "list";
data: Array<{
capabilities: {
fine_tune: boolean;
inference: boolean;
completion: boolean;
chat_completion: boolean;
embeddings: boolean;
};
lifecycle_status: "generally-available";
id: string;
created_at: number;
object: "model";
}>;
}
export default class Azure
implements IProviderTemplate<SettingKeys, "azure", typeof AzureMetas>
{
name = "azure" as const;
metas = AzureMetas;
models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name }));
defaultModels = modelConfigs;
providerMeta = {
displayName: "Azure",
@@ -62,25 +81,11 @@ export default class Azure
const {
providerConfig: { azureUrl, azureApiVersion },
} = payload;
const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion!);
const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion);
console.log("[Proxy Endpoint] ", azureUrl, path);
let baseUrl = azureUrl;
if (!baseUrl) {
baseUrl = "/api/openai";
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(AzureMetas.OpenAI)) {
baseUrl = "https://" + baseUrl;
}
console.log("[Proxy Endpoint] ", baseUrl, path);
return [baseUrl, path].join("/");
return [azureUrl!, path].join("/");
}
private getHeaders(payload: InternalChatRequestPayload<SettingKeys>) {
@@ -90,14 +95,9 @@ export default class Azure
"Content-Type": "application/json",
Accept: "application/json",
};
const authHeader = "Authorization";
const makeBearer = (s: string) => `Bearer ${s.trim()}`;
const validString = (x?: string): x is string => Boolean(x && x.length > 0);
// when using google api in app, not set auth header
if (validString(azureApiKey)) {
headers[authHeader] = makeBearer(azureApiKey);
headers["Authorization"] = makeBearer(azureApiKey);
}
return headers;
@@ -197,52 +197,12 @@ export default class Azure
streamChat(
payload: InternalChatRequestPayload<SettingKeys>,
onProgress: (message: string, chunk: string) => void,
onFinish: (message: string) => void,
onError: (err: Error) => void,
handlers: ChatHandlers,
) {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
let responseText = "";
let remainText = "";
let finished = false;
// animate response to make it looks smooth
const animateResponseText = () => {
if (finished || timer.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
if (responseText?.length === 0) {
onError(new Error("empty response from server"));
}
return;
}
if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
onProgress(responseText, fetchText);
}
requestAnimationFrame(animateResponseText);
};
// start animaion
animateResponseText();
const finish = () => {
if (!finished) {
finished = true;
onFinish(responseText + remainText);
}
};
timer.signal.onabort = finish;
fetchEventSource(requestPayload.url, {
...requestPayload,
async onopen(res) {
@@ -251,8 +211,8 @@ export default class Azure
console.log("[OpenAI] request response content type: ", contentType);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
const responseText = await res.clone().text();
return handlers.onFlash(responseText);
}
if (
@@ -262,29 +222,29 @@ export default class Azure
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
const responseTexts = [];
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
const responseText = responseTexts.join("\n\n");
return finish();
return handlers.onFlash(responseText);
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
if (msg.data === "[DONE]") {
return;
}
const text = msg.data;
try {
@@ -293,34 +253,41 @@ export default class Azure
delta: { content: string };
}>;
const delta = choices[0]?.delta?.content;
const textmoderation = json?.prompt_filter_results;
if (delta) {
remainText += delta;
handlers.onProgress(delta);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
handlers.onFinish();
},
onerror(e) {
onError(e);
handlers.onError(e);
throw e;
},
openWhenHidden: true,
});
return timer;
}
}
function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
async getAvailableModels(
providerConfig: Record<SettingKeys, string>,
): Promise<ModelInfo[]> {
const { azureApiKey, azureUrl } = providerConfig;
const res = await fetch(`${azureUrl}/vi/models`, {
headers: {
Authorization: `Bearer ${azureApiKey}`,
},
method: "GET",
});
const data: ModelList = await res.json();
return data.data.map((o) => ({
name: o.id,
}));
}
}