feat: Add Stability API server relay sending

This commit is contained in:
licoy 2024-07-09 09:50:04 +08:00
parent a16725ac17
commit 2b0153807c
10 changed files with 220 additions and 31 deletions

View File

@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.Claude: case ModelProvider.Claude:
systemApiKey = serverConfig.anthropicApiKey; systemApiKey = serverConfig.anthropicApiKey;
break; break;
case ModelProvider.Stability:
systemApiKey = serverConfig.stabilityApiKey;
break;
case ModelProvider.GPT: case ModelProvider.GPT:
default: default:
if (serverConfig.isAzure) { if (serverConfig.isAzure) {

View File

@ -0,0 +1,104 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
import { auth } from "@/app/api/auth";
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Stability] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const controller = new AbortController();
const serverConfig = getServerSideConfig();
let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;
if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}
let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");
console.log("[Stability Proxy] ", path);
console.log("[Stability Base Url]", baseUrl);
const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
const authResult = auth(req, ModelProvider.Stability);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
const bearToken = req.headers.get("Authorization") ?? "";
const token = bearToken.trim().replaceAll("Bearer ", "").trim();
const key = token ? token : serverConfig.stabilityApiKey;
if (!key) {
return NextResponse.json(
{
error: true,
message: `missing STABILITY_API_KEY in server env vars`,
},
{
status: 401,
},
);
}
const fetchUrl = `${baseUrl}/${path}`;
console.log("[Stability Url] ", fetchUrl);
const fetchOptions: RequestInit = {
headers: {
"Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
Accept: req.headers.get("Accept") || "application/json",
Authorization: `Bearer ${key}`,
},
method: req.method,
body: req.body,
// to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
redirect: "manual",
// @ts-ignore
duplex: "half",
signal: controller.signal,
};
try {
const res = await fetch(fetchUrl, fetchOptions);
// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");
return new Response(res.body, {
status: res.status,
statusText: res.statusText,
headers: newHeaders,
});
} finally {
clearTimeout(timeoutId);
}
}
export const GET = handle;
export const POST = handle;
export const runtime = "edge";

View File

@ -307,22 +307,12 @@ export function SdPanel() {
model_name: currentModel.name, model_name: currentModel.name,
status: "wait", status: "wait",
params: reqParams, params: reqParams,
created_at: new Date().toISOString(), created_at: new Date().toLocaleString(),
img_data: "", img_data: "",
}; };
sdListDb.add(data).then( sendSdTask(data, sdListDb, execCountInc, () => {
(id) => {
data = { ...data, id, status: "running" };
sdListDb.update(data);
execCountInc();
sendSdTask(data, sdListDb, execCountInc);
setParams(getModelParamBasicData(columns, params, true)); setParams(getModelParamBasicData(columns, params, true));
}, });
(error) => {
console.error(error);
showToast(`error: ` + error.message);
},
);
}; };
return ( return (
<> <>

View File

@ -9,7 +9,6 @@ import {
copyToClipboard, copyToClipboard,
getMessageTextContent, getMessageTextContent,
useMobileScreen, useMobileScreen,
useWindowSize,
} from "@/app/utils"; } from "@/app/utils";
import { useNavigate } from "react-router-dom"; import { useNavigate } from "react-router-dom";
import { useAppConfig } from "@/app/store"; import { useAppConfig } from "@/app/store";
@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg";
import PromptIcon from "@/app/icons/prompt.svg"; import PromptIcon from "@/app/icons/prompt.svg";
import ResetIcon from "@/app/icons/reload.svg"; import ResetIcon from "@/app/icons/reload.svg";
import { useIndexedDB } from "react-indexed-db-hook"; import { useIndexedDB } from "react-indexed-db-hook";
import { useSdStore } from "@/app/store/sd"; import { sendSdTask, useSdStore } from "@/app/store/sd";
import locales from "@/app/locales"; import locales from "@/app/locales";
import LoadingIcon from "../icons/three-dots.svg"; import LoadingIcon from "../icons/three-dots.svg";
import ErrorIcon from "../icons/delete.svg"; import ErrorIcon from "../icons/delete.svg";
import { Property } from "csstype"; import { Property } from "csstype";
import { showConfirm } from "@/app/components/ui-lib"; import {
showConfirm,
showImageModal,
showModal,
} from "@/app/components/ui-lib";
function openBase64ImgUrl(base64Data: string, contentType: string) { function getBase64ImgUrl(base64Data: string, contentType: string) {
const byteCharacters = atob(base64Data); const byteCharacters = atob(base64Data);
const byteNumbers = new Array(byteCharacters.length); const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) { for (let i = 0; i < byteCharacters.length; i++) {
@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) {
} }
const byteArray = new Uint8Array(byteNumbers); const byteArray = new Uint8Array(byteNumbers);
const blob = new Blob([byteArray], { type: contentType }); const blob = new Blob([byteArray], { type: contentType });
const blobUrl = URL.createObjectURL(blob); return URL.createObjectURL(blob);
window.open(blobUrl);
} }
function getSdTaskStatus(item: any) { function getSdTaskStatus(item: any) {
@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) {
<span> <span>
{locales.Sd.Status.Name}: {s} {locales.Sd.Status.Name}: {s}
</span> </span>
{item.status === "error" && <span> - {item.error}</span>} {item.status === "error" && (
<span
className="clickable"
onClick={() => {
showModal({
title: locales.Sd.Detail,
children: (
<div style={{ color: color, userSelect: "text" }}>
{item.error}
</div>
),
});
}}
>
{" "}
- {item.error}
</span>
)}
</p> </p>
); );
} }
@ -83,7 +102,7 @@ export function Sd() {
const scrollRef = useRef<HTMLDivElement>(null); const scrollRef = useRef<HTMLDivElement>(null);
const sdListDb = useIndexedDB(StoreKey.SdList); const sdListDb = useIndexedDB(StoreKey.SdList);
const [sdImages, setSdImages] = useState([]); const [sdImages, setSdImages] = useState([]);
const { execCount } = useSdStore(); const { execCount, execCountInc } = useSdStore();
useEffect(() => { useEffect(() => {
sdListDb.getAll().then((data) => { sdListDb.getAll().then((data) => {
@ -145,7 +164,10 @@ export function Sd() {
src={`data:image/png;base64,${item.img_data}`} src={`data:image/png;base64,${item.img_data}`}
alt={`${item.id}`} alt={`${item.id}`}
onClick={(e) => { onClick={(e) => {
openBase64ImgUrl(item.img_data, "image/png"); showImageModal(
getBase64ImgUrl(item.img_data, "image/png"),
true,
);
}} }}
/> />
) : item.status === "error" ? ( ) : item.status === "error" ? (
@ -163,7 +185,20 @@ export function Sd() {
> >
<p className={styles["line-1"]}> <p className={styles["line-1"]}>
{locales.SdPanel.Prompt}:{" "} {locales.SdPanel.Prompt}:{" "}
<span title={item.params.prompt}> <span
className="clickable"
title={item.params.prompt}
onClick={() => {
showModal({
title: locales.Sd.Detail,
children: (
<div style={{ userSelect: "text" }}>
{item.params.prompt}
</div>
),
});
}}
>
{item.params.prompt} {item.params.prompt}
</span> </span>
</p> </p>
@ -177,7 +212,21 @@ export function Sd() {
<ChatAction <ChatAction
text={Locale.Sd.Actions.Params} text={Locale.Sd.Actions.Params}
icon={<PromptIcon />} icon={<PromptIcon />}
onClick={() => console.log(1)} onClick={() => {
showModal({
title: locales.Sd.GenerateParams,
children: (
<div style={{ userSelect: "text" }}>
{Object.keys(item.params).map((key) => (
<div key={key} style={{ margin: "10px" }}>
<strong>{key}: </strong>
{item.params[key]}
</div>
))}
</div>
),
});
}}
/> />
<ChatAction <ChatAction
text={Locale.Sd.Actions.Copy} text={Locale.Sd.Actions.Copy}
@ -194,7 +243,17 @@ export function Sd() {
<ChatAction <ChatAction
text={Locale.Sd.Actions.Retry} text={Locale.Sd.Actions.Retry}
icon={<ResetIcon />} icon={<ResetIcon />}
onClick={() => console.log(1)} onClick={() => {
const reqData = {
model: item.model,
model_name: item.model_name,
status: "wait",
params: { ...item.params },
created_at: new Date().toLocaleString(),
img_data: "",
};
sendSdTask(reqData, sdListDb, execCountInc);
}}
/> />
<ChatAction <ChatAction
text={Locale.Sd.Actions.Delete} text={Locale.Sd.Actions.Delete}

View File

@ -425,11 +425,12 @@ export function showPrompt(content: any, value = "", rows = 3) {
}); });
} }
export function showImageModal(img: string) { export function showImageModal(img: string, defaultMax?: boolean) {
showModal({ showModal({
title: Locale.Export.Image.Modal, title: Locale.Export.Image.Modal,
defaultMax: defaultMax,
children: ( children: (
<div> <div style={{ display: "flex", justifyContent: "center" }}>
<img <img
src={img} src={img}
alt="preview" alt="preview"

View File

@ -124,6 +124,9 @@ export const getServerSideConfig = () => {
anthropicApiVersion: process.env.ANTHROPIC_API_VERSION, anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
anthropicUrl: process.env.ANTHROPIC_URL, anthropicUrl: process.env.ANTHROPIC_URL,
stabilityUrl: process.env.STABILITY_URL,
stabilityApiKey: process.env.STABILITY_API_KEY,
gtmId: process.env.GTM_ID, gtmId: process.env.GTM_ID,
needCode: ACCESS_CODES.size > 0, needCode: ACCESS_CODES.size > 0,

View File

@ -1,3 +1,5 @@
import { stabilityRequestCall } from "@/app/store/sd";
export const OWNER = "Yidadaa"; export const OWNER = "Yidadaa";
export const REPO = "ChatGPT-Next-Web"; export const REPO = "ChatGPT-Next-Web";
export const REPO_URL = `https://github.com/${OWNER}/${REPO}`; export const REPO_URL = `https://github.com/${OWNER}/${REPO}`;
@ -13,6 +15,7 @@ export const OPENAI_BASE_URL = "https://api.openai.com";
export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
export const STABILITY_BASE_URL = "https://api.stability.ai";
export enum Path { export enum Path {
Home = "/", Home = "/",
@ -79,6 +82,7 @@ export enum ModelProvider {
GPT = "GPT", GPT = "GPT",
GeminiPro = "GeminiPro", GeminiPro = "GeminiPro",
Claude = "Claude", Claude = "Claude",
Stability = "Stability",
} }
export const Anthropic = { export const Anthropic = {
@ -104,6 +108,10 @@ export const Google = {
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
}; };
export const StabilityPath = {
GeneratePath: "v2beta/stable-image/generate",
};
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
// export const DEFAULT_SYSTEM_TEMPLATE = ` // export const DEFAULT_SYSTEM_TEMPLATE = `
// You are ChatGPT, a large language model trained by {{ServiceProvider}}. // You are ChatGPT, a large language model trained by {{ServiceProvider}}.

View File

@ -534,6 +534,8 @@ const cn = {
Danger: { Danger: {
Delete: "确认删除?", Delete: "确认删除?",
}, },
GenerateParams: "生成参数",
Detail: "详情",
}, },
}; };

View File

@ -540,6 +540,8 @@ const en: LocaleType = {
Danger: { Danger: {
Delete: "Confirm to delete?", Delete: "Confirm to delete?",
}, },
GenerateParams: "Generate Params",
Detail: "Detail",
}, },
}; };

View File

@ -1,6 +1,7 @@
import { initDB, useIndexedDB } from "react-indexed-db-hook"; import { initDB, useIndexedDB } from "react-indexed-db-hook";
import { StoreKey } from "@/app/constant"; import { StabilityPath, StoreKey } from "@/app/constant";
import { create, StoreApi } from "zustand"; import { create, StoreApi } from "zustand";
import { showToast } from "@/app/components/ui-lib";
export const SdDbConfig = { export const SdDbConfig = {
name: "@chatgpt-next-web/sd", name: "@chatgpt-next-web/sd",
@ -44,12 +45,28 @@ export const useSdStore = create<SdStore>()((set) => ({
execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })), execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })),
})); }));
export function sendSdTask(data: any, db: any, inc: any) { export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) {
db.add(data).then(
(id: number) => {
data = { ...data, id, status: "running" };
db.update(data);
inc();
stabilityRequestCall(data, db, inc);
okCall?.();
},
(error: any) => {
console.error(error);
showToast(`error: ` + error.message);
},
);
}
export function stabilityRequestCall(data: any, db: any, inc: any) {
const formData = new FormData(); const formData = new FormData();
for (let paramsKey in data.params) { for (let paramsKey in data.params) {
formData.append(paramsKey, data.params[paramsKey]); formData.append(paramsKey, data.params[paramsKey]);
} }
fetch("https://api.stability.ai/v2beta/stable-image/generate/" + data.model, { fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
method: "POST", method: "POST",
headers: { headers: {
Accept: "application/json", Accept: "application/json",