feat: mix handlers of proxy server in providers

This commit is contained in:
Dean-YZG
2024-05-22 21:31:54 +08:00
parent 8093d1ffba
commit 8de8acdce8
23 changed files with 1570 additions and 420 deletions

View File

@@ -1,8 +1,25 @@
import { SettingItem } from "../../common";
import Locale from "./locale";
export const preferredRegion: string | string[] = [
"bom1",
"cle1",
"cpt1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];
export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
export const GoogleMetas = {
ExampleEndpoint: "https://generativelanguage.googleapis.com/",
ExampleEndpoint: GEMINI_BASE_URL,
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
};
@@ -32,13 +49,16 @@ export const modelConfigs = [
},
];
export const settingItems: SettingItem<SettingKeys>[] = [
export const settingItems: (
defaultEndpoint: string,
) => SettingItem<SettingKeys>[] = (defaultEndpoint) => [
{
name: "googleUrl",
title: Locale.Endpoint.Title,
description: Locale.Endpoint.SubTitle + GoogleMetas.ExampleEndpoint,
placeholder: GoogleMetas.ExampleEndpoint,
type: "input",
defaultValue: defaultEndpoint,
validators: [
async (v: any) => {
if (typeof v === "string") {
@@ -52,6 +72,7 @@ export const settingItems: SettingItem<SettingKeys>[] = [
return Locale.Endpoint.Error.EndWithBackslash;
}
},
"required",
],
},
{

View File

@@ -1,4 +1,11 @@
import { SettingKeys, modelConfigs, settingItems, GoogleMetas } from "./config";
import {
SettingKeys,
modelConfigs,
settingItems,
GoogleMetas,
GEMINI_BASE_URL,
preferredRegion,
} from "./config";
import {
ChatHandlers,
InternalChatRequestPayload,
@@ -8,7 +15,14 @@ import {
getMessageTextContent,
getMessageImages,
} from "../../common";
import { ensureProperEnding, makeBearer, validString } from "./utils";
import {
auth,
ensureProperEnding,
getTimer,
parseResp,
urlParamApikeyName,
} from "./utils";
import { NextResponse } from "next/server";
export type GoogleProviderSettingKeys = SettingKeys;
@@ -29,38 +43,38 @@ interface ModelList {
nextPageToken: string;
}
type ProviderTemplate = IProviderTemplate<
SettingKeys,
"azure",
typeof GoogleMetas
>;
export default class GoogleProvider
implements IProviderTemplate<SettingKeys, "google", typeof GoogleMetas>
{
allowedApiMethods: (
| "POST"
| "GET"
| "OPTIONS"
| "PUT"
| "PATCH"
| "DELETE"
)[] = ["GET", "POST"];
runtime = "edge" as const;
apiRouteRootName: "/api/provider/google" = "/api/provider/google";
preferredRegion = preferredRegion;
name = "google" as const;
metas = GoogleMetas;
providerMeta = {
displayName: "Google",
settingItems,
settingItems: settingItems(this.apiRouteRootName),
};
defaultModels = modelConfigs;
readonly REQUEST_TIMEOUT_MS = 60000;
private getHeaders(payload: InternalChatRequestPayload<SettingKeys>) {
const {
providerConfig: { googleApiKey },
context: { isApp },
} = payload;
const headers: Record<string, string> = {
"Content-Type": "application/json",
Accept: "application/json",
};
if (!isApp && validString(googleApiKey)) {
headers["Authorization"] = makeBearer(googleApiKey);
}
return headers;
}
private formatChatPayload(payload: InternalChatRequestPayload<SettingKeys>) {
const {
messages,
@@ -69,19 +83,16 @@ export default class GoogleProvider
stream,
modelConfig,
providerConfig,
context: { isApp },
} = payload;
const { googleUrl, googleApiKey } = providerConfig;
const { temperature, top_p, max_tokens } = modelConfig;
let multimodal = false;
const internalMessages = messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }];
if (isVisionModel) {
const images = getMessageImages(v);
if (images.length > 0) {
multimodal = true;
parts = parts.concat(
images.map((image) => {
const imageType = image.split(";")[0].split(":")[1];
@@ -145,16 +156,15 @@ export default class GoogleProvider
],
};
let googleChatPath = GoogleMetas.ChatPath(model);
let baseUrl = googleUrl ?? "/api/google/" + googleChatPath;
if (isApp) {
baseUrl += `?key=${googleApiKey}`;
}
const baseUrl = `${googleUrl}/${GoogleMetas.ChatPath(
model,
)}?${urlParamApikeyName}=${googleApiKey}`;
return {
headers: this.getHeaders(payload),
headers: {
"Content-Type": "application/json",
Accept: "application/json",
},
body: JSON.stringify(requestPayload),
method: "POST",
url: stream
@@ -162,46 +172,15 @@ export default class GoogleProvider
: baseUrl,
};
}
private readWholeMessageResponseBody(res: any) {
if (res?.promptFeedback?.blockReason) {
// being blocked
throw new Error(
"Message is being blocked for reason: " +
res.promptFeedback.blockReason,
);
}
return {
message:
res.candidates?.at(0)?.content?.parts?.at(0)?.text ||
res.error?.message ||
"",
};
}
private getTimer = () => {
const controller = new AbortController();
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
this.REQUEST_TIMEOUT_MS,
);
return {
...controller,
clear: () => {
clearTimeout(requestTimeoutId);
},
};
};
streamChat(
payload: InternalChatRequestPayload<SettingKeys>,
handlers: ChatHandlers,
fetch: typeof window.fetch,
) {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
const timer = getTimer();
let existingTexts: string[] = [];
@@ -274,15 +253,10 @@ export default class GoogleProvider
async chat(
payload: InternalChatRequestPayload<SettingKeys>,
fetch: typeof window.fetch,
): Promise<StandChatReponseMessage> {
const requestPayload = this.formatChatPayload(payload);
const timer = this.getTimer();
// make a fetch request
const requestTimeoutId = setTimeout(
() => timer.abort(),
this.REQUEST_TIMEOUT_MS,
);
const timer = getTimer();
const res = await fetch(requestPayload.url, {
headers: {
@@ -293,10 +267,10 @@ export default class GoogleProvider
signal: timer.signal,
});
clearTimeout(requestTimeoutId);
timer.clear();
const resJson = await res.json();
const message = this.readWholeMessageResponseBody(resJson);
const message = parseResp(resJson);
return message;
}
@@ -315,4 +289,65 @@ export default class GoogleProvider
return data.models;
}
serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] =
async (req, serverConfig) => {
const { googleUrl = GEMINI_BASE_URL } = serverConfig;
const controller = new AbortController();
const path = `${req.nextUrl.pathname}`.replaceAll(
this.apiRouteRootName,
"",
);
console.log("[Proxy] ", path);
console.log("[Base Url]", googleUrl);
const authResult = auth(req, serverConfig);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
const fetchUrl = `${googleUrl}/${path}?key=${authResult.apiKey}`;
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
"Cache-Control": "no-store",
},
method: req.method,
body: req.body,
// to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
redirect: "manual",
// @ts-ignore
duplex: "half",
signal: controller.signal,
};
const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
try {
const res = await fetch(fetchUrl, fetchOptions);
// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");
return new NextResponse(res.body, {
status: res.status,
statusText: res.statusText,
headers: newHeaders,
});
} finally {
clearTimeout(timeoutId);
}
};
}

View File

@@ -1,3 +1,10 @@
import { NextRequest } from "next/server";
import { ServerConfig, getIP } from "../../common";
export const urlParamApikeyName = "key";
export const REQUEST_TIMEOUT_MS = 60000;
export const makeBearer = (s: string) => `Bearer ${s.trim()}`;
export const validString = (x?: string): x is string =>
Boolean(x && x.length > 0);
@@ -8,3 +15,73 @@ export function ensureProperEnding(str: string) {
}
return str;
}
export function auth(req: NextRequest, serverConfig: ServerConfig) {
let apiKey = req.nextUrl.searchParams.get(urlParamApikeyName);
const { hideUserApiKey, googleApiKey } = serverConfig;
console.log("[User IP] ", getIP(req));
console.log("[Time] ", new Date().toLocaleString());
if (hideUserApiKey && apiKey) {
return {
error: true,
message: "you are not allowed to access with your own api key",
};
}
if (apiKey) {
console.log("[Auth] use user api key");
return {
error: false,
apiKey,
};
}
if (googleApiKey) {
console.log("[Auth] use system api key");
return {
error: false,
apiKey: googleApiKey,
};
}
console.log("[Auth] admin did not provide an api key");
return {
error: true,
message: `missing api key`,
};
}
export function getTimer() {
const controller = new AbortController();
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
return {
...controller,
clear: () => {
clearTimeout(requestTimeoutId);
},
};
}
export function parseResp(res: any) {
if (res?.promptFeedback?.blockReason) {
// being blocked
throw new Error(
"Message is being blocked for reason: " + res.promptFeedback.blockReason,
);
}
return {
message:
res.candidates?.at(0)?.content?.parts?.at(0)?.text ||
res.error?.message ||
"",
};
}