diff --git a/app/api/auth.ts b/app/api/auth.ts
index b750f2d17..4162ec2d0 100644
--- a/app/api/auth.ts
+++ b/app/api/auth.ts
@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.Claude:
systemApiKey = serverConfig.anthropicApiKey;
break;
+ case ModelProvider.Stability:
+ systemApiKey = serverConfig.stabilityApiKey;
+ break;
case ModelProvider.GPT:
default:
if (serverConfig.isAzure) {
diff --git a/app/api/stability/[...path]/route.ts b/app/api/stability/[...path]/route.ts
new file mode 100644
index 000000000..4b2bcc305
--- /dev/null
+++ b/app/api/stability/[...path]/route.ts
@@ -0,0 +1,104 @@
+import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "@/app/config/server";
+import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
+import { auth } from "@/app/api/auth";
+
+async function handle(
+ req: NextRequest,
+ { params }: { params: { path: string[] } },
+) {
+ console.log("[Stability] params ", params);
+
+ if (req.method === "OPTIONS") {
+ return NextResponse.json({ body: "OK" }, { status: 200 });
+ }
+
+ const controller = new AbortController();
+
+ const serverConfig = getServerSideConfig();
+
+ let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;
+
+ if (!baseUrl.startsWith("http")) {
+ baseUrl = `https://${baseUrl}`;
+ }
+
+ if (baseUrl.endsWith("/")) {
+ baseUrl = baseUrl.slice(0, -1);
+ }
+
+ let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");
+
+ console.log("[Stability Proxy] ", path);
+ console.log("[Stability Base Url]", baseUrl);
+
+ const timeoutId = setTimeout(
+ () => {
+ controller.abort();
+ },
+ 10 * 60 * 1000,
+ );
+
+ const authResult = auth(req, ModelProvider.Stability);
+
+ if (authResult.error) {
+ return NextResponse.json(authResult, {
+ status: 401,
+ });
+ }
+
+ const bearToken = req.headers.get("Authorization") ?? "";
+ const token = bearToken.trim().replaceAll("Bearer ", "").trim();
+
+ const key = token ? token : serverConfig.stabilityApiKey;
+
+ if (!key) {
+ return NextResponse.json(
+ {
+ error: true,
+ message: `missing STABILITY_API_KEY in server env vars`,
+ },
+ {
+ status: 401,
+ },
+ );
+ }
+
+ const fetchUrl = `${baseUrl}/${path}`;
+ console.log("[Stability Url] ", fetchUrl);
+ const fetchOptions: RequestInit = {
+ headers: {
+ "Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
+ Accept: req.headers.get("Accept") || "application/json",
+ Authorization: `Bearer ${key}`,
+ },
+ method: req.method,
+ body: req.body,
+ // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
+ redirect: "manual",
+ // @ts-ignore
+ duplex: "half",
+ signal: controller.signal,
+ };
+
+ try {
+ const res = await fetch(fetchUrl, fetchOptions);
+ // to prevent browser prompt for credentials
+ const newHeaders = new Headers(res.headers);
+ newHeaders.delete("www-authenticate");
+ // to disable nginx buffering
+ newHeaders.set("X-Accel-Buffering", "no");
+ return new Response(res.body, {
+ status: res.status,
+ statusText: res.statusText,
+ headers: newHeaders,
+ });
+ } finally {
+ clearTimeout(timeoutId);
+ }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";
diff --git a/app/components/sd-panel.tsx b/app/components/sd-panel.tsx
index 8e1159cf9..043030152 100644
--- a/app/components/sd-panel.tsx
+++ b/app/components/sd-panel.tsx
@@ -307,22 +307,12 @@ export function SdPanel() {
model_name: currentModel.name,
status: "wait",
params: reqParams,
- created_at: new Date().toISOString(),
+ created_at: new Date().toLocaleString(),
img_data: "",
};
- sdListDb.add(data).then(
- (id) => {
- data = { ...data, id, status: "running" };
- sdListDb.update(data);
- execCountInc();
- sendSdTask(data, sdListDb, execCountInc);
- setParams(getModelParamBasicData(columns, params, true));
- },
- (error) => {
- console.error(error);
- showToast(`error: ` + error.message);
- },
- );
+ sendSdTask(data, sdListDb, execCountInc, () => {
+ setParams(getModelParamBasicData(columns, params, true));
+ });
};
return (
<>
diff --git a/app/components/sd.tsx b/app/components/sd.tsx
index a35d6a60e..6f3acde96 100644
--- a/app/components/sd.tsx
+++ b/app/components/sd.tsx
@@ -9,7 +9,6 @@ import {
copyToClipboard,
getMessageTextContent,
useMobileScreen,
- useWindowSize,
} from "@/app/utils";
import { useNavigate } from "react-router-dom";
import { useAppConfig } from "@/app/store";
@@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg";
import PromptIcon from "@/app/icons/prompt.svg";
import ResetIcon from "@/app/icons/reload.svg";
import { useIndexedDB } from "react-indexed-db-hook";
-import { useSdStore } from "@/app/store/sd";
+import { sendSdTask, useSdStore } from "@/app/store/sd";
import locales from "@/app/locales";
import LoadingIcon from "../icons/three-dots.svg";
import ErrorIcon from "../icons/delete.svg";
import { Property } from "csstype";
-import { showConfirm } from "@/app/components/ui-lib";
+import {
+ showConfirm,
+ showImageModal,
+ showModal,
+} from "@/app/components/ui-lib";
-function openBase64ImgUrl(base64Data: string, contentType: string) {
+function getBase64ImgUrl(base64Data: string, contentType: string) {
const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
@@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) {
}
const byteArray = new Uint8Array(byteNumbers);
const blob = new Blob([byteArray], { type: contentType });
- const blobUrl = URL.createObjectURL(blob);
- window.open(blobUrl);
+ return URL.createObjectURL(blob);
}
function getSdTaskStatus(item: any) {
@@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) {
{locales.Sd.Status.Name}: {s}
- {item.status === "error" && - {item.error}}
+ {item.status === "error" && (
+ {
+ showModal({
+ title: locales.Sd.Detail,
+ children: (
+
{locales.SdPanel.Prompt}:{" "}
-
+ {
+ showModal({
+ title: locales.Sd.Detail,
+ children: (
+