From 2b0153807cf6294ea7b8bce9f2f4b58a71c94be4 Mon Sep 17 00:00:00 2001 From: licoy Date: Tue, 9 Jul 2024 09:50:04 +0800 Subject: [PATCH] feat: Add Stability API server relay sending --- app/api/auth.ts | 3 + app/api/stability/[...path]/route.ts | 104 +++++++++++++++++++++++++++ app/components/sd-panel.tsx | 18 ++--- app/components/sd.tsx | 83 +++++++++++++++++---- app/components/ui-lib.tsx | 5 +- app/config/server.ts | 3 + app/constant.ts | 8 +++ app/locales/cn.ts | 2 + app/locales/en.ts | 2 + app/store/sd.ts | 23 +++++- 10 files changed, 220 insertions(+), 31 deletions(-) create mode 100644 app/api/stability/[...path]/route.ts diff --git a/app/api/auth.ts b/app/api/auth.ts index b750f2d17..4162ec2d0 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Stability: + systemApiKey = serverConfig.stabilityApiKey; + break; case ModelProvider.GPT: default: if (serverConfig.isAzure) { diff --git a/app/api/stability/[...path]/route.ts b/app/api/stability/[...path]/route.ts new file mode 100644 index 000000000..4b2bcc305 --- /dev/null +++ b/app/api/stability/[...path]/route.ts @@ -0,0 +1,104 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getServerSideConfig } from "@/app/config/server"; +import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant"; +import { auth } from "@/app/api/auth"; + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Stability] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const controller = new AbortController(); + + const serverConfig = getServerSideConfig(); + + let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", ""); + + console.log("[Stability Proxy] ", path); + console.log("[Stability Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const authResult = auth(req, ModelProvider.Stability); + + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const bearToken = req.headers.get("Authorization") ?? ""; + const token = bearToken.trim().replaceAll("Bearer ", "").trim(); + + const key = token ? token : serverConfig.stabilityApiKey; + + if (!key) { + return NextResponse.json( + { + error: true, + message: `missing STABILITY_API_KEY in server env vars`, + }, + { + status: 401, + }, + ); + } + + const fetchUrl = `${baseUrl}/${path}`; + console.log("[Stability Url] ", fetchUrl); + const fetchOptions: RequestInit = { + headers: { + "Content-Type": req.headers.get("Content-Type") || "multipart/form-data", + Accept: req.headers.get("Accept") || "application/json", + Authorization: `Bearer ${key}`, + }, + method: req.method, + body: req.body, + // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + try { + const res = await fetch(fetchUrl, fetchOptions); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; diff --git a/app/components/sd-panel.tsx b/app/components/sd-panel.tsx index 8e1159cf9..043030152 100644 --- a/app/components/sd-panel.tsx +++ b/app/components/sd-panel.tsx @@ -307,22 +307,12 @@ export function SdPanel() { model_name: currentModel.name, status: "wait", params: reqParams, - created_at: new Date().toISOString(), + created_at: new Date().toLocaleString(), img_data: "", }; - sdListDb.add(data).then( - (id) => { - data = { ...data, id, status: "running" }; - sdListDb.update(data); - execCountInc(); - sendSdTask(data, sdListDb, execCountInc); - setParams(getModelParamBasicData(columns, params, true)); - }, - (error) => { - console.error(error); - showToast(`error: ` + error.message); - }, - ); + sendSdTask(data, sdListDb, execCountInc, () => { + setParams(getModelParamBasicData(columns, params, true)); + }); }; return ( <> diff --git a/app/components/sd.tsx b/app/components/sd.tsx index a35d6a60e..6f3acde96 100644 --- a/app/components/sd.tsx +++ b/app/components/sd.tsx @@ -9,7 +9,6 @@ import { copyToClipboard, getMessageTextContent, useMobileScreen, - useWindowSize, } from "@/app/utils"; import { useNavigate } from "react-router-dom"; import { useAppConfig } from "@/app/store"; @@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg"; import PromptIcon from "@/app/icons/prompt.svg"; import ResetIcon from "@/app/icons/reload.svg"; import { useIndexedDB } from "react-indexed-db-hook"; -import { useSdStore } from "@/app/store/sd"; +import { sendSdTask, useSdStore } from "@/app/store/sd"; import locales from "@/app/locales"; import LoadingIcon from "../icons/three-dots.svg"; import ErrorIcon from "../icons/delete.svg"; import { Property } from "csstype"; -import { showConfirm } from "@/app/components/ui-lib"; +import { + showConfirm, + showImageModal, + showModal, +} from "@/app/components/ui-lib"; -function openBase64ImgUrl(base64Data: string, contentType: string) { +function getBase64ImgUrl(base64Data: string, contentType: string) { const byteCharacters = atob(base64Data); const byteNumbers = new Array(byteCharacters.length); for (let i = 0; i < byteCharacters.length; i++) { @@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) { } const byteArray = new Uint8Array(byteNumbers); const blob = new Blob([byteArray], { type: contentType }); - const blobUrl = URL.createObjectURL(blob); - window.open(blobUrl); + return URL.createObjectURL(blob); } function getSdTaskStatus(item: any) { @@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) { {locales.Sd.Status.Name}: {s} - {item.status === "error" && - {item.error}} + {item.status === "error" && ( + { + showModal({ + title: locales.Sd.Detail, + children: ( +
+ {item.error} +
+ ), + }); + }} + > + {" "} + - {item.error} +
+ )}

); } @@ -83,7 +102,7 @@ export function Sd() { const scrollRef = useRef(null); const sdListDb = useIndexedDB(StoreKey.SdList); const [sdImages, setSdImages] = useState([]); - const { execCount } = useSdStore(); + const { execCount, execCountInc } = useSdStore(); useEffect(() => { sdListDb.getAll().then((data) => { @@ -145,7 +164,10 @@ export function Sd() { src={`data:image/png;base64,${item.img_data}`} alt={`${item.id}`} onClick={(e) => { - openBase64ImgUrl(item.img_data, "image/png"); + showImageModal( + getBase64ImgUrl(item.img_data, "image/png"), + true, + ); }} /> ) : item.status === "error" ? ( @@ -163,7 +185,20 @@ export function Sd() { >

{locales.SdPanel.Prompt}:{" "} - + { + showModal({ + title: locales.Sd.Detail, + children: ( +

+ {item.params.prompt} +
+ ), + }); + }} + > {item.params.prompt}

@@ -177,7 +212,21 @@ export function Sd() { } - onClick={() => console.log(1)} + onClick={() => { + showModal({ + title: locales.Sd.GenerateParams, + children: ( +
+ {Object.keys(item.params).map((key) => ( +
+ {key}: + {item.params[key]} +
+ ))} +
+ ), + }); + }} /> } - onClick={() => console.log(1)} + onClick={() => { + const reqData = { + model: item.model, + model_name: item.model_name, + status: "wait", + params: { ...item.params }, + created_at: new Date().toLocaleString(), + img_data: "", + }; + sendSdTask(reqData, sdListDb, execCountInc); + }} /> +
preview { anthropicApiVersion: process.env.ANTHROPIC_API_VERSION, anthropicUrl: process.env.ANTHROPIC_URL, + stabilityUrl: process.env.STABILITY_URL, + stabilityApiKey: process.env.STABILITY_API_KEY, + gtmId: process.env.GTM_ID, needCode: ACCESS_CODES.size > 0, diff --git a/app/constant.ts b/app/constant.ts index 29bda6b7f..4a6043e30 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -1,3 +1,5 @@ +import { stabilityRequestCall } from "@/app/store/sd"; + export const OWNER = "Yidadaa"; export const REPO = "ChatGPT-Next-Web"; export const REPO_URL = `https://github.com/${OWNER}/${REPO}`; @@ -13,6 +15,7 @@ export const OPENAI_BASE_URL = "https://api.openai.com"; export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; +export const STABILITY_BASE_URL = "https://api.stability.ai"; export enum Path { Home = "/", @@ -79,6 +82,7 @@ export enum ModelProvider { GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", + Stability = "Stability", } export const Anthropic = { @@ -104,6 +108,10 @@ export const Google = { ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; +export const StabilityPath = { + GeneratePath: "v2beta/stable-image/generate", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 22ff94205..b01ea9d27 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -534,6 +534,8 @@ const cn = { Danger: { Delete: "确认删除?", }, + GenerateParams: "生成参数", + Detail: "详情", }, }; diff --git a/app/locales/en.ts b/app/locales/en.ts index a47b1ef70..da2224f6a 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -540,6 +540,8 @@ const en: LocaleType = { Danger: { Delete: "Confirm to delete?", }, + GenerateParams: "Generate Params", + Detail: "Detail", }, }; diff --git a/app/store/sd.ts b/app/store/sd.ts index b811e0add..77d31d743 100644 --- a/app/store/sd.ts +++ b/app/store/sd.ts @@ -1,6 +1,7 @@ import { initDB, useIndexedDB } from "react-indexed-db-hook"; -import { StoreKey } from "@/app/constant"; +import { StabilityPath, StoreKey } from "@/app/constant"; import { create, StoreApi } from "zustand"; +import { showToast } from "@/app/components/ui-lib"; export const SdDbConfig = { name: "@chatgpt-next-web/sd", @@ -44,12 +45,28 @@ export const useSdStore = create()((set) => ({ execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })), })); -export function sendSdTask(data: any, db: any, inc: any) { +export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) { + db.add(data).then( + (id: number) => { + data = { ...data, id, status: "running" }; + db.update(data); + inc(); + stabilityRequestCall(data, db, inc); + okCall?.(); + }, + (error: any) => { + console.error(error); + showToast(`error: ` + error.message); + }, + ); +} + +export function stabilityRequestCall(data: any, db: any, inc: any) { const formData = new FormData(); for (let paramsKey in data.params) { formData.append(paramsKey, data.params[paramsKey]); } - fetch("https://api.stability.ai/v2beta/stable-image/generate/" + data.model, { + fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, { method: "POST", headers: { Accept: "application/json",