ChatGPT-Next-Web/app/store/sd.ts

164 lines
5.0 KiB
TypeScript

import { getBearerToken } from '@/app/client/api';
import { getModelParamBasicData, models } from '@/app/components/sd/sd-panel';
import {
ACCESS_CODE_PREFIX,
ApiPath,
Stability,
StoreKey,
} from '@/app/constant';
import { base64Image2Blob, uploadImage } from '@/app/utils/chat';
import { createPersistStore } from '@/app/utils/store';
import { nanoid } from 'nanoid';
import { useAccessStore } from './access';
const defaultModel = {
name: models[0].name,
value: models[0].value,
};
const defaultParams = getModelParamBasicData(models[0].params({}), {});
const DEFAULT_SD_STATE = {
currentId: 0,
draw: [],
currentModel: defaultModel,
currentParams: defaultParams,
};
export const useSdStore = createPersistStore<
{
currentId: number;
draw: any[];
currentModel: typeof defaultModel;
currentParams: any;
},
{
getNextId: () => number;
sendTask: (data: any, okCall?: Function) => void;
updateDraw: (draw: any) => void;
setCurrentModel: (model: any) => void;
setCurrentParams: (data: any) => void;
}
>(
DEFAULT_SD_STATE,
(set, _get) => {
function get() {
return {
..._get(),
...methods,
};
}
const methods = {
getNextId() {
const id = ++_get().currentId;
set({ currentId: id });
return id;
},
sendTask(data: any, okCall?: Function) {
data = { ...data, id: nanoid(), status: 'running' };
set({ draw: [data, ..._get().draw] });
this.getNextId();
this.stabilityRequestCall(data);
okCall?.();
},
stabilityRequestCall(data: any) {
const accessStore = useAccessStore.getState();
let prefix: string = ApiPath.Stability as string;
let bearerToken = '';
if (accessStore.useCustomConfig) {
prefix = accessStore.stabilityUrl || (ApiPath.Stability as string);
bearerToken = getBearerToken(accessStore.stabilityApiKey);
}
if (!bearerToken && accessStore.enabledAccessControl()) {
bearerToken = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
const headers = {
Accept: 'application/json',
Authorization: bearerToken,
};
const path = `${prefix}/${Stability.GeneratePath}/${data.model}`;
const formData = new FormData();
for (const paramsKey in data.params) {
formData.append(paramsKey, data.params[paramsKey]);
}
fetch(path, {
method: 'POST',
headers,
body: formData,
})
.then(response => response.json())
.then((resData) => {
if (resData.errors && resData.errors.length > 0) {
this.updateDraw({
...data,
status: 'error',
error: resData.errors[0],
});
this.getNextId();
return;
}
const self = this;
if (resData.finish_reason === 'SUCCESS') {
uploadImage(base64Image2Blob(resData.image, 'image/png'))
.then((img_data) => {
console.debug('uploadImage success', img_data, self);
self.updateDraw({
...data,
status: 'success',
img_data,
});
})
.catch((e) => {
console.error('uploadImage error', e);
self.updateDraw({
...data,
status: 'error',
error: JSON.stringify(e),
});
});
} else {
self.updateDraw({
...data,
status: 'error',
error: JSON.stringify(resData),
});
}
this.getNextId();
})
.catch((error) => {
this.updateDraw({ ...data, status: 'error', error: error.message });
console.error('Error:', error);
this.getNextId();
});
},
updateDraw(_draw: any) {
const draw = _get().draw || [];
draw.some((item, index) => {
if (item.id === _draw.id) {
draw[index] = _draw;
set(() => ({ draw }));
return true;
}
});
},
setCurrentModel(model: any) {
set({ currentModel: model });
},
setCurrentParams(data: any) {
set({
currentParams: data,
});
},
};
return methods;
},
{
name: StoreKey.SdList,
version: 1.0,
},
);