feat: Add Stability API server relay sending
This commit is contained in:
parent
a16725ac17
commit
2b0153807c
|
@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
|
|||
case ModelProvider.Claude:
|
||||
systemApiKey = serverConfig.anthropicApiKey;
|
||||
break;
|
||||
case ModelProvider.Stability:
|
||||
systemApiKey = serverConfig.stabilityApiKey;
|
||||
break;
|
||||
case ModelProvider.GPT:
|
||||
default:
|
||||
if (serverConfig.isAzure) {
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { getServerSideConfig } from "@/app/config/server";
|
||||
import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
|
||||
import { auth } from "@/app/api/auth";
|
||||
|
||||
async function handle(
|
||||
req: NextRequest,
|
||||
{ params }: { params: { path: string[] } },
|
||||
) {
|
||||
console.log("[Stability] params ", params);
|
||||
|
||||
if (req.method === "OPTIONS") {
|
||||
return NextResponse.json({ body: "OK" }, { status: 200 });
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
|
||||
const serverConfig = getServerSideConfig();
|
||||
|
||||
let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;
|
||||
|
||||
if (!baseUrl.startsWith("http")) {
|
||||
baseUrl = `https://${baseUrl}`;
|
||||
}
|
||||
|
||||
if (baseUrl.endsWith("/")) {
|
||||
baseUrl = baseUrl.slice(0, -1);
|
||||
}
|
||||
|
||||
let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");
|
||||
|
||||
console.log("[Stability Proxy] ", path);
|
||||
console.log("[Stability Base Url]", baseUrl);
|
||||
|
||||
const timeoutId = setTimeout(
|
||||
() => {
|
||||
controller.abort();
|
||||
},
|
||||
10 * 60 * 1000,
|
||||
);
|
||||
|
||||
const authResult = auth(req, ModelProvider.Stability);
|
||||
|
||||
if (authResult.error) {
|
||||
return NextResponse.json(authResult, {
|
||||
status: 401,
|
||||
});
|
||||
}
|
||||
|
||||
const bearToken = req.headers.get("Authorization") ?? "";
|
||||
const token = bearToken.trim().replaceAll("Bearer ", "").trim();
|
||||
|
||||
const key = token ? token : serverConfig.stabilityApiKey;
|
||||
|
||||
if (!key) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: true,
|
||||
message: `missing STABILITY_API_KEY in server env vars`,
|
||||
},
|
||||
{
|
||||
status: 401,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const fetchUrl = `${baseUrl}/${path}`;
|
||||
console.log("[Stability Url] ", fetchUrl);
|
||||
const fetchOptions: RequestInit = {
|
||||
headers: {
|
||||
"Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
|
||||
Accept: req.headers.get("Accept") || "application/json",
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
method: req.method,
|
||||
body: req.body,
|
||||
// to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
|
||||
redirect: "manual",
|
||||
// @ts-ignore
|
||||
duplex: "half",
|
||||
signal: controller.signal,
|
||||
};
|
||||
|
||||
try {
|
||||
const res = await fetch(fetchUrl, fetchOptions);
|
||||
// to prevent browser prompt for credentials
|
||||
const newHeaders = new Headers(res.headers);
|
||||
newHeaders.delete("www-authenticate");
|
||||
// to disable nginx buffering
|
||||
newHeaders.set("X-Accel-Buffering", "no");
|
||||
return new Response(res.body, {
|
||||
status: res.status,
|
||||
statusText: res.statusText,
|
||||
headers: newHeaders,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
export const GET = handle;
|
||||
export const POST = handle;
|
||||
|
||||
export const runtime = "edge";
|
|
@ -307,22 +307,12 @@ export function SdPanel() {
|
|||
model_name: currentModel.name,
|
||||
status: "wait",
|
||||
params: reqParams,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: new Date().toLocaleString(),
|
||||
img_data: "",
|
||||
};
|
||||
sdListDb.add(data).then(
|
||||
(id) => {
|
||||
data = { ...data, id, status: "running" };
|
||||
sdListDb.update(data);
|
||||
execCountInc();
|
||||
sendSdTask(data, sdListDb, execCountInc);
|
||||
setParams(getModelParamBasicData(columns, params, true));
|
||||
},
|
||||
(error) => {
|
||||
console.error(error);
|
||||
showToast(`error: ` + error.message);
|
||||
},
|
||||
);
|
||||
sendSdTask(data, sdListDb, execCountInc, () => {
|
||||
setParams(getModelParamBasicData(columns, params, true));
|
||||
});
|
||||
};
|
||||
return (
|
||||
<>
|
||||
|
|
|
@ -9,7 +9,6 @@ import {
|
|||
copyToClipboard,
|
||||
getMessageTextContent,
|
||||
useMobileScreen,
|
||||
useWindowSize,
|
||||
} from "@/app/utils";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { useAppConfig } from "@/app/store";
|
||||
|
@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg";
|
|||
import PromptIcon from "@/app/icons/prompt.svg";
|
||||
import ResetIcon from "@/app/icons/reload.svg";
|
||||
import { useIndexedDB } from "react-indexed-db-hook";
|
||||
import { useSdStore } from "@/app/store/sd";
|
||||
import { sendSdTask, useSdStore } from "@/app/store/sd";
|
||||
import locales from "@/app/locales";
|
||||
import LoadingIcon from "../icons/three-dots.svg";
|
||||
import ErrorIcon from "../icons/delete.svg";
|
||||
import { Property } from "csstype";
|
||||
import { showConfirm } from "@/app/components/ui-lib";
|
||||
import {
|
||||
showConfirm,
|
||||
showImageModal,
|
||||
showModal,
|
||||
} from "@/app/components/ui-lib";
|
||||
|
||||
function openBase64ImgUrl(base64Data: string, contentType: string) {
|
||||
function getBase64ImgUrl(base64Data: string, contentType: string) {
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
|
@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) {
|
|||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
const blob = new Blob([byteArray], { type: contentType });
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
window.open(blobUrl);
|
||||
return URL.createObjectURL(blob);
|
||||
}
|
||||
|
||||
function getSdTaskStatus(item: any) {
|
||||
|
@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) {
|
|||
<span>
|
||||
{locales.Sd.Status.Name}: {s}
|
||||
</span>
|
||||
{item.status === "error" && <span> - {item.error}</span>}
|
||||
{item.status === "error" && (
|
||||
<span
|
||||
className="clickable"
|
||||
onClick={() => {
|
||||
showModal({
|
||||
title: locales.Sd.Detail,
|
||||
children: (
|
||||
<div style={{ color: color, userSelect: "text" }}>
|
||||
{item.error}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}}
|
||||
>
|
||||
{" "}
|
||||
- {item.error}
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
);
|
||||
}
|
||||
|
@ -83,7 +102,7 @@ export function Sd() {
|
|||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const sdListDb = useIndexedDB(StoreKey.SdList);
|
||||
const [sdImages, setSdImages] = useState([]);
|
||||
const { execCount } = useSdStore();
|
||||
const { execCount, execCountInc } = useSdStore();
|
||||
|
||||
useEffect(() => {
|
||||
sdListDb.getAll().then((data) => {
|
||||
|
@ -145,7 +164,10 @@ export function Sd() {
|
|||
src={`data:image/png;base64,${item.img_data}`}
|
||||
alt={`${item.id}`}
|
||||
onClick={(e) => {
|
||||
openBase64ImgUrl(item.img_data, "image/png");
|
||||
showImageModal(
|
||||
getBase64ImgUrl(item.img_data, "image/png"),
|
||||
true,
|
||||
);
|
||||
}}
|
||||
/>
|
||||
) : item.status === "error" ? (
|
||||
|
@ -163,7 +185,20 @@ export function Sd() {
|
|||
>
|
||||
<p className={styles["line-1"]}>
|
||||
{locales.SdPanel.Prompt}:{" "}
|
||||
<span title={item.params.prompt}>
|
||||
<span
|
||||
className="clickable"
|
||||
title={item.params.prompt}
|
||||
onClick={() => {
|
||||
showModal({
|
||||
title: locales.Sd.Detail,
|
||||
children: (
|
||||
<div style={{ userSelect: "text" }}>
|
||||
{item.params.prompt}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}}
|
||||
>
|
||||
{item.params.prompt}
|
||||
</span>
|
||||
</p>
|
||||
|
@ -177,7 +212,21 @@ export function Sd() {
|
|||
<ChatAction
|
||||
text={Locale.Sd.Actions.Params}
|
||||
icon={<PromptIcon />}
|
||||
onClick={() => console.log(1)}
|
||||
onClick={() => {
|
||||
showModal({
|
||||
title: locales.Sd.GenerateParams,
|
||||
children: (
|
||||
<div style={{ userSelect: "text" }}>
|
||||
{Object.keys(item.params).map((key) => (
|
||||
<div key={key} style={{ margin: "10px" }}>
|
||||
<strong>{key}: </strong>
|
||||
{item.params[key]}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<ChatAction
|
||||
text={Locale.Sd.Actions.Copy}
|
||||
|
@ -194,7 +243,17 @@ export function Sd() {
|
|||
<ChatAction
|
||||
text={Locale.Sd.Actions.Retry}
|
||||
icon={<ResetIcon />}
|
||||
onClick={() => console.log(1)}
|
||||
onClick={() => {
|
||||
const reqData = {
|
||||
model: item.model,
|
||||
model_name: item.model_name,
|
||||
status: "wait",
|
||||
params: { ...item.params },
|
||||
created_at: new Date().toLocaleString(),
|
||||
img_data: "",
|
||||
};
|
||||
sendSdTask(reqData, sdListDb, execCountInc);
|
||||
}}
|
||||
/>
|
||||
<ChatAction
|
||||
text={Locale.Sd.Actions.Delete}
|
||||
|
|
|
@ -425,11 +425,12 @@ export function showPrompt(content: any, value = "", rows = 3) {
|
|||
});
|
||||
}
|
||||
|
||||
export function showImageModal(img: string) {
|
||||
export function showImageModal(img: string, defaultMax?: boolean) {
|
||||
showModal({
|
||||
title: Locale.Export.Image.Modal,
|
||||
defaultMax: defaultMax,
|
||||
children: (
|
||||
<div>
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src={img}
|
||||
alt="preview"
|
||||
|
|
|
@ -124,6 +124,9 @@ export const getServerSideConfig = () => {
|
|||
anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
|
||||
anthropicUrl: process.env.ANTHROPIC_URL,
|
||||
|
||||
stabilityUrl: process.env.STABILITY_URL,
|
||||
stabilityApiKey: process.env.STABILITY_API_KEY,
|
||||
|
||||
gtmId: process.env.GTM_ID,
|
||||
|
||||
needCode: ACCESS_CODES.size > 0,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import { stabilityRequestCall } from "@/app/store/sd";
|
||||
|
||||
export const OWNER = "Yidadaa";
|
||||
export const REPO = "ChatGPT-Next-Web";
|
||||
export const REPO_URL = `https://github.com/${OWNER}/${REPO}`;
|
||||
|
@ -13,6 +15,7 @@ export const OPENAI_BASE_URL = "https://api.openai.com";
|
|||
export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
|
||||
|
||||
export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
|
||||
export const STABILITY_BASE_URL = "https://api.stability.ai";
|
||||
|
||||
export enum Path {
|
||||
Home = "/",
|
||||
|
@ -79,6 +82,7 @@ export enum ModelProvider {
|
|||
GPT = "GPT",
|
||||
GeminiPro = "GeminiPro",
|
||||
Claude = "Claude",
|
||||
Stability = "Stability",
|
||||
}
|
||||
|
||||
export const Anthropic = {
|
||||
|
@ -104,6 +108,10 @@ export const Google = {
|
|||
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
|
||||
};
|
||||
|
||||
export const StabilityPath = {
|
||||
GeneratePath: "v2beta/stable-image/generate",
|
||||
};
|
||||
|
||||
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
|
||||
// export const DEFAULT_SYSTEM_TEMPLATE = `
|
||||
// You are ChatGPT, a large language model trained by {{ServiceProvider}}.
|
||||
|
|
|
@ -534,6 +534,8 @@ const cn = {
|
|||
Danger: {
|
||||
Delete: "确认删除?",
|
||||
},
|
||||
GenerateParams: "生成参数",
|
||||
Detail: "详情",
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -540,6 +540,8 @@ const en: LocaleType = {
|
|||
Danger: {
|
||||
Delete: "Confirm to delete?",
|
||||
},
|
||||
GenerateParams: "Generate Params",
|
||||
Detail: "Detail",
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import { initDB, useIndexedDB } from "react-indexed-db-hook";
|
||||
import { StoreKey } from "@/app/constant";
|
||||
import { StabilityPath, StoreKey } from "@/app/constant";
|
||||
import { create, StoreApi } from "zustand";
|
||||
import { showToast } from "@/app/components/ui-lib";
|
||||
|
||||
export const SdDbConfig = {
|
||||
name: "@chatgpt-next-web/sd",
|
||||
|
@ -44,12 +45,28 @@ export const useSdStore = create<SdStore>()((set) => ({
|
|||
execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })),
|
||||
}));
|
||||
|
||||
export function sendSdTask(data: any, db: any, inc: any) {
|
||||
export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) {
|
||||
db.add(data).then(
|
||||
(id: number) => {
|
||||
data = { ...data, id, status: "running" };
|
||||
db.update(data);
|
||||
inc();
|
||||
stabilityRequestCall(data, db, inc);
|
||||
okCall?.();
|
||||
},
|
||||
(error: any) => {
|
||||
console.error(error);
|
||||
showToast(`error: ` + error.message);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export function stabilityRequestCall(data: any, db: any, inc: any) {
|
||||
const formData = new FormData();
|
||||
for (let paramsKey in data.params) {
|
||||
formData.append(paramsKey, data.params[paramsKey]);
|
||||
}
|
||||
fetch("https://api.stability.ai/v2beta/stable-image/generate/" + data.model, {
|
||||
fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
|
|
Loading…
Reference in New Issue