import { getBearerToken } from '@/app/client/api'; import { getModelParamBasicData, models } from '@/app/components/sd/sd-panel'; import { ACCESS_CODE_PREFIX, ApiPath, Stability, StoreKey, } from '@/app/constant'; import { base64Image2Blob, uploadImage } from '@/app/utils/chat'; import { createPersistStore } from '@/app/utils/store'; import { nanoid } from 'nanoid'; import { useAccessStore } from './access'; const defaultModel = { name: models[0].name, value: models[0].value, }; const defaultParams = getModelParamBasicData(models[0].params({}), {}); const DEFAULT_SD_STATE = { currentId: 0, draw: [], currentModel: defaultModel, currentParams: defaultParams, }; export const useSdStore = createPersistStore< { currentId: number; draw: any[]; currentModel: typeof defaultModel; currentParams: any; }, { getNextId: () => number; sendTask: (data: any, okCall?: Function) => void; updateDraw: (draw: any) => void; setCurrentModel: (model: any) => void; setCurrentParams: (data: any) => void; } >( DEFAULT_SD_STATE, (set, _get) => { function get() { return { ..._get(), ...methods, }; } const methods = { getNextId() { const id = ++_get().currentId; set({ currentId: id }); return id; }, sendTask(data: any, okCall?: Function) { data = { ...data, id: nanoid(), status: 'running' }; set({ draw: [data, ..._get().draw] }); this.getNextId(); this.stabilityRequestCall(data); okCall?.(); }, stabilityRequestCall(data: any) { const accessStore = useAccessStore.getState(); let prefix: string = ApiPath.Stability as string; let bearerToken = ''; if (accessStore.useCustomConfig) { prefix = accessStore.stabilityUrl || (ApiPath.Stability as string); bearerToken = getBearerToken(accessStore.stabilityApiKey); } if (!bearerToken && accessStore.enabledAccessControl()) { bearerToken = getBearerToken( ACCESS_CODE_PREFIX + accessStore.accessCode, ); } const headers = { Accept: 'application/json', Authorization: bearerToken, }; const path = `${prefix}/${Stability.GeneratePath}/${data.model}`; const formData = new FormData(); for (const paramsKey in data.params) { formData.append(paramsKey, data.params[paramsKey]); } fetch(path, { method: 'POST', headers, body: formData, }) .then(response => response.json()) .then((resData) => { if (resData.errors && resData.errors.length > 0) { this.updateDraw({ ...data, status: 'error', error: resData.errors[0], }); this.getNextId(); return; } const self = this; if (resData.finish_reason === 'SUCCESS') { uploadImage(base64Image2Blob(resData.image, 'image/png')) .then((img_data) => { console.debug('uploadImage success', img_data, self); self.updateDraw({ ...data, status: 'success', img_data, }); }) .catch((e) => { console.error('uploadImage error', e); self.updateDraw({ ...data, status: 'error', error: JSON.stringify(e), }); }); } else { self.updateDraw({ ...data, status: 'error', error: JSON.stringify(resData), }); } this.getNextId(); }) .catch((error) => { this.updateDraw({ ...data, status: 'error', error: error.message }); console.error('Error:', error); this.getNextId(); }); }, updateDraw(_draw: any) { const draw = _get().draw || []; draw.some((item, index) => { if (item.id === _draw.id) { draw[index] = _draw; set(() => ({ draw })); return true; } }); }, setCurrentModel(model: any) { set({ currentModel: model }); }, setCurrentParams(data: any) { set({ currentParams: data, }); }, }; return methods; }, { name: StoreKey.SdList, version: 1.0, }, );