using b64_json for dall-e-3

This commit is contained in:
lloydzhou 2024-08-02 20:58:21 +08:00
parent 46cb48023e
commit 8c83fe23a1
1 changed files with 18 additions and 6 deletions

View File

@ -11,7 +11,11 @@ import {
} from "@/app/constant"; } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { collectModelsWithDefaultModel } from "@/app/utils/model";
import { preProcessImageContent } from "@/app/utils/chat"; import {
preProcessImageContent,
uploadImage,
base64Image2Blob,
} from "@/app/utils/chat";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { DalleSize } from "@/app/typing"; import { DalleSize } from "@/app/typing";
@ -63,6 +67,7 @@ export interface RequestPayload {
export interface DalleRequestPayload { export interface DalleRequestPayload {
model: string; model: string;
prompt: string; prompt: string;
response_format: "url" | "b64_json";
n: number; n: number;
size: DalleSize; size: DalleSize;
} }
@ -109,13 +114,18 @@ export class ChatGPTApi implements LLMApi {
return cloudflareAIGatewayUrl([baseUrl, path].join("/")); return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
} }
extractMessage(res: any) { async extractMessage(res: any) {
if (res.error) { if (res.error) {
return "```\n" + JSON.stringify(res, null, 4) + "\n```"; return "```\n" + JSON.stringify(res, null, 4) + "\n```";
} }
// dalle3 model return url, just return // dalle3 model return url, using url create image message
if (res.data) { if (res.data) {
const url = res.data?.at(0)?.url ?? ""; let url = res.data?.at(0)?.url ?? "";
const b64_json = res.data?.at(0)?.b64_json ?? "";
if (!url && b64_json) {
// uploadImage
url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
}
return [ return [
{ {
type: "image_url", type: "image_url",
@ -148,6 +158,8 @@ export class ChatGPTApi implements LLMApi {
requestPayload = { requestPayload = {
model: options.config.model, model: options.config.model,
prompt, prompt,
// URLs are only valid for 60 minutes after the image has been generated.
response_format: "b64_json", // using b64_json, and save image in CacheStorage
n: 1, n: 1,
size: options.config?.size ?? "1024x1024", size: options.config?.size ?? "1024x1024",
}; };
@ -227,7 +239,7 @@ export class ChatGPTApi implements LLMApi {
// make a fetch request // make a fetch request
const requestTimeoutId = setTimeout( const requestTimeoutId = setTimeout(
() => controller.abort(), () => controller.abort(),
REQUEST_TIMEOUT_MS, isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
); );
if (shouldStream) { if (shouldStream) {
@ -358,7 +370,7 @@ export class ChatGPTApi implements LLMApi {
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);
const resJson = await res.json(); const resJson = await res.json();
const message = this.extractMessage(resJson); const message = await this.extractMessage(resJson);
options.onFinish(message); options.onFinish(message);
} }
} catch (e) { } catch (e) {