From 2b0153807cf6294ea7b8bce9f2f4b58a71c94be4 Mon Sep 17 00:00:00 2001
From: licoy
Date: Tue, 9 Jul 2024 09:50:04 +0800
Subject: [PATCH] feat: Add Stability API server relay sending
---
app/api/auth.ts | 3 +
app/api/stability/[...path]/route.ts | 104 +++++++++++++++++++++++++++
app/components/sd-panel.tsx | 18 ++---
app/components/sd.tsx | 83 +++++++++++++++++----
app/components/ui-lib.tsx | 5 +-
app/config/server.ts | 3 +
app/constant.ts | 8 +++
app/locales/cn.ts | 2 +
app/locales/en.ts | 2 +
app/store/sd.ts | 23 +++++-
10 files changed, 220 insertions(+), 31 deletions(-)
create mode 100644 app/api/stability/[...path]/route.ts
diff --git a/app/api/auth.ts b/app/api/auth.ts
index b750f2d17..4162ec2d0 100644
--- a/app/api/auth.ts
+++ b/app/api/auth.ts
@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.Claude:
systemApiKey = serverConfig.anthropicApiKey;
break;
+ case ModelProvider.Stability:
+ systemApiKey = serverConfig.stabilityApiKey;
+ break;
case ModelProvider.GPT:
default:
if (serverConfig.isAzure) {
diff --git a/app/api/stability/[...path]/route.ts b/app/api/stability/[...path]/route.ts
new file mode 100644
index 000000000..4b2bcc305
--- /dev/null
+++ b/app/api/stability/[...path]/route.ts
@@ -0,0 +1,104 @@
+import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "@/app/config/server";
+import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
+import { auth } from "@/app/api/auth";
+
+async function handle(
+ req: NextRequest,
+ { params }: { params: { path: string[] } },
+) {
+ console.log("[Stability] params ", params);
+
+ if (req.method === "OPTIONS") {
+ return NextResponse.json({ body: "OK" }, { status: 200 });
+ }
+
+ const controller = new AbortController();
+
+ const serverConfig = getServerSideConfig();
+
+ let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;
+
+ if (!baseUrl.startsWith("http")) {
+ baseUrl = `https://${baseUrl}`;
+ }
+
+ if (baseUrl.endsWith("/")) {
+ baseUrl = baseUrl.slice(0, -1);
+ }
+
+ let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");
+
+ console.log("[Stability Proxy] ", path);
+ console.log("[Stability Base Url]", baseUrl);
+
+ const timeoutId = setTimeout(
+ () => {
+ controller.abort();
+ },
+ 10 * 60 * 1000,
+ );
+
+ const authResult = auth(req, ModelProvider.Stability);
+
+ if (authResult.error) {
+ return NextResponse.json(authResult, {
+ status: 401,
+ });
+ }
+
+ const bearToken = req.headers.get("Authorization") ?? "";
+ const token = bearToken.trim().replaceAll("Bearer ", "").trim();
+
+ const key = token ? token : serverConfig.stabilityApiKey;
+
+ if (!key) {
+ return NextResponse.json(
+ {
+ error: true,
+ message: `missing STABILITY_API_KEY in server env vars`,
+ },
+ {
+ status: 401,
+ },
+ );
+ }
+
+ const fetchUrl = `${baseUrl}/${path}`;
+ console.log("[Stability Url] ", fetchUrl);
+ const fetchOptions: RequestInit = {
+ headers: {
+ "Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
+ Accept: req.headers.get("Accept") || "application/json",
+ Authorization: `Bearer ${key}`,
+ },
+ method: req.method,
+ body: req.body,
+ // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
+ redirect: "manual",
+ // @ts-ignore
+ duplex: "half",
+ signal: controller.signal,
+ };
+
+ try {
+ const res = await fetch(fetchUrl, fetchOptions);
+ // to prevent browser prompt for credentials
+ const newHeaders = new Headers(res.headers);
+ newHeaders.delete("www-authenticate");
+ // to disable nginx buffering
+ newHeaders.set("X-Accel-Buffering", "no");
+ return new Response(res.body, {
+ status: res.status,
+ statusText: res.statusText,
+ headers: newHeaders,
+ });
+ } finally {
+ clearTimeout(timeoutId);
+ }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";
diff --git a/app/components/sd-panel.tsx b/app/components/sd-panel.tsx
index 8e1159cf9..043030152 100644
--- a/app/components/sd-panel.tsx
+++ b/app/components/sd-panel.tsx
@@ -307,22 +307,12 @@ export function SdPanel() {
model_name: currentModel.name,
status: "wait",
params: reqParams,
- created_at: new Date().toISOString(),
+ created_at: new Date().toLocaleString(),
img_data: "",
};
- sdListDb.add(data).then(
- (id) => {
- data = { ...data, id, status: "running" };
- sdListDb.update(data);
- execCountInc();
- sendSdTask(data, sdListDb, execCountInc);
- setParams(getModelParamBasicData(columns, params, true));
- },
- (error) => {
- console.error(error);
- showToast(`error: ` + error.message);
- },
- );
+ sendSdTask(data, sdListDb, execCountInc, () => {
+ setParams(getModelParamBasicData(columns, params, true));
+ });
};
return (
<>
diff --git a/app/components/sd.tsx b/app/components/sd.tsx
index a35d6a60e..6f3acde96 100644
--- a/app/components/sd.tsx
+++ b/app/components/sd.tsx
@@ -9,7 +9,6 @@ import {
copyToClipboard,
getMessageTextContent,
useMobileScreen,
- useWindowSize,
} from "@/app/utils";
import { useNavigate } from "react-router-dom";
import { useAppConfig } from "@/app/store";
@@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg";
import PromptIcon from "@/app/icons/prompt.svg";
import ResetIcon from "@/app/icons/reload.svg";
import { useIndexedDB } from "react-indexed-db-hook";
-import { useSdStore } from "@/app/store/sd";
+import { sendSdTask, useSdStore } from "@/app/store/sd";
import locales from "@/app/locales";
import LoadingIcon from "../icons/three-dots.svg";
import ErrorIcon from "../icons/delete.svg";
import { Property } from "csstype";
-import { showConfirm } from "@/app/components/ui-lib";
+import {
+ showConfirm,
+ showImageModal,
+ showModal,
+} from "@/app/components/ui-lib";
-function openBase64ImgUrl(base64Data: string, contentType: string) {
+function getBase64ImgUrl(base64Data: string, contentType: string) {
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
@@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) {
}
const byteArray = new Uint8Array(byteNumbers);
const blob = new Blob([byteArray], { type: contentType });
- const blobUrl = URL.createObjectURL(blob);
- window.open(blobUrl);
+ return URL.createObjectURL(blob);
}
function getSdTaskStatus(item: any) {
@@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) {
{locales.Sd.Status.Name}: {s}
- {item.status === "error" && - {item.error}}
+ {item.status === "error" && (
+ {
+ showModal({
+ title: locales.Sd.Detail,
+ children: (
+
+ {item.error}
+
+ ),
+ });
+ }}
+ >
+ {" "}
+ - {item.error}
+
+ )}
);
}
@@ -83,7 +102,7 @@ export function Sd() {
const scrollRef = useRef(null);
const sdListDb = useIndexedDB(StoreKey.SdList);
const [sdImages, setSdImages] = useState([]);
- const { execCount } = useSdStore();
+ const { execCount, execCountInc } = useSdStore();
useEffect(() => {
sdListDb.getAll().then((data) => {
@@ -145,7 +164,10 @@ export function Sd() {
src={`data:image/png;base64,${item.img_data}`}
alt={`${item.id}`}
onClick={(e) => {
- openBase64ImgUrl(item.img_data, "image/png");
+ showImageModal(
+ getBase64ImgUrl(item.img_data, "image/png"),
+ true,
+ );
}}
/>
) : item.status === "error" ? (
@@ -163,7 +185,20 @@ export function Sd() {
>
{locales.SdPanel.Prompt}:{" "}
-
+ {
+ showModal({
+ title: locales.Sd.Detail,
+ children: (
+
+ {item.params.prompt}
+
+ ),
+ });
+ }}
+ >
{item.params.prompt}
@@ -177,7 +212,21 @@ export function Sd() {
}
- onClick={() => console.log(1)}
+ onClick={() => {
+ showModal({
+ title: locales.Sd.GenerateParams,
+ children: (
+
+ {Object.keys(item.params).map((key) => (
+
+ {key}:
+ {item.params[key]}
+
+ ))}
+
+ ),
+ });
+ }}
/>
}
- onClick={() => console.log(1)}
+ onClick={() => {
+ const reqData = {
+ model: item.model,
+ model_name: item.model_name,
+ status: "wait",
+ params: { ...item.params },
+ created_at: new Date().toLocaleString(),
+ img_data: "",
+ };
+ sendSdTask(reqData, sdListDb, execCountInc);
+ }}
/>
+

{
anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
anthropicUrl: process.env.ANTHROPIC_URL,
+ stabilityUrl: process.env.STABILITY_URL,
+ stabilityApiKey: process.env.STABILITY_API_KEY,
+
gtmId: process.env.GTM_ID,
needCode: ACCESS_CODES.size > 0,
diff --git a/app/constant.ts b/app/constant.ts
index 29bda6b7f..4a6043e30 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -1,3 +1,5 @@
+import { stabilityRequestCall } from "@/app/store/sd";
+
export const OWNER = "Yidadaa";
export const REPO = "ChatGPT-Next-Web";
export const REPO_URL = `https://github.com/${OWNER}/${REPO}`;
@@ -13,6 +15,7 @@ export const OPENAI_BASE_URL = "https://api.openai.com";
export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
+export const STABILITY_BASE_URL = "https://api.stability.ai";
export enum Path {
Home = "/",
@@ -79,6 +82,7 @@ export enum ModelProvider {
GPT = "GPT",
GeminiPro = "GeminiPro",
Claude = "Claude",
+ Stability = "Stability",
}
export const Anthropic = {
@@ -104,6 +108,10 @@ export const Google = {
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
};
+export const StabilityPath = {
+ GeneratePath: "v2beta/stable-image/generate",
+};
+
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
// export const DEFAULT_SYSTEM_TEMPLATE = `
// You are ChatGPT, a large language model trained by {{ServiceProvider}}.
diff --git a/app/locales/cn.ts b/app/locales/cn.ts
index 22ff94205..b01ea9d27 100644
--- a/app/locales/cn.ts
+++ b/app/locales/cn.ts
@@ -534,6 +534,8 @@ const cn = {
Danger: {
Delete: "确认删除?",
},
+ GenerateParams: "生成参数",
+ Detail: "详情",
},
};
diff --git a/app/locales/en.ts b/app/locales/en.ts
index a47b1ef70..da2224f6a 100644
--- a/app/locales/en.ts
+++ b/app/locales/en.ts
@@ -540,6 +540,8 @@ const en: LocaleType = {
Danger: {
Delete: "Confirm to delete?",
},
+ GenerateParams: "Generate Params",
+ Detail: "Detail",
},
};
diff --git a/app/store/sd.ts b/app/store/sd.ts
index b811e0add..77d31d743 100644
--- a/app/store/sd.ts
+++ b/app/store/sd.ts
@@ -1,6 +1,7 @@
import { initDB, useIndexedDB } from "react-indexed-db-hook";
-import { StoreKey } from "@/app/constant";
+import { StabilityPath, StoreKey } from "@/app/constant";
import { create, StoreApi } from "zustand";
+import { showToast } from "@/app/components/ui-lib";
export const SdDbConfig = {
name: "@chatgpt-next-web/sd",
@@ -44,12 +45,28 @@ export const useSdStore = create
()((set) => ({
execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })),
}));
-export function sendSdTask(data: any, db: any, inc: any) {
+export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) {
+ db.add(data).then(
+ (id: number) => {
+ data = { ...data, id, status: "running" };
+ db.update(data);
+ inc();
+ stabilityRequestCall(data, db, inc);
+ okCall?.();
+ },
+ (error: any) => {
+ console.error(error);
+ showToast(`error: ` + error.message);
+ },
+ );
+}
+
+export function stabilityRequestCall(data: any, db: any, inc: any) {
const formData = new FormData();
for (let paramsKey in data.params) {
formData.append(paramsKey, data.params[paramsKey]);
}
- fetch("https://api.stability.ai/v2beta/stable-image/generate/" + data.model, {
+ fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
method: "POST",
headers: {
Accept: "application/json",