diff --git a/app/api/provider/[...path]/route.ts b/app/api/provider/[...path]/route.ts new file mode 100644 index 000000000..f6bba4ca6 --- /dev/null +++ b/app/api/provider/[...path]/route.ts @@ -0,0 +1,93 @@ +import * as ProviderTemplates from "@/app/client/providers"; +import { getServerSideConfig } from "@/app/config/server"; +import { NextRequest, NextResponse } from "next/server"; +import { cloneDeep } from "lodash-es"; +import { + disableSystemApiKey, + makeUrlsUsable, + modelNameRequestHeader, +} from "@/app/client/common"; +import { collectModelTable } from "@/app/utils/model"; + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + const [providerName] = params.path; + const { headers } = req; + const serverConfig = getServerSideConfig(); + const modelName = headers.get(modelNameRequestHeader); + + const ProviderTemplate = Object.values(ProviderTemplates).find( + (t) => t.prototype.name === providerName, + ); + + if (!ProviderTemplate) { + return NextResponse.json( + { + error: true, + message: "No provider found: " + providerName, + }, + { + status: 404, + }, + ); + } + + // #1815 try to refuse gpt4 request + if (modelName && serverConfig.customModels) { + try { + const modelTable = collectModelTable([], serverConfig.customModels); + + // not undefined and is false + if (modelTable[modelName]?.available === false) { + return NextResponse.json( + { + error: true, + message: `you are not allowed to use ${modelName} model`, + }, + { + status: 403, + }, + ); + } + } catch (e) { + console.error("models filter", e); + } + } + + const config = disableSystemApiKey( + makeUrlsUsable(cloneDeep(serverConfig), [ + "anthropicUrl", + "azureUrl", + "googleUrl", + "baseUrl", + ]), + ["anthropicApiKey", "azureApiKey", "googleApiKey", "apiKey"], + serverConfig.needCode && + ProviderTemplate !== ProviderTemplates.NextChatProvider, // if it must take a access code in the req, do not provide system-keys for Non-nextchat providers + ); + + const request = Object.assign({}, req, { + subpath: params.path.join("/"), + }); + + return new ProviderTemplate().serverSideRequestHandler(request, config); +} + +export const GET = handle; +export const POST = handle; +export const PUT = handle; +export const PATCH = handle; +export const DELETE = handle; +export const OPTIONS = handle; + +export const runtime = "edge"; +export const preferredRegion = Array.from( + new Set( + Object.values(ProviderTemplates).reduce( + (arr, t) => [...arr, ...(t.prototype.preferredRegion ?? [])], + [] as string[], + ), + ), +); diff --git a/app/client/common/index.ts b/app/client/common/index.ts index 807aac0cd..f1414bf1a 100644 --- a/app/client/common/index.ts +++ b/app/client/common/index.ts @@ -3,3 +3,5 @@ export * from "./types"; export * from "./locale"; export * from "./utils"; + +export const modelNameRequestHeader = "x-nextchat-model-name"; diff --git a/app/client/common/types.ts b/app/client/common/types.ts index 49031ea7f..990559f26 100644 --- a/app/client/common/types.ts +++ b/app/client/common/types.ts @@ -1,4 +1,6 @@ import { RequestMessage } from "../api"; +import { getServerSideConfig } from "@/app/config/server"; +import { NextRequest, NextResponse } from "next/server"; export { type RequestMessage }; @@ -152,6 +154,9 @@ export type SettingItem = // ===================================== Provider Settings Types end ====================================== // ===================================== Provider Template Types start ====================================== + +export type ServerConfig = ReturnType; + export interface IProviderTemplate< SettingKeys extends string, NAME extends string, @@ -159,6 +164,12 @@ export interface IProviderTemplate< > { readonly name: NAME; + readonly apiRouteRootName: `/api/provider/${NAME}`; + + readonly allowedApiMethods: Array< + "GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "OPTIONS" + >; + readonly metas: Meta; readonly providerMeta: { @@ -170,17 +181,31 @@ export interface IProviderTemplate< streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ): AbortController; chat( payload: InternalChatRequestPayload, + fetch: typeof window.fetch, ): Promise; getAvailableModels?( providerConfig: InternalChatRequestPayload["providerConfig"], ): Promise; + + readonly runtime: "edge"; + readonly preferredRegion: "auto" | "global" | "home" | string | string[]; + + serverSideRequestHandler( + req: NextRequest & { + subpath: string; + }, + serverConfig: ServerConfig, + ): Promise; } +export type ProviderTemplate = IProviderTemplate; + export interface Serializable { serialize(): Snapshot; } diff --git a/app/client/common/utils.ts b/app/client/common/utils.ts index de23c7825..bb9e579d1 100644 --- a/app/client/common/utils.ts +++ b/app/client/common/utils.ts @@ -1,4 +1,6 @@ -import { RequestMessage } from "./types"; +import { NextRequest } from "next/server"; +import { RequestMessage, ServerConfig } from "./types"; +import { cloneDeep } from "lodash-es"; export function getMessageTextContent(message: RequestMessage) { if (typeof message.content === "string") { @@ -24,3 +26,63 @@ export function getMessageImages(message: RequestMessage): string[] { } return urls; } + +export function getIP(req: NextRequest) { + let ip = req.ip ?? req.headers.get("x-real-ip"); + const forwardedFor = req.headers.get("x-forwarded-for"); + + if (!ip && forwardedFor) { + ip = forwardedFor.split(",").at(0) ?? ""; + } + + return ip; +} + +export function formatUrl(baseUrl?: string) { + if (baseUrl && !baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + if (baseUrl?.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + return baseUrl; +} + +function travel( + config: ServerConfig, + keys: Array, + handle: (prop: any) => any, +): ServerConfig { + const copiedConfig = cloneDeep(config); + keys.forEach((k) => { + copiedConfig[k] = handle(copiedConfig[k] as string) as never; + }); + return copiedConfig; +} + +export const makeUrlsUsable = ( + config: ServerConfig, + keys: Array, +) => travel(config, keys, formatUrl); + +export const disableSystemApiKey = ( + config: ServerConfig, + keys: Array, + forbidden: boolean, +) => + travel(config, keys, (p) => { + return forbidden ? undefined : p; + }); + +export function isSameOrigin(requestUrl: string) { + var a = document.createElement("a"); + a.href = requestUrl; + + // 检查协议、主机名和端口号是否与当前页面相同 + return ( + a.protocol === window.location.protocol && + a.hostname === window.location.hostname && + a.port === window.location.port + ); +} diff --git a/app/client/core/index.ts b/app/client/core/index.ts index 2ffc6679e..963227f5e 100644 --- a/app/client/core/index.ts +++ b/app/client/core/index.ts @@ -1,3 +1,5 @@ +export * from "./shim"; + export * from "../common/types"; export * from "./providerClient"; @@ -5,5 +7,3 @@ export * from "./providerClient"; export * from "./modelClient"; export * from "../common/locale"; - -export * from "./shim"; diff --git a/app/client/core/providerClient.ts b/app/client/core/providerClient.ts index 863527eaf..cc733bc8c 100644 --- a/app/client/core/providerClient.ts +++ b/app/client/core/providerClient.ts @@ -3,14 +3,15 @@ import { InternalChatHandlers, Model, ModelTemplate, + ProviderTemplate, StandChatReponseMessage, StandChatRequestPayload, + isSameOrigin, + modelNameRequestHeader, } from "../common"; import * as ProviderTemplates from "@/app/client/providers"; import { nanoid } from "nanoid"; -export type ProviderTemplate = IProviderTemplate; - export type ProviderTemplateName = (typeof ProviderTemplates)[keyof typeof ProviderTemplates]["prototype"]["name"]; @@ -38,6 +39,7 @@ const providerTemplates = Object.values(ProviderTemplates).reduce( export class ProviderClient { providerTemplate: IProviderTemplate; + genFetch: (modelName: string) => typeof window.fetch; static ProviderTemplates = providerTemplates; @@ -61,6 +63,31 @@ export class ProviderClient { constructor(private provider: Provider) { const { providerTemplateName } = provider; this.providerTemplate = this.getProviderTemplate(providerTemplateName); + this.genFetch = + (modelName: string) => + (...args) => { + const req = new Request(...args); + const headers: Record = { + ...req.headers, + }; + if (isSameOrigin(req.url)) { + headers[modelNameRequestHeader] = modelName; + } + + return window.fetch(req.url, { + method: req.method, + keepalive: req.keepalive, + headers, + body: req.body, + redirect: req.redirect, + integrity: req.integrity, + signal: req.signal, + credentials: req.credentials, + mode: req.mode, + referrer: req.referrer, + referrerPolicy: req.referrerPolicy, + }); + }; } private getProviderTemplate(providerTemplateName: string) { @@ -98,12 +125,15 @@ export class ProviderClient { async chat( payload: StandChatRequestPayload, ): Promise { - return this.providerTemplate.chat({ - ...payload, - stream: false, - isVisionModel: this.getModelConfig(payload.model)?.isVisionModel, - providerConfig: this.provider.providerConfig, - }); + return this.providerTemplate.chat( + { + ...payload, + stream: false, + isVisionModel: this.getModelConfig(payload.model)?.isVisionModel, + providerConfig: this.provider.providerConfig, + }, + this.genFetch(payload.model), + ); } streamChat(payload: StandChatRequestPayload, handlers: InternalChatHandlers) { @@ -129,6 +159,7 @@ export class ProviderClient { handlers.onFinish(message); }, }, + this.genFetch(payload.model), ); timer.signal.onabort = () => { diff --git a/app/client/providers/anthropic/config.ts b/app/client/providers/anthropic/config.ts index fe45a7aaf..60f328197 100644 --- a/app/client/providers/anthropic/config.ts +++ b/app/client/providers/anthropic/config.ts @@ -6,10 +6,11 @@ export type SettingKeys = | "anthropicApiKey" | "anthropicApiVersion"; +export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; + export const AnthropicMetas = { ChatPath: "v1/messages", - ChatPath1: "v1/complete", - ExampleEndpoint: "https://api.anthropic.com", + ExampleEndpoint: ANTHROPIC_BASE_URL, Vision: "2023-06-01", }; @@ -64,9 +65,29 @@ export const modelConfigs = [ }, ]; -const defaultEndpoint = "/api/anthropic"; +export const preferredRegion: string | string[] = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; -export const settingItems: SettingItem[] = [ +export const settingItems: ( + defaultEndpoint: string, +) => SettingItem[] = (defaultEndpoint) => [ { name: "anthropicUrl", title: Locale.Endpoint.Title, @@ -103,7 +124,7 @@ export const settingItems: SettingItem[] = [ name: "anthropicApiVersion", title: Locale.ApiVerion.Title, description: Locale.ApiVerion.SubTitle, - placeholder: AnthropicMetas.Vision, + defaultValue: AnthropicMetas.Vision, type: "input", // validators: ["required"], }, diff --git a/app/client/providers/anthropic/index.ts b/app/client/providers/anthropic/index.ts index 7d2f03350..a92d9485a 100644 --- a/app/client/providers/anthropic/index.ts +++ b/app/client/providers/anthropic/index.ts @@ -1,24 +1,33 @@ import { + ANTHROPIC_BASE_URL, AnthropicMetas, ClaudeMapper, SettingKeys, modelConfigs, + preferredRegion, settingItems, } from "./config"; import { ChatHandlers, InternalChatRequestPayload, IProviderTemplate, - getMessageTextContent, - RequestMessage, + ServerConfig, } from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; import Locale from "@/app/locales"; -import { getAuthKey, trimEnd, prettyObject } from "./utils"; +import { + prettyObject, + getTimer, + authHeaderName, + auth, + parseResp, + formatMessage, +} from "./utils"; import { cloneDeep } from "lodash-es"; +import { NextRequest, NextResponse } from "next/server"; export type AnthropicProviderSettingKeys = SettingKeys; @@ -61,86 +70,32 @@ export interface ChatRequest { stream?: boolean; // Whether to incrementally stream the response using server-sent events. } -export default class AnthropicProvider - implements IProviderTemplate -{ +type ProviderTemplate = IProviderTemplate< + SettingKeys, + "anthropic", + typeof AnthropicMetas +>; + +export default class AnthropicProvider implements ProviderTemplate { + apiRouteRootName = "/api/provider/anthropic" as const; + allowedApiMethods: ["GET", "POST"] = ["GET", "POST"]; + + runtime = "edge" as const; + preferredRegion = preferredRegion; + name = "anthropic" as const; metas = AnthropicMetas; providerMeta = { displayName: "Anthropic", - settingItems, + settingItems: settingItems( + `${this.apiRouteRootName}//${AnthropicMetas.ChatPath}`, + ), }; defaultModels = modelConfigs; - readonly REQUEST_TIMEOUT_MS = 60000; - - private path(payload: InternalChatRequestPayload) { - const { - providerConfig: { anthropicUrl }, - } = payload; - - return `${trimEnd(anthropicUrl!)}/${AnthropicMetas.ChatPath}`; - } - - private formatMessage( - messages: RequestMessage[], - payload: InternalChatRequestPayload, - ) { - const { isVisionModel } = payload; - - return messages - .flat() - .filter((v) => { - if (!v.content) return false; - if (typeof v.content === "string" && !v.content.trim()) return false; - return true; - }) - .map((v) => { - const { role, content } = v; - const insideRole = ClaudeMapper[role] ?? "user"; - - if (!isVisionModel || typeof content === "string") { - return { - role: insideRole, - content: getMessageTextContent(v), - }; - } - return { - role: insideRole, - content: content - .filter((v) => v.image_url || v.text) - .map(({ type, text, image_url }) => { - if (type === "text") { - return { - type, - text: text!, - }; - } - const { url = "" } = image_url || {}; - const colonIndex = url.indexOf(":"); - const semicolonIndex = url.indexOf(";"); - const comma = url.indexOf(","); - - const mimeType = url.slice(colonIndex + 1, semicolonIndex); - const encodeType = url.slice(semicolonIndex + 1, comma); - const data = url.slice(comma + 1); - - return { - type: "image" as const, - source: { - type: encodeType, - media_type: mimeType, - data, - }, - }; - }), - }; - }); - } - private formatChatPayload(payload: InternalChatRequestPayload) { const { messages: outsideMessages, @@ -149,7 +104,8 @@ export default class AnthropicProvider modelConfig, providerConfig, } = payload; - const { anthropicApiKey, anthropicApiVersion } = providerConfig; + const { anthropicApiKey, anthropicApiVersion, anthropicUrl } = + providerConfig; const { temperature, top_p, max_tokens } = modelConfig; const keys = ["system", "user"]; @@ -172,7 +128,7 @@ export default class AnthropicProvider } } - const prompt = this.formatMessage(messages, payload); + const prompt = formatMessage(messages, payload.isVisionModel); const requestBody: AnthropicChatRequest = { messages: prompt, @@ -188,52 +144,84 @@ export default class AnthropicProvider headers: { "Content-Type": "application/json", Accept: "application/json", - "x-api-key": anthropicApiKey ?? "", + [authHeaderName]: anthropicApiKey ?? "", "anthropic-version": anthropicApiVersion ?? "", - Authorization: getAuthKey(anthropicApiKey), }, body: JSON.stringify(requestBody), method: "POST", - url: this.path(payload), + url: anthropicUrl!, }; } - private readWholeMessageResponseBody(res: any) { - return { - message: res?.content?.[0]?.text ?? "", - }; - } - - private getTimer = (onabort: () => void = () => {}) => { + private async request(req: NextRequest, serverConfig: ServerConfig) { const controller = new AbortController(); - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - this.REQUEST_TIMEOUT_MS, + const authValue = req.headers.get(authHeaderName) ?? ""; + + const path = `${req.nextUrl.pathname}`.replaceAll( + this.apiRouteRootName, + "", ); - controller.signal.onabort = onabort; + const baseUrl = serverConfig.anthropicUrl || ANTHROPIC_BASE_URL; - return { - ...controller, - clear: () => { - clearTimeout(requestTimeoutId); + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); }, - }; - }; - - async chat(payload: InternalChatRequestPayload) { - const requestPayload = this.formatChatPayload(payload); - - const timer = this.getTimer(); - - // make a fetch request - const requestTimeoutId = setTimeout( - () => timer.abort(), - this.REQUEST_TIMEOUT_MS, + 10 * 60 * 1000, ); + const fetchUrl = `${baseUrl}${path}`; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + [authHeaderName]: authValue, + "anthropic-version": + req.headers.get("anthropic-version") || + serverConfig.anthropicApiVersion || + AnthropicMetas.Vision, + }, + method: req.method, + body: req.body, + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + console.log("[Anthropic request]", fetchOptions.headers, req.method); + try { + const res = await fetch(fetchUrl, fetchOptions); + + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } + } + + async chat( + payload: InternalChatRequestPayload, + fetch: typeof window.fetch, + ) { + const requestPayload = this.formatChatPayload(payload); + const timer = getTimer(); + const res = await fetch(requestPayload.url, { headers: { ...requestPayload.headers, @@ -246,7 +234,7 @@ export default class AnthropicProvider timer.clear(); const resJson = await res.json(); - const message = this.readWholeMessageResponseBody(resJson); + const message = parseResp(resJson); return message; } @@ -254,13 +242,14 @@ export default class AnthropicProvider streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ) { const requestPayload = this.formatChatPayload(payload); - - const timer = this.getTimer(); + const timer = getTimer(); fetchEventSource(requestPayload.url, { ...requestPayload, + fetch, async onopen(res) { timer.clear(); const contentType = res.headers.get("content-type"); @@ -329,4 +318,39 @@ export default class AnthropicProvider return timer; } + + serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] = + async (req, config) => { + const { subpath } = req; + const ALLOWD_PATH = [AnthropicMetas.ChatPath]; + + if (!ALLOWD_PATH.includes(subpath)) { + console.log("[Anthropic Route] forbidden path ", subpath); + return NextResponse.json( + { + error: true, + message: "you are not allowed to request " + subpath, + }, + { + status: 403, + }, + ); + } + + const authResult = auth(req, config); + + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await this.request(req, config); + return response; + } catch (e) { + console.error("[Anthropic] ", e); + return NextResponse.json(prettyObject(e)); + } + }; } diff --git a/app/client/providers/anthropic/utils.ts b/app/client/providers/anthropic/utils.ts index 9a36f2d72..7f6f576f5 100644 --- a/app/client/providers/anthropic/utils.ts +++ b/app/client/providers/anthropic/utils.ts @@ -1,3 +1,15 @@ +import { NextRequest } from "next/server"; +import { + RequestMessage, + ServerConfig, + getIP, + getMessageTextContent, +} from "../../common"; +import { ClaudeMapper } from "./config"; + +export const REQUEST_TIMEOUT_MS = 60000; +export const authHeaderName = "x-api-key"; + export function trimEnd(s: string, end = " ") { if (end.length === 0) return s; @@ -12,17 +24,6 @@ export function bearer(value: string) { return `Bearer ${value.trim()}`; } -export function getAuthKey(apiKey = "") { - let authKey = ""; - - if (apiKey) { - // use user's api key first - authKey = bearer(apiKey); - } - - return authKey; -} - export function prettyObject(msg: any) { const obj = msg; if (typeof msg !== "string") { @@ -36,3 +37,115 @@ export function prettyObject(msg: any) { } return ["```json", msg, "```"].join("\n"); } + +export function getTimer() { + const controller = new AbortController(); + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + return { + ...controller, + clear: () => { + clearTimeout(requestTimeoutId); + }, + }; +} + +export function auth(req: NextRequest, serverConfig: ServerConfig) { + const apiKey = req.headers.get(authHeaderName); + + console.log("[User IP] ", getIP(req)); + console.log("[Time] ", new Date().toLocaleString()); + + if (serverConfig.hideUserApiKey && apiKey) { + return { + error: true, + message: "you are not allowed to access with your own api key", + }; + } + + if (apiKey) { + console.log("[Auth] use user api key"); + return { + error: false, + }; + } + + // if user does not provide an api key, inject system api key + const systemApiKey = serverConfig.anthropicApiKey; + + if (systemApiKey) { + console.log("[Auth] use system api key"); + req.headers.set(authHeaderName, systemApiKey); + } else { + console.log("[Auth] admin did not provide an api key"); + } + + return { + error: false, + }; +} + +export function parseResp(res: any) { + return { + message: res?.content?.[0]?.text ?? "", + }; +} + +export function formatMessage( + messages: RequestMessage[], + isVisionModel?: boolean, +) { + return messages + .flat() + .filter((v) => { + if (!v.content) return false; + if (typeof v.content === "string" && !v.content.trim()) return false; + return true; + }) + .map((v) => { + const { role, content } = v; + const insideRole = ClaudeMapper[role] ?? "user"; + + if (!isVisionModel || typeof content === "string") { + return { + role: insideRole, + content: getMessageTextContent(v), + }; + } + return { + role: insideRole, + content: content + .filter((v) => v.image_url || v.text) + .map(({ type, text, image_url }) => { + if (type === "text") { + return { + type, + text: text!, + }; + } + const { url = "" } = image_url || {}; + const colonIndex = url.indexOf(":"); + const semicolonIndex = url.indexOf(";"); + const comma = url.indexOf(","); + + const mimeType = url.slice(colonIndex + 1, semicolonIndex); + const encodeType = url.slice(semicolonIndex + 1, comma); + const data = url.slice(comma + 1); + + return { + type: "image" as const, + source: { + type: encodeType, + media_type: mimeType, + data, + }, + }; + }), + }; + }); +} diff --git a/app/client/providers/azure/config.ts b/app/client/providers/azure/config.ts index c26b29f8b..7b9b05e78 100644 --- a/app/client/providers/azure/config.ts +++ b/app/client/providers/azure/config.ts @@ -5,20 +5,44 @@ import { modelConfigs as openaiModelConfigs } from "../openai/config"; export const AzureMetas = { ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", - ChatPath: "v1/chat/completions", + ChatPath: "chat/completions", + ListModelPath: "v1/models", }; export type SettingKeys = "azureUrl" | "azureApiKey" | "azureApiVersion"; +export const preferredRegion: string | string[] = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + export const modelConfigs = openaiModelConfigs; -export const settingItems: SettingItem[] = [ +export const settingItems: ( + defaultEndpoint: string, +) => SettingItem[] = (defaultEndpoint) => [ { name: "azureUrl", title: Locale.Endpoint.Title, description: Locale.Endpoint.SubTitle + AzureMetas.ExampleEndpoint, placeholder: AzureMetas.ExampleEndpoint, type: "input", + defaultValue: defaultEndpoint, validators: [ async (v: any) => { if (typeof v === "string") { diff --git a/app/client/providers/azure/index.ts b/app/client/providers/azure/index.ts index 2d5ee112e..ca20cf2f1 100644 --- a/app/client/providers/azure/index.ts +++ b/app/client/providers/azure/index.ts @@ -1,17 +1,33 @@ -import { settingItems, SettingKeys, modelConfigs, AzureMetas } from "./config"; +import { + settingItems, + SettingKeys, + modelConfigs, + AzureMetas, + preferredRegion, +} from "./config"; import { ChatHandlers, InternalChatRequestPayload, IProviderTemplate, ModelInfo, getMessageTextContent, + ServerConfig, } from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; import Locale from "@/app/locales"; -import { makeAzurePath, makeBearer, prettyObject, validString } from "./utils"; +import { + auth, + authHeaderName, + getHeaders, + getTimer, + makeAzurePath, + parseResp, + prettyObject, +} from "./utils"; +import { NextRequest, NextResponse } from "next/server"; export type AzureProviderSettingKeys = SettingKeys; @@ -62,9 +78,35 @@ interface ModelList { }>; } -export default class Azure - implements IProviderTemplate -{ +interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +type ProviderTemplate = IProviderTemplate< + SettingKeys, + "azure", + typeof AzureMetas +>; + +export default class Azure implements ProviderTemplate { + apiRouteRootName: "/api/provider/azure" = "/api/provider/azure"; + allowedApiMethods: ( + | "POST" + | "GET" + | "OPTIONS" + | "PUT" + | "PATCH" + | "DELETE" + )[] = ["POST", "GET"]; + runtime = "edge" as const; + + preferredRegion = preferredRegion; + name = "azure" as const; metas = AzureMetas; @@ -72,46 +114,26 @@ export default class Azure providerMeta = { displayName: "Azure", - settingItems, + settingItems: settingItems( + `${this.apiRouteRootName}/${AzureMetas.ChatPath}`, + ), }; - readonly REQUEST_TIMEOUT_MS = 60000; - - private path(payload: InternalChatRequestPayload): string { + private formatChatPayload(payload: InternalChatRequestPayload) { const { + messages, + isVisionModel, + model, + stream, + modelConfig: { + temperature, + presence_penalty, + frequency_penalty, + top_p, + max_tokens, + }, providerConfig: { azureUrl, azureApiVersion }, } = payload; - const path = makeAzurePath(AzureMetas.ChatPath, azureApiVersion!); - - console.log("[Proxy Endpoint] ", azureUrl, path); - - return [azureUrl!, path].join("/"); - } - - private getHeaders(payload: InternalChatRequestPayload) { - const { azureApiKey } = payload.providerConfig; - - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json", - }; - - if (validString(azureApiKey)) { - headers["Authorization"] = makeBearer(azureApiKey); - } - - return headers; - } - - private formatChatPayload(payload: InternalChatRequestPayload) { - const { messages, isVisionModel, model, stream, modelConfig } = payload; - const { - temperature, - presence_penalty, - frequency_penalty, - top_p, - max_tokens, - } = modelConfig; const openAiMessages = messages.map((v) => ({ role: v.role, @@ -136,47 +158,105 @@ export default class Azure console.log("[Request] openai payload: ", requestPayload); return { - headers: this.getHeaders(payload), + headers: getHeaders(payload.providerConfig.azureApiKey), body: JSON.stringify(requestPayload), method: "POST", - url: this.path(payload), + url: `${azureUrl}?api-version=${azureApiVersion!}`, }; } - private readWholeMessageResponseBody(res: any) { - return { - message: res.choices?.at(0)?.message?.content ?? "", - }; - } - - private getTimer = (onabort: () => void = () => {}) => { + private async requestAzure(req: NextRequest, serverConfig: ServerConfig) { const controller = new AbortController(); - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - this.REQUEST_TIMEOUT_MS, + const authValue = + req.headers + .get("Authorization") + ?.trim() + .replaceAll("Bearer ", "") + .trim() ?? ""; + + const { azureUrl, azureApiVersion } = serverConfig; + + if (!azureUrl) { + return NextResponse.json({ + error: true, + message: `missing AZURE_URL in server env vars`, + }); + } + + if (!azureApiVersion) { + return NextResponse.json({ + error: true, + message: `missing AZURE_API_VERSION in server env vars`, + }); + } + + let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( + this.apiRouteRootName, + "", ); - controller.signal.onabort = onabort; + path = makeAzurePath(path, azureApiVersion); - return { - ...controller, - clear: () => { - clearTimeout(requestTimeoutId); + console.log("[Proxy] ", path); + console.log("[Base Url]", azureUrl); + + const fetchUrl = `${azureUrl}/${path}`; + + const timeoutId = setTimeout( + () => { + controller.abort(); }, - }; - }; + 10 * 60 * 1000, + ); - async chat(payload: InternalChatRequestPayload) { + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + [authHeaderName]: authValue, + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + try { + const res = await fetch(fetchUrl, fetchOptions); + + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + // The latest version of the OpenAI API forced the content-encoding to be "br" in json response + // So if the streaming is disabled, we need to remove the content-encoding header + // Because Vercel uses gzip to compress the response, if we don't remove the content-encoding header + // The browser will try to decode the response with brotli and fail + newHeaders.delete("content-encoding"); + + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } + } + + async chat( + payload: InternalChatRequestPayload, + fetch: typeof window.fetch, + ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); - // make a fetch request - const requestTimeoutId = setTimeout( - () => timer.abort(), - this.REQUEST_TIMEOUT_MS, - ); + const timer = getTimer(); const res = await fetch(requestPayload.url, { headers: { @@ -187,10 +267,10 @@ export default class Azure signal: timer.signal, }); - clearTimeout(requestTimeoutId); + timer.clear(); const resJson = await res.json(); - const message = this.readWholeMessageResponseBody(resJson); + const message = parseResp(resJson); return message; } @@ -198,13 +278,15 @@ export default class Azure streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); fetchEventSource(requestPayload.url, { ...requestPayload, + fetch, async onopen(res) { timer.clear(); const contentType = res.headers.get("content-type"); @@ -278,7 +360,7 @@ export default class Azure providerConfig: Record, ): Promise { const { azureApiKey, azureUrl } = providerConfig; - const res = await fetch(`${azureUrl}/vi/models`, { + const res = await fetch(`${azureUrl}/${AzureMetas.ListModelPath}`, { headers: { Authorization: `Bearer ${azureApiKey}`, }, @@ -290,4 +372,37 @@ export default class Azure name: o.id, })); } + + serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] = + async (req, config) => { + const { subpath } = req; + const ALLOWD_PATH = [AzureMetas.ChatPath]; + + if (!ALLOWD_PATH.includes(subpath)) { + return NextResponse.json( + { + error: true, + message: "you are not allowed to request " + subpath, + }, + { + status: 403, + }, + ); + } + + const authResult = auth(req, config); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await this.requestAzure(req, config); + + return response; + } catch (e) { + return NextResponse.json(prettyObject(e)); + } + }; } diff --git a/app/client/providers/azure/utils.ts b/app/client/providers/azure/utils.ts index fea7457c8..f1bdda4de 100644 --- a/app/client/providers/azure/utils.ts +++ b/app/client/providers/azure/utils.ts @@ -1,7 +1,29 @@ -export function makeAzurePath(path: string, apiVersion: string) { - // should omit /v1 prefix - path = path.replaceAll("v1/", ""); +import { NextRequest } from "next/server"; +import { ServerConfig, getIP } from "../../common"; +export const authHeaderName = "api-key"; +export const REQUEST_TIMEOUT_MS = 60000; + +export function getHeaders(azureApiKey?: string) { + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json", + }; + + if (validString(azureApiKey)) { + headers[authHeaderName] = makeBearer(azureApiKey); + } + + return headers; +} + +export function parseResp(res: any) { + return { + message: res.choices?.at(0)?.message?.content ?? "", + }; +} + +export function makeAzurePath(path: string, apiVersion: string) { // should add api-key to query string path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`; @@ -25,3 +47,64 @@ export function prettyObject(msg: any) { export const makeBearer = (s: string) => `Bearer ${s.trim()}`; export const validString = (x?: string): x is string => Boolean(x && x.length > 0); + +export function parseApiKey(bearToken: string) { + const token = bearToken.trim().replaceAll("Bearer ", "").trim(); + + return { + apiKey: token, + }; +} + +export function getTimer() { + const controller = new AbortController(); + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + return { + ...controller, + clear: () => { + clearTimeout(requestTimeoutId); + }, + }; +} + +export function auth(req: NextRequest, serverConfig: ServerConfig) { + const authToken = req.headers.get(authHeaderName) ?? ""; + + const { hideUserApiKey, apiKey: systemApiKey } = serverConfig; + + const { apiKey } = parseApiKey(authToken); + + console.log("[User IP] ", getIP(req)); + console.log("[Time] ", new Date().toLocaleString()); + + if (hideUserApiKey && apiKey) { + return { + error: true, + message: "you are not allowed to access with your own api key", + }; + } + + if (apiKey) { + console.log("[Auth] use user api key"); + return { + error: false, + }; + } + + if (systemApiKey) { + console.log("[Auth] use system api key"); + req.headers.set("Authorization", `Bearer ${systemApiKey}`); + } else { + console.log("[Auth] admin did not provide an api key"); + } + + return { + error: false, + }; +} diff --git a/app/client/providers/google/config.ts b/app/client/providers/google/config.ts index 3f7e02cdb..7cf4dc85d 100644 --- a/app/client/providers/google/config.ts +++ b/app/client/providers/google/config.ts @@ -1,8 +1,25 @@ import { SettingItem } from "../../common"; import Locale from "./locale"; +export const preferredRegion: string | string[] = [ + "bom1", + "cle1", + "cpt1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; + export const GoogleMetas = { - ExampleEndpoint: "https://generativelanguage.googleapis.com/", + ExampleEndpoint: GEMINI_BASE_URL, ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; @@ -32,13 +49,16 @@ export const modelConfigs = [ }, ]; -export const settingItems: SettingItem[] = [ +export const settingItems: ( + defaultEndpoint: string, +) => SettingItem[] = (defaultEndpoint) => [ { name: "googleUrl", title: Locale.Endpoint.Title, description: Locale.Endpoint.SubTitle + GoogleMetas.ExampleEndpoint, placeholder: GoogleMetas.ExampleEndpoint, type: "input", + defaultValue: defaultEndpoint, validators: [ async (v: any) => { if (typeof v === "string") { @@ -52,6 +72,7 @@ export const settingItems: SettingItem[] = [ return Locale.Endpoint.Error.EndWithBackslash; } }, + "required", ], }, { diff --git a/app/client/providers/google/index.ts b/app/client/providers/google/index.ts index a0ea89605..1003c8b40 100644 --- a/app/client/providers/google/index.ts +++ b/app/client/providers/google/index.ts @@ -1,4 +1,11 @@ -import { SettingKeys, modelConfigs, settingItems, GoogleMetas } from "./config"; +import { + SettingKeys, + modelConfigs, + settingItems, + GoogleMetas, + GEMINI_BASE_URL, + preferredRegion, +} from "./config"; import { ChatHandlers, InternalChatRequestPayload, @@ -8,7 +15,14 @@ import { getMessageTextContent, getMessageImages, } from "../../common"; -import { ensureProperEnding, makeBearer, validString } from "./utils"; +import { + auth, + ensureProperEnding, + getTimer, + parseResp, + urlParamApikeyName, +} from "./utils"; +import { NextResponse } from "next/server"; export type GoogleProviderSettingKeys = SettingKeys; @@ -29,38 +43,38 @@ interface ModelList { nextPageToken: string; } +type ProviderTemplate = IProviderTemplate< + SettingKeys, + "azure", + typeof GoogleMetas +>; + export default class GoogleProvider implements IProviderTemplate { + allowedApiMethods: ( + | "POST" + | "GET" + | "OPTIONS" + | "PUT" + | "PATCH" + | "DELETE" + )[] = ["GET", "POST"]; + runtime = "edge" as const; + + apiRouteRootName: "/api/provider/google" = "/api/provider/google"; + + preferredRegion = preferredRegion; + name = "google" as const; metas = GoogleMetas; providerMeta = { displayName: "Google", - settingItems, + settingItems: settingItems(this.apiRouteRootName), }; defaultModels = modelConfigs; - readonly REQUEST_TIMEOUT_MS = 60000; - - private getHeaders(payload: InternalChatRequestPayload) { - const { - providerConfig: { googleApiKey }, - context: { isApp }, - } = payload; - - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json", - }; - - if (!isApp && validString(googleApiKey)) { - headers["Authorization"] = makeBearer(googleApiKey); - } - - return headers; - } - private formatChatPayload(payload: InternalChatRequestPayload) { const { messages, @@ -69,19 +83,16 @@ export default class GoogleProvider stream, modelConfig, providerConfig, - context: { isApp }, } = payload; const { googleUrl, googleApiKey } = providerConfig; const { temperature, top_p, max_tokens } = modelConfig; - let multimodal = false; const internalMessages = messages.map((v) => { let parts: any[] = [{ text: getMessageTextContent(v) }]; if (isVisionModel) { const images = getMessageImages(v); if (images.length > 0) { - multimodal = true; parts = parts.concat( images.map((image) => { const imageType = image.split(";")[0].split(":")[1]; @@ -145,16 +156,15 @@ export default class GoogleProvider ], }; - let googleChatPath = GoogleMetas.ChatPath(model); - - let baseUrl = googleUrl ?? "/api/google/" + googleChatPath; - - if (isApp) { - baseUrl += `?key=${googleApiKey}`; - } + const baseUrl = `${googleUrl}/${GoogleMetas.ChatPath( + model, + )}?${urlParamApikeyName}=${googleApiKey}`; return { - headers: this.getHeaders(payload), + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, body: JSON.stringify(requestPayload), method: "POST", url: stream @@ -162,46 +172,15 @@ export default class GoogleProvider : baseUrl, }; } - private readWholeMessageResponseBody(res: any) { - if (res?.promptFeedback?.blockReason) { - // being blocked - throw new Error( - "Message is being blocked for reason: " + - res.promptFeedback.blockReason, - ); - } - return { - message: - res.candidates?.at(0)?.content?.parts?.at(0)?.text || - res.error?.message || - "", - }; - } - - private getTimer = () => { - const controller = new AbortController(); - - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - this.REQUEST_TIMEOUT_MS, - ); - - return { - ...controller, - clear: () => { - clearTimeout(requestTimeoutId); - }, - }; - }; streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); let existingTexts: string[] = []; @@ -274,15 +253,10 @@ export default class GoogleProvider async chat( payload: InternalChatRequestPayload, + fetch: typeof window.fetch, ): Promise { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); - - // make a fetch request - const requestTimeoutId = setTimeout( - () => timer.abort(), - this.REQUEST_TIMEOUT_MS, - ); + const timer = getTimer(); const res = await fetch(requestPayload.url, { headers: { @@ -293,10 +267,10 @@ export default class GoogleProvider signal: timer.signal, }); - clearTimeout(requestTimeoutId); + timer.clear(); const resJson = await res.json(); - const message = this.readWholeMessageResponseBody(resJson); + const message = parseResp(resJson); return message; } @@ -315,4 +289,65 @@ export default class GoogleProvider return data.models; } + + serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] = + async (req, serverConfig) => { + const { googleUrl = GEMINI_BASE_URL } = serverConfig; + + const controller = new AbortController(); + + const path = `${req.nextUrl.pathname}`.replaceAll( + this.apiRouteRootName, + "", + ); + + console.log("[Proxy] ", path); + console.log("[Base Url]", googleUrl); + + const authResult = auth(req, serverConfig); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const fetchUrl = `${googleUrl}/${path}?key=${authResult.apiKey}`; + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + try { + const res = await fetch(fetchUrl, fetchOptions); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } + }; } diff --git a/app/client/providers/google/utils.ts b/app/client/providers/google/utils.ts index 78258ef27..2d2528167 100644 --- a/app/client/providers/google/utils.ts +++ b/app/client/providers/google/utils.ts @@ -1,3 +1,10 @@ +import { NextRequest } from "next/server"; +import { ServerConfig, getIP } from "../../common"; + +export const urlParamApikeyName = "key"; + +export const REQUEST_TIMEOUT_MS = 60000; + export const makeBearer = (s: string) => `Bearer ${s.trim()}`; export const validString = (x?: string): x is string => Boolean(x && x.length > 0); @@ -8,3 +15,73 @@ export function ensureProperEnding(str: string) { } return str; } + +export function auth(req: NextRequest, serverConfig: ServerConfig) { + let apiKey = req.nextUrl.searchParams.get(urlParamApikeyName); + + const { hideUserApiKey, googleApiKey } = serverConfig; + + console.log("[User IP] ", getIP(req)); + console.log("[Time] ", new Date().toLocaleString()); + + if (hideUserApiKey && apiKey) { + return { + error: true, + message: "you are not allowed to access with your own api key", + }; + } + + if (apiKey) { + console.log("[Auth] use user api key"); + return { + error: false, + apiKey, + }; + } + + if (googleApiKey) { + console.log("[Auth] use system api key"); + return { + error: false, + apiKey: googleApiKey, + }; + } + + console.log("[Auth] admin did not provide an api key"); + return { + error: true, + message: `missing api key`, + }; +} + +export function getTimer() { + const controller = new AbortController(); + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + return { + ...controller, + clear: () => { + clearTimeout(requestTimeoutId); + }, + }; +} + +export function parseResp(res: any) { + if (res?.promptFeedback?.blockReason) { + // being blocked + throw new Error( + "Message is being blocked for reason: " + res.promptFeedback.blockReason, + ); + } + return { + message: + res.candidates?.at(0)?.content?.parts?.at(0)?.text || + res.error?.message || + "", + }; +} diff --git a/app/client/providers/nextchat/config.ts b/app/client/providers/nextchat/config.ts index a9eab766b..67852854e 100644 --- a/app/client/providers/nextchat/config.ts +++ b/app/client/providers/nextchat/config.ts @@ -2,6 +2,8 @@ import { SettingItem } from "../../common"; import { isVisionModel } from "@/app/utils"; import Locale from "@/app/locales"; +export const OPENAI_BASE_URL = "https://api.openai.com"; + export const NextChatMetas = { ChatPath: "v1/chat/completions", UsagePath: "dashboard/billing/usage", @@ -9,6 +11,26 @@ export const NextChatMetas = { ListModelPath: "v1/models", }; +export const preferredRegion: string | string[] = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + export type SettingKeys = "accessCode"; export const defaultModal = "gpt-3.5-turbo"; diff --git a/app/client/providers/nextchat/index.ts b/app/client/providers/nextchat/index.ts index a66001d95..5471a2796 100644 --- a/app/client/providers/nextchat/index.ts +++ b/app/client/providers/nextchat/index.ts @@ -3,13 +3,15 @@ import { settingItems, SettingKeys, NextChatMetas, + preferredRegion, + OPENAI_BASE_URL, } from "./config"; -import { ACCESS_CODE_PREFIX } from "@/app/constant"; import { ChatHandlers, getMessageTextContent, InternalChatRequestPayload, IProviderTemplate, + ServerConfig, StandChatReponseMessage, } from "../../common"; import { @@ -18,7 +20,8 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; -import { makeBearer, validString } from "./utils"; +import { auth, authHeaderName, getHeaders, getTimer, parseResp } from "./utils"; +import { NextRequest, NextResponse } from "next/server"; export type NextChatProviderSettingKeys = SettingKeys; @@ -52,9 +55,27 @@ interface RequestPayload { max_tokens?: number; } +type ProviderTemplate = IProviderTemplate< + SettingKeys, + "azure", + typeof NextChatMetas +>; + export default class NextChatProvider implements IProviderTemplate { + apiRouteRootName: "/api/provider/nextchat" = "/api/provider/nextchat"; + allowedApiMethods: ( + | "POST" + | "GET" + | "OPTIONS" + | "PUT" + | "PATCH" + | "DELETE" + )[] = ["GET", "POST"]; + + runtime = "edge" as const; + preferredRegion = preferredRegion; name = "nextchat" as const; metas = NextChatMetas; @@ -65,33 +86,6 @@ export default class NextChatProvider settingItems, }; - readonly REQUEST_TIMEOUT_MS = 60000; - - private path(): string { - const path = NextChatMetas.ChatPath; - - let baseUrl = "/api/openai"; - - console.log("[Proxy Endpoint] ", baseUrl, path); - - return [baseUrl, path].join("/"); - } - - private getHeaders(payload: InternalChatRequestPayload) { - const { accessCode } = payload.providerConfig; - - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json", - }; - - if (validString(accessCode)) { - headers["Authorization"] = makeBearer(ACCESS_CODE_PREFIX + accessCode); - } - - return headers; - } - private formatChatPayload(payload: InternalChatRequestPayload) { const { messages, isVisionModel, model, stream, modelConfig } = payload; const { @@ -125,46 +119,106 @@ export default class NextChatProvider console.log("[Request] openai payload: ", requestPayload); return { - headers: this.getHeaders(payload), + headers: getHeaders(payload.providerConfig.accessCode!), body: JSON.stringify(requestPayload), method: "POST", - url: this.path(), + url: [this.apiRouteRootName, NextChatMetas.ChatPath].join("/"), }; } - private readWholeMessageResponseBody(res: any) { - return { - message: res.choices?.at(0)?.message?.content ?? "", - }; - } - - private getTimer = () => { + private async requestOpenai(req: NextRequest, serverConfig: ServerConfig) { + const { baseUrl = OPENAI_BASE_URL, openaiOrgId } = serverConfig; const controller = new AbortController(); + const authValue = req.headers.get(authHeaderName) ?? ""; - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - this.REQUEST_TIMEOUT_MS, + const path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( + this.apiRouteRootName, + "", ); - return { - ...controller, - clear: () => { - clearTimeout(requestTimeoutId); + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}/${path}`; + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + [authHeaderName]: authValue, + ...(openaiOrgId && { + "OpenAI-Organization": openaiOrgId, + }), + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, }; - }; + + try { + const res = await fetch(fetchUrl, fetchOptions); + + // Extract the OpenAI-Organization header from the response + const openaiOrganizationHeader = res.headers.get("OpenAI-Organization"); + + // Check if serverConfig.openaiOrgId is defined and not an empty string + if (openaiOrgId && openaiOrgId.trim() !== "") { + // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present + console.log("[Org ID]", openaiOrganizationHeader); + } else { + console.log("[Org ID] is not set up."); + } + + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + // Conditionally delete the OpenAI-Organization header from the response if [Org ID] is undefined or empty (not setup in ENV) + // Also, this is to prevent the header from being sent to the client + if (!openaiOrgId || openaiOrgId.trim() === "") { + newHeaders.delete("OpenAI-Organization"); + } + + // The latest version of the OpenAI API forced the content-encoding to be "br" in json response + // So if the streaming is disabled, we need to remove the content-encoding header + // Because Vercel uses gzip to compress the response, if we don't remove the content-encoding header + // The browser will try to decode the response with brotli and fail + newHeaders.delete("content-encoding"); + + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } + } streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); fetchEventSource(requestPayload.url, { ...requestPayload, + fetch, async onopen(res) { timer.clear(); const contentType = res.headers.get("content-type"); @@ -236,10 +290,11 @@ export default class NextChatProvider async chat( payload: InternalChatRequestPayload<"accessCode">, + fetch: typeof window.fetch, ): Promise { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); const res = await fetch(requestPayload.url, { headers: { @@ -253,8 +308,41 @@ export default class NextChatProvider timer.clear(); const resJson = await res.json(); - const message = this.readWholeMessageResponseBody(resJson); + const message = parseResp(resJson); return message; } + + serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] = + async (req, config) => { + const { subpath } = req; + const ALLOWD_PATH = new Set(Object.values(NextChatMetas)); + + if (!ALLOWD_PATH.has(subpath)) { + return NextResponse.json( + { + error: true, + message: "you are not allowed to request " + subpath, + }, + { + status: 403, + }, + ); + } + + const authResult = auth(req, config); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await this.requestOpenai(req, config); + + return response; + } catch (e) { + return NextResponse.json(prettyObject(e)); + } + }; } diff --git a/app/client/providers/nextchat/utils.ts b/app/client/providers/nextchat/utils.ts index 24f6ef4f0..397fadfff 100644 --- a/app/client/providers/nextchat/utils.ts +++ b/app/client/providers/nextchat/utils.ts @@ -1,3 +1,13 @@ +import { NextRequest } from "next/server"; +import { ServerConfig, getIP } from "../../common"; +import md5 from "spark-md5"; + +export const ACCESS_CODE_PREFIX = "nk-"; + +export const REQUEST_TIMEOUT_MS = 60000; + +export const authHeaderName = "Authorization"; + export const makeBearer = (s: string) => `Bearer ${s.trim()}`; export const validString = (x?: string): x is string => @@ -16,3 +26,87 @@ export function prettyObject(msg: any) { } return ["```json", msg, "```"].join("\n"); } + +export function getTimer() { + const controller = new AbortController(); + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + return { + ...controller, + clear: () => { + clearTimeout(requestTimeoutId); + }, + }; +} + +export function getHeaders(accessCode: string) { + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json", + [authHeaderName]: makeBearer(ACCESS_CODE_PREFIX + accessCode), + }; + + return headers; +} + +export function parseResp(res: { choices: { message: { content: any } }[] }) { + return { + message: res.choices?.[0]?.message?.content ?? "", + }; +} + +function parseApiKey(req: NextRequest) { + const authToken = req.headers.get("Authorization") ?? ""; + + return { + accessCode: + authToken.startsWith(ACCESS_CODE_PREFIX) && + authToken.slice(ACCESS_CODE_PREFIX.length), + }; +} + +export function auth(req: NextRequest, serverConfig: ServerConfig) { + // check if it is openai api key or user token + const { accessCode } = parseApiKey(req); + const { googleApiKey, apiKey, anthropicApiKey, azureApiKey, codes } = + serverConfig; + + const hashedCode = md5.hash(accessCode || "").trim(); + + console.log("[Auth] allowed hashed codes: ", [...codes]); + console.log("[Auth] got access code:", accessCode); + console.log("[Auth] hashed access code:", hashedCode); + console.log("[User IP] ", getIP(req)); + console.log("[Time] ", new Date().toLocaleString()); + + if (!codes.has(hashedCode)) { + return { + error: true, + message: !accessCode ? "empty access code" : "wrong access code", + }; + } + + const systemApiKey = googleApiKey || apiKey || anthropicApiKey || azureApiKey; + + if (systemApiKey) { + console.log("[Auth] use system api key"); + + return { + error: false, + accessCode, + systemApiKey, + }; + } + + console.log("[Auth] admin did not provide an api key"); + + return { + error: true, + message: `Server internal error`, + }; +} diff --git a/app/client/providers/openai/config.ts b/app/client/providers/openai/config.ts index 60c7073d4..34fa487cd 100644 --- a/app/client/providers/openai/config.ts +++ b/app/client/providers/openai/config.ts @@ -5,6 +5,26 @@ export const OPENAI_BASE_URL = "https://api.openai.com"; export const ROLES = ["system", "user", "assistant"] as const; +export const preferredRegion: string | string[] = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + export const OpenaiMetas = { ChatPath: "v1/chat/completions", UsagePath: "dashboard/billing/usage", @@ -157,9 +177,9 @@ export const modelConfigs = [ }, ]; -const defaultEndpoint = "/api/openai"; - -export const settingItems: SettingItem[] = [ +export const settingItems: ( + defaultEndpoint: string, +) => SettingItem[] = (defaultEndpoint) => [ { name: "openaiUrl", title: Locale.Endpoint.Title, @@ -189,6 +209,6 @@ export const settingItems: SettingItem[] = [ placeholder: Locale.ApiKey.Placeholder, type: "input", inputType: "password", - validators: ["required"], + // validators: ["required"], }, ]; diff --git a/app/client/providers/openai/index.ts b/app/client/providers/openai/index.ts index 1d336494a..86df158f4 100644 --- a/app/client/providers/openai/index.ts +++ b/app/client/providers/openai/index.ts @@ -4,20 +4,32 @@ import { IProviderTemplate, ModelInfo, getMessageTextContent, + ServerConfig, } from "../../common"; import { EventStreamContentType, fetchEventSource, } from "@fortaine/fetch-event-source"; import Locale from "@/app/locales"; -import { makeBearer, validString, prettyObject } from "./utils"; +import { + authHeaderName, + prettyObject, + parseResp, + auth, + getTimer, + getHeaders, +} from "./utils"; import { modelConfigs, settingItems, SettingKeys, OpenaiMetas, ROLES, + OPENAI_BASE_URL, + preferredRegion, } from "./config"; +import { NextRequest, NextResponse } from "next/server"; +import { ModelList } from "./type"; export type OpenAIProviderSettingKeys = SettingKeys; @@ -49,66 +61,54 @@ interface RequestPayload { max_tokens?: number; } -interface ModelList { - object: "list"; - data: Array<{ - id: string; - object: "model"; - created: number; - owned_by: "system" | "openai-internal"; - }>; -} +type ProviderTemplate = IProviderTemplate< + SettingKeys, + "azure", + typeof OpenaiMetas +>; class OpenAIProvider implements IProviderTemplate { + apiRouteRootName: "/api/provider/openai" = "/api/provider/openai"; + allowedApiMethods: ( + | "POST" + | "GET" + | "OPTIONS" + | "PUT" + | "PATCH" + | "DELETE" + )[] = ["GET", "POST"]; + runtime = "edge" as const; + preferredRegion = preferredRegion; + name = "openai" as const; metas = OpenaiMetas; - readonly REQUEST_TIMEOUT_MS = 60000; - defaultModels = modelConfigs; providerMeta = { displayName: "OpenAI", - settingItems, + settingItems: settingItems( + `${this.apiRouteRootName}/${OpenaiMetas.ChatPath}`, + ), }; - private path(payload: InternalChatRequestPayload): string { + private formatChatPayload(payload: InternalChatRequestPayload) { const { + messages, + isVisionModel, + model, + stream, + modelConfig: { + temperature, + presence_penalty, + frequency_penalty, + top_p, + max_tokens, + }, providerConfig: { openaiUrl }, } = payload; - const path = OpenaiMetas.ChatPath; - - console.log("[Proxy Endpoint] ", openaiUrl, path); - - return [openaiUrl, path].join("/"); - } - - private getHeaders(payload: InternalChatRequestPayload) { - const { openaiApiKey } = payload.providerConfig; - - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json", - }; - - if (validString(openaiApiKey)) { - headers["Authorization"] = makeBearer(openaiApiKey); - } - - return headers; - } - - private formatChatPayload(payload: InternalChatRequestPayload) { - const { messages, isVisionModel, model, stream, modelConfig } = payload; - const { - temperature, - presence_penalty, - frequency_penalty, - top_p, - max_tokens, - } = modelConfig; const openAiMessages = messages.map((v) => ({ role: v.role, @@ -133,42 +133,101 @@ class OpenAIProvider console.log("[Request] openai payload: ", requestPayload); return { - headers: this.getHeaders(payload), + headers: getHeaders(payload.providerConfig.openaiApiKey), body: JSON.stringify(requestPayload), method: "POST", - url: this.path(payload), + url: openaiUrl!, }; } - private readWholeMessageResponseBody(res: { - choices: { message: { content: any } }[]; - }) { - return { - message: res.choices?.[0]?.message?.content ?? "", - }; - } - - private getTimer = () => { + private async requestOpenai(req: NextRequest, serverConfig: ServerConfig) { + const { baseUrl = OPENAI_BASE_URL, openaiOrgId } = serverConfig; const controller = new AbortController(); + const authValue = req.headers.get(authHeaderName) ?? ""; - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - this.REQUEST_TIMEOUT_MS, + const path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( + this.apiRouteRootName, + "", ); - return { - ...controller, - clear: () => { - clearTimeout(requestTimeoutId); - }, - }; - }; + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); - async chat(payload: InternalChatRequestPayload) { + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}/${path}`; + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-store", + [authHeaderName]: authValue, + ...(openaiOrgId && { + "OpenAI-Organization": openaiOrgId, + }), + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + try { + const res = await fetch(fetchUrl, fetchOptions); + + // Extract the OpenAI-Organization header from the response + const openaiOrganizationHeader = res.headers.get("OpenAI-Organization"); + + // Check if serverConfig.openaiOrgId is defined and not an empty string + if (openaiOrgId && openaiOrgId.trim() !== "") { + // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present + console.log("[Org ID]", openaiOrganizationHeader); + } else { + console.log("[Org ID] is not set up."); + } + + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + // Conditionally delete the OpenAI-Organization header from the response if [Org ID] is undefined or empty (not setup in ENV) + // Also, this is to prevent the header from being sent to the client + if (!openaiOrgId || openaiOrgId.trim() === "") { + newHeaders.delete("OpenAI-Organization"); + } + + // The latest version of the OpenAI API forced the content-encoding to be "br" in json response + // So if the streaming is disabled, we need to remove the content-encoding header + // Because Vercel uses gzip to compress the response, if we don't remove the content-encoding header + // The browser will try to decode the response with brotli and fail + newHeaders.delete("content-encoding"); + + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } + } + + async chat( + payload: InternalChatRequestPayload, + fetch: typeof window.fetch, + ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); const res = await fetch(requestPayload.url, { headers: { @@ -182,7 +241,7 @@ class OpenAIProvider timer.clear(); const resJson = await res.json(); - const message = this.readWholeMessageResponseBody(resJson); + const message = parseResp(resJson); return message; } @@ -190,13 +249,15 @@ class OpenAIProvider streamChat( payload: InternalChatRequestPayload, handlers: ChatHandlers, + fetch: typeof window.fetch, ) { const requestPayload = this.formatChatPayload(payload); - const timer = this.getTimer(); + const timer = getTimer(); fetchEventSource(requestPayload.url, { ...requestPayload, + fetch, async onopen(res) { timer.clear(); const contentType = res.headers.get("content-type"); @@ -270,7 +331,7 @@ class OpenAIProvider providerConfig: Record, ): Promise { const { openaiApiKey, openaiUrl } = providerConfig; - const res = await fetch(`${openaiUrl}/vi/models`, { + const res = await fetch(`${openaiUrl}/v1/models`, { headers: { Authorization: `Bearer ${openaiApiKey}`, }, @@ -282,6 +343,39 @@ class OpenAIProvider name: o.id, })); } + + serverSideRequestHandler: ProviderTemplate["serverSideRequestHandler"] = + async (req, config) => { + const { subpath } = req; + const ALLOWD_PATH = new Set(Object.values(OpenaiMetas)); + + if (!ALLOWD_PATH.has(subpath)) { + return NextResponse.json( + { + error: true, + message: "you are not allowed to request " + subpath, + }, + { + status: 403, + }, + ); + } + + const authResult = auth(req, config); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await this.requestOpenai(req, config); + + return response; + } catch (e) { + return NextResponse.json(prettyObject(e)); + } + }; } export default OpenAIProvider; diff --git a/app/client/providers/openai/type.ts b/app/client/providers/openai/type.ts new file mode 100644 index 000000000..792ba844f --- /dev/null +++ b/app/client/providers/openai/type.ts @@ -0,0 +1,18 @@ +export interface ModelList { + object: "list"; + data: Array<{ + id: string; + object: "model"; + created: number; + owned_by: "system" | "openai-internal"; + }>; +} + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} diff --git a/app/client/providers/openai/utils.ts b/app/client/providers/openai/utils.ts index 24f6ef4f0..cb14b95a7 100644 --- a/app/client/providers/openai/utils.ts +++ b/app/client/providers/openai/utils.ts @@ -1,7 +1,21 @@ -export const makeBearer = (s: string) => `Bearer ${s.trim()}`; +import { NextRequest } from "next/server"; +import { ServerConfig, getIP } from "../../common"; -export const validString = (x?: string): x is string => - Boolean(x && x.length > 0); +export const REQUEST_TIMEOUT_MS = 60000; + +export const authHeaderName = "Authorization"; + +const makeBearer = (s: string) => `Bearer ${s.trim()}`; + +const validString = (x?: string): x is string => Boolean(x && x.length > 0); + +function parseApiKey(bearToken: string) { + const token = bearToken.trim().replaceAll("Bearer ", "").trim(); + + return { + apiKey: token, + }; +} export function prettyObject(msg: any) { const obj = msg; @@ -16,3 +30,74 @@ export function prettyObject(msg: any) { } return ["```json", msg, "```"].join("\n"); } + +export function parseResp(res: { choices: { message: { content: any } }[] }) { + return { + message: res.choices?.[0]?.message?.content ?? "", + }; +} + +export function auth(req: NextRequest, serverConfig: ServerConfig) { + const { hideUserApiKey, apiKey: systemApiKey } = serverConfig; + const authToken = req.headers.get(authHeaderName) ?? ""; + + const { apiKey } = parseApiKey(authToken); + + console.log("[User IP] ", getIP(req)); + console.log("[Time] ", new Date().toLocaleString()); + + if (hideUserApiKey && apiKey) { + return { + error: true, + message: "you are not allowed to access with your own api key", + }; + } + + if (apiKey) { + console.log("[Auth] use user api key"); + return { + error: false, + }; + } + + if (systemApiKey) { + console.log("[Auth] use system api key"); + req.headers.set(authHeaderName, `Bearer ${systemApiKey}`); + } else { + console.log("[Auth] admin did not provide an api key"); + } + + return { + error: false, + }; +} + +export function getTimer() { + const controller = new AbortController(); + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + return { + ...controller, + clear: () => { + clearTimeout(requestTimeoutId); + }, + }; +} + +export function getHeaders(openaiApiKey?: string) { + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json", + }; + + if (validString(openaiApiKey)) { + headers[authHeaderName] = makeBearer(openaiApiKey); + } + + return headers; +} diff --git a/app/config/server.ts b/app/config/server.ts index b7c85ce6a..5969768bb 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -55,7 +55,10 @@ const ACCESS_CODES = (function getAccessCodes(): Set { })(); function getApiKey(keys?: string) { - const apiKeyEnvVar = keys ?? ""; + if (!keys) { + return; + } + const apiKeyEnvVar = keys; const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); const randomIndex = Math.floor(Math.random() * apiKeys.length); const apiKey = apiKeys[randomIndex];