feat: 1) Present 'maxtokens' as properties tied to a single model. 2) Remove the original author's implementation of the send verification logic and replace it with a user input validator. Pre-verification 3) Provides the ability to pull the 'User Visible modellist' provided by 'provider' 4) Provider-related parameters are passed in the constructor of 'providerClient'. Not passed in the 'chat' method

This commit is contained in:
Dean-YZG
2024-05-17 21:11:21 +08:00
parent 74a6e1260e
commit 8093d1ffba
30 changed files with 883 additions and 581 deletions

View File

@@ -1,12 +1,11 @@
import Locale from "./locale";
import { SettingItem } from "../../core/types";
import { SettingItem } from "../../common";
import { modelConfigs as openaiModelConfigs } from "../openai/config";
export const AzureMetas = {
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
ChatPath: "v1/chat/completions",
OpenAI: "/api/openai",
};
export type SettingKeys = "azureUrl" | "azureApiKey" | "azureApiVersion";
@@ -20,6 +19,21 @@ export const settingItems: SettingItem<SettingKeys>[] = [
description: Locale.Endpoint.SubTitle + AzureMetas.ExampleEndpoint,
placeholder: AzureMetas.ExampleEndpoint,
type: "input",
validators: [
async (v: any) => {
if (typeof v === "string") {
try {
new URL(v);
} catch (e) {
return Locale.Endpoint.Error.IllegalURL;
}
}
if (typeof v === "string" && v.endsWith("/")) {
return Locale.Endpoint.Error.EndWithBackslash;
}
},
"required",
],
},
{
name: "azureApiKey",

View File

@@ -1,15 +1,17 @@
import { settingItems, SettingKeys, modelConfigs, AzureMetas } from "./config";
import {
ChatHandlers,
InternalChatRequestPayload,
IProviderTemplate,
} from "../../core/types";
import { getMessageTextContent } from "@/app/utils";
ModelInfo,
getMessageTextContent,
} from "../../common";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales";
import { makeAzurePath, makeBearer, prettyObject, validString } from "./utils";
export type AzureProviderSettingKeys = SettingKeys;
@@ -43,13 +45,30 @@ interface RequestPayload {
max_tokens?: number;
}
interface ModelList {
object: "list";
data: Array<{
capabilities: {
fine_tune: boolean;
inference: boolean;
completion: boolean;
chat_completion: boolean;
embeddings: boolean;
};
lifecycle_status: "generally-available";
id: string;
created_at: number;
object: "model";
}>;
}
export default class Azure
implements IProviderTemplate<SettingKeys, "azure", typeof AzureMetas>
{
name = "azure" as const;
metas = AzureMetas;
models = modelConfigs.map((c) => ({ ...c, providerTemplateName: this.name }));
defaultModels = modelConfigs;
providerMeta = {
displayName: "Azure",
@@ -62,25 +81,11 @@ export default class Azure
const {
providerConfig: { azureUrl, azureApiVersion },
} = payload;
const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion!);
const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion);
console.log("[Proxy Endpoint] ", azureUrl, path);
let baseUrl = azureUrl;
if (!baseUrl) {
baseUrl = "/api/openai";
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(AzureMetas.OpenAI)) {
baseUrl = "https://" + baseUrl;
}
console.log("[Proxy Endpoint] ", baseUrl, path);
return [baseUrl, path].join("/");
return [azureUrl!, path].join("/");
}
private getHeaders(payload: InternalChatRequestPayload<SettingKeys>) {
@@ -90,14 +95,9 @@ export default class Azure
"Content-Type": "application/json",
Accept: "application/json",
};
const authHeader = "Authorization";
const makeBearer = (s: string) => `Bearer ${s.trim()}`;
const validString = (x?: string): x is string => Boolean(x && x.length > 0);
// when using google api in app, not set auth header
if (validString(azureApiKey)) {
headers[authHeader] = makeBearer(azureApiKey);
headers["Authorization"] = makeBearer(azureApiKey);
}
return headers;
@@ -197,52 +197,12 @@ export default class Azure
streamChat(
payload: InternalChatRequestPayload<SettingKeys>,
onProgress: (message: string, chunk: string) => void,
onFinish: (message: string) => void,
onError: (err: Error) => void,
handlers: ChatHandlers,
) {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
let responseText = "";
let remainText = "";
let finished = false;
// animate response to make it looks smooth
const animateResponseText = () => {
if (finished || timer.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
if (responseText?.length === 0) {
onError(new Error("empty response from server"));
}
return;
}
if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
onProgress(responseText, fetchText);
}
requestAnimationFrame(animateResponseText);
};
// start animaion
animateResponseText();
const finish = () => {
if (!finished) {
finished = true;
onFinish(responseText + remainText);
}
};
timer.signal.onabort = finish;
fetchEventSource(requestPayload.url, {
...requestPayload,
async onopen(res) {
@@ -251,8 +211,8 @@ export default class Azure
console.log("[OpenAI] request response content type: ", contentType);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
const responseText = await res.clone().text();
return handlers.onFlash(responseText);
}
if (
@@ -262,29 +222,29 @@ export default class Azure
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
const responseTexts = [];
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
const responseText = responseTexts.join("\n\n");
return finish();
return handlers.onFlash(responseText);
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
if (msg.data === "[DONE]") {
return;
}
const text = msg.data;
try {
@@ -293,34 +253,41 @@ export default class Azure
delta: { content: string };
}>;
const delta = choices[0]?.delta?.content;
const textmoderation = json?.prompt_filter_results;
if (delta) {
remainText += delta;
handlers.onProgress(delta);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
handlers.onFinish();
},
onerror(e) {
onError(e);
handlers.onError(e);
throw e;
},
openWhenHidden: true,
});
return timer;
}
}
function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
async getAvailableModels(
providerConfig: Record<SettingKeys, string>,
): Promise<ModelInfo[]> {
const { azureApiKey, azureUrl } = providerConfig;
const res = await fetch(`${azureUrl}/vi/models`, {
headers: {
Authorization: `Bearer ${azureApiKey}`,
},
method: "GET",
});
const data: ModelList = await res.json();
return data.data.map((o) => ({
name: o.id,
}));
}
}

View File

@@ -1,4 +1,4 @@
import { getLocaleText } from "../../core/locale";
import { getLocaleText } from "../../common";
export default getLocaleText<
{
@@ -10,6 +10,10 @@ export default getLocaleText<
Endpoint: {
Title: string;
SubTitle: string;
Error: {
EndWithBackslash: string;
IllegalURL: string;
};
};
ApiVerion: {
Title: string;
@@ -29,6 +33,10 @@ export default getLocaleText<
Endpoint: {
Title: "接口地址",
SubTitle: "样例:",
Error: {
EndWithBackslash: "不能以「/」结尾",
IllegalURL: "请输入一个完整可用的url",
},
},
ApiVerion: {
@@ -46,6 +54,10 @@ export default getLocaleText<
Endpoint: {
Title: "Azure Endpoint",
SubTitle: "Example: ",
Error: {
EndWithBackslash: "Cannot end with '/'",
IllegalURL: "Please enter a complete available url",
},
},
ApiVerion: {
@@ -63,6 +75,10 @@ export default getLocaleText<
Endpoint: {
Title: "Endpoint Azure",
SubTitle: "Exemplo: ",
Error: {
EndWithBackslash: "Não é possível terminar com '/'",
IllegalURL: "Insira um URL completo disponível",
},
},
ApiVerion: {
@@ -80,6 +96,10 @@ export default getLocaleText<
Endpoint: {
Title: "Koncový bod Azure",
SubTitle: "Príklad: ",
Error: {
EndWithBackslash: "Nemôže končiť znakom „/“",
IllegalURL: "Zadajte úplnú dostupnú adresu URL",
},
},
ApiVerion: {
@@ -97,6 +117,10 @@ export default getLocaleText<
Endpoint: {
Title: "介面(Endpoint) 地址",
SubTitle: "樣例:",
Error: {
EndWithBackslash: "不能以「/」結尾",
IllegalURL: "請輸入一個完整可用的url",
},
},
ApiVerion: {

View File

@@ -0,0 +1,27 @@
export function makeAzurePath(path: string, apiVersion: string) {
// should omit /v1 prefix
path = path.replaceAll("v1/", "");
// should add api-key to query string
path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`;
return path;
}
export function prettyObject(msg: any) {
const obj = msg;
if (typeof msg !== "string") {
msg = JSON.stringify(msg, null, " ");
}
if (msg === "{}") {
return obj.toString();
}
if (msg.startsWith("```json")) {
return msg;
}
return ["```json", msg, "```"].join("\n");
}
export const makeBearer = (s: string) => `Bearer ${s.trim()}`;
export const validString = (x?: string): x is string =>
Boolean(x && x.length > 0);