add dalle3 model

This commit is contained in:
lloydzhou 2024-08-02 18:00:42 +08:00
parent b3219f57c8
commit ac599aa47c
5 changed files with 113 additions and 27 deletions

View File

@ -33,6 +33,7 @@ import {
getMessageTextContent, getMessageTextContent,
getMessageImages, getMessageImages,
isVisionModel, isVisionModel,
isDalle3 as _isDalle3,
} from "@/app/utils"; } from "@/app/utils";
export interface OpenAIListModelResponse { export interface OpenAIListModelResponse {
@ -58,6 +59,13 @@ export interface RequestPayload {
max_tokens?: number; max_tokens?: number;
} }
export interface DalleRequestPayload {
model: string;
prompt: string;
n: number;
size: "1024x1024" | "1792x1024" | "1024x1792";
}
export class ChatGPTApi implements LLMApi { export class ChatGPTApi implements LLMApi {
private disableListModels = true; private disableListModels = true;
@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi {
} }
extractMessage(res: any) { extractMessage(res: any) {
if (res.error) {
return "```\n" + JSON.stringify(res, null, 4) + "\n```";
}
// dalle3 model return url, just return
if (res.data) {
const url = res.data?.at(0)?.url ?? "";
return [
{
type: "image_url",
image_url: {
url,
},
},
];
}
return res.choices?.at(0)?.message?.content ?? ""; return res.choices?.at(0)?.message?.content ?? "";
} }
async chat(options: ChatOptions) { async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}
const modelConfig = { const modelConfig = {
...useAppConfig.getState().modelConfig, ...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig,
@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi {
}, },
}; };
const requestPayload: RequestPayload = { let requestPayload: RequestPayload | DalleRequestPayload;
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
// add max_tokens to vision model const isDalle3 = _isDalle3(options.config.model);
if (visionModel && modelConfig.model.includes("preview")) { if (isDalle3) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); const prompt = getMessageTextContent(options.messages.slice(-1)?.pop());
requestPayload = {
model: options.config.model,
prompt,
n: 1,
size: options.config?.size ?? "1024x1024",
};
} else {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}
requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
// add max_tokens to vision model
if (visionModel && modelConfig.model.includes("preview")) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
}
} }
console.log("[Request] openai payload: ", requestPayload); console.log("[Request] openai payload: ", requestPayload);
const shouldStream = !!options.config.stream; const shouldStream = !isDalle3 && !!options.config.stream;
const controller = new AbortController(); const controller = new AbortController();
options.onController?.(controller); options.onController?.(controller);
@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi {
model?.provider?.providerName === ServiceProvider.Azure, model?.provider?.providerName === ServiceProvider.Azure,
); );
chatPath = this.path( chatPath = this.path(
Azure.ChatPath( (isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
(model?.displayName ?? model?.name) as string, (model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "", useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
), ),
); );
} else { } else {
chatPath = this.path(OpenaiPath.ChatPath); chatPath = this.path(
isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
);
} }
const chatPayload = { const chatPayload = {
method: "POST", method: "POST",

View File

@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
import BottomIcon from "../icons/bottom.svg"; import BottomIcon from "../icons/bottom.svg";
import StopIcon from "../icons/pause.svg"; import StopIcon from "../icons/pause.svg";
import RobotIcon from "../icons/robot.svg"; import RobotIcon from "../icons/robot.svg";
import SizeIcon from "../icons/size.svg";
import PluginIcon from "../icons/plugin.svg"; import PluginIcon from "../icons/plugin.svg";
import { import {
@ -60,6 +61,7 @@ import {
getMessageTextContent, getMessageTextContent,
getMessageImages, getMessageImages,
isVisionModel, isVisionModel,
isDalle3,
} from "../utils"; } from "../utils";
import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@ -481,6 +483,11 @@ export function ChatActions(props: {
const [showPluginSelector, setShowPluginSelector] = useState(false); const [showPluginSelector, setShowPluginSelector] = useState(false);
const [showUploadImage, setShowUploadImage] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false);
const [showSizeSelector, setShowSizeSelector] = useState(false);
const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"];
const currentSize =
chatStore.currentSession().mask.modelConfig?.size || "1024x1024";
useEffect(() => { useEffect(() => {
const show = isVisionModel(currentModel); const show = isVisionModel(currentModel);
setShowUploadImage(show); setShowUploadImage(show);
@ -624,6 +631,33 @@ export function ChatActions(props: {
/> />
)} )}
{isDalle3(currentModel) && (
<ChatAction
onClick={() => setShowSizeSelector(true)}
text={currentSize}
icon={<SizeIcon />}
/>
)}
{showSizeSelector && (
<Selector
defaultSelectedValue={currentSize}
items={dalle3Sizes.map((m) => ({
title: m,
value: m,
}))}
onClose={() => setShowSizeSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const size = s[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.size = size;
});
showToast(size);
}}
/>
)}
<ChatAction <ChatAction
onClick={() => setShowPluginSelector(true)} onClick={() => setShowPluginSelector(true)}
text={Locale.Plugin.Name} text={Locale.Plugin.Name}

View File

@ -146,6 +146,7 @@ export const Anthropic = {
export const OpenaiPath = { export const OpenaiPath = {
ChatPath: "v1/chat/completions", ChatPath: "v1/chat/completions",
ImagePath: "v1/images/generations",
UsagePath: "dashboard/billing/usage", UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription", SubsPath: "dashboard/billing/subscription",
ListModelPath: "v1/models", ListModelPath: "v1/models",
@ -154,7 +155,10 @@ export const OpenaiPath = {
export const Azure = { export const Azure = {
ChatPath: (deployName: string, apiVersion: string) => ChatPath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`, `deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", // https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
ImagePath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/images/generations?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai",
}; };
export const Google = { export const Google = {
@ -256,6 +260,7 @@ const openaiModels = [
"gpt-4-vision-preview", "gpt-4-vision-preview",
"gpt-4-turbo-2024-04-09", "gpt-4-turbo-2024-04-09",
"gpt-4-1106-preview", "gpt-4-1106-preview",
"dall-e-3",
]; ];
const googleModels = [ const googleModels = [

View File

@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model"; import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access"; import { useAccessStore } from "./access";
import { isDalle3 } from "../utils";
export type ChatMessage = RequestMessage & { export type ChatMessage = RequestMessage & {
date: string; date: string;
@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
const config = useAppConfig.getState(); const config = useAppConfig.getState();
const session = get().currentSession(); const session = get().currentSession();
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}
const api: ClientApi = getClientApi(modelConfig.providerName); const api: ClientApi = getClientApi(modelConfig.providerName);

View File

@ -265,3 +265,7 @@ export function isVisionModel(model: string) {
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
); );
} }
export function isDalle3(model: string) {
return "dall-e-3" === model;
}