add dalle3 model
This commit is contained in:
parent
b3219f57c8
commit
ac599aa47c
|
@ -33,6 +33,7 @@ import {
|
|||
getMessageTextContent,
|
||||
getMessageImages,
|
||||
isVisionModel,
|
||||
isDalle3 as _isDalle3,
|
||||
} from "@/app/utils";
|
||||
|
||||
export interface OpenAIListModelResponse {
|
||||
|
@ -58,6 +59,13 @@ export interface RequestPayload {
|
|||
max_tokens?: number;
|
||||
}
|
||||
|
||||
export interface DalleRequestPayload {
|
||||
model: string;
|
||||
prompt: string;
|
||||
n: number;
|
||||
size: "1024x1024" | "1792x1024" | "1024x1792";
|
||||
}
|
||||
|
||||
export class ChatGPTApi implements LLMApi {
|
||||
private disableListModels = true;
|
||||
|
||||
|
@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi {
|
|||
}
|
||||
|
||||
extractMessage(res: any) {
|
||||
if (res.error) {
|
||||
return "```\n" + JSON.stringify(res, null, 4) + "\n```";
|
||||
}
|
||||
// dalle3 model return url, just return
|
||||
if (res.data) {
|
||||
const url = res.data?.at(0)?.url ?? "";
|
||||
return [
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url,
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
return res.choices?.at(0)?.message?.content ?? "";
|
||||
}
|
||||
|
||||
async chat(options: ChatOptions) {
|
||||
const visionModel = isVisionModel(options.config.model);
|
||||
const messages: ChatOptions["messages"] = [];
|
||||
for (const v of options.messages) {
|
||||
const content = visionModel
|
||||
? await preProcessImageContent(v.content)
|
||||
: getMessageTextContent(v);
|
||||
messages.push({ role: v.role, content });
|
||||
}
|
||||
|
||||
const modelConfig = {
|
||||
...useAppConfig.getState().modelConfig,
|
||||
...useChatStore.getState().currentSession().mask.modelConfig,
|
||||
|
@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi {
|
|||
},
|
||||
};
|
||||
|
||||
const requestPayload: RequestPayload = {
|
||||
messages,
|
||||
stream: options.config.stream,
|
||||
model: modelConfig.model,
|
||||
temperature: modelConfig.temperature,
|
||||
presence_penalty: modelConfig.presence_penalty,
|
||||
frequency_penalty: modelConfig.frequency_penalty,
|
||||
top_p: modelConfig.top_p,
|
||||
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
||||
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
||||
};
|
||||
let requestPayload: RequestPayload | DalleRequestPayload;
|
||||
|
||||
// add max_tokens to vision model
|
||||
if (visionModel && modelConfig.model.includes("preview")) {
|
||||
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
||||
const isDalle3 = _isDalle3(options.config.model);
|
||||
if (isDalle3) {
|
||||
const prompt = getMessageTextContent(options.messages.slice(-1)?.pop());
|
||||
requestPayload = {
|
||||
model: options.config.model,
|
||||
prompt,
|
||||
n: 1,
|
||||
size: options.config?.size ?? "1024x1024",
|
||||
};
|
||||
} else {
|
||||
const visionModel = isVisionModel(options.config.model);
|
||||
const messages: ChatOptions["messages"] = [];
|
||||
for (const v of options.messages) {
|
||||
const content = visionModel
|
||||
? await preProcessImageContent(v.content)
|
||||
: getMessageTextContent(v);
|
||||
messages.push({ role: v.role, content });
|
||||
}
|
||||
|
||||
requestPayload = {
|
||||
messages,
|
||||
stream: options.config.stream,
|
||||
model: modelConfig.model,
|
||||
temperature: modelConfig.temperature,
|
||||
presence_penalty: modelConfig.presence_penalty,
|
||||
frequency_penalty: modelConfig.frequency_penalty,
|
||||
top_p: modelConfig.top_p,
|
||||
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
||||
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
||||
};
|
||||
|
||||
// add max_tokens to vision model
|
||||
if (visionModel && modelConfig.model.includes("preview")) {
|
||||
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
||||
}
|
||||
}
|
||||
|
||||
console.log("[Request] openai payload: ", requestPayload);
|
||||
|
||||
const shouldStream = !!options.config.stream;
|
||||
const shouldStream = !isDalle3 && !!options.config.stream;
|
||||
const controller = new AbortController();
|
||||
options.onController?.(controller);
|
||||
|
||||
|
@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi {
|
|||
model?.provider?.providerName === ServiceProvider.Azure,
|
||||
);
|
||||
chatPath = this.path(
|
||||
Azure.ChatPath(
|
||||
(isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
|
||||
(model?.displayName ?? model?.name) as string,
|
||||
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
chatPath = this.path(OpenaiPath.ChatPath);
|
||||
chatPath = this.path(
|
||||
isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
|
||||
);
|
||||
}
|
||||
const chatPayload = {
|
||||
method: "POST",
|
||||
|
|
|
@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
|
|||
import BottomIcon from "../icons/bottom.svg";
|
||||
import StopIcon from "../icons/pause.svg";
|
||||
import RobotIcon from "../icons/robot.svg";
|
||||
import SizeIcon from "../icons/size.svg";
|
||||
import PluginIcon from "../icons/plugin.svg";
|
||||
|
||||
import {
|
||||
|
@ -60,6 +61,7 @@ import {
|
|||
getMessageTextContent,
|
||||
getMessageImages,
|
||||
isVisionModel,
|
||||
isDalle3,
|
||||
} from "../utils";
|
||||
|
||||
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
||||
|
@ -481,6 +483,11 @@ export function ChatActions(props: {
|
|||
const [showPluginSelector, setShowPluginSelector] = useState(false);
|
||||
const [showUploadImage, setShowUploadImage] = useState(false);
|
||||
|
||||
const [showSizeSelector, setShowSizeSelector] = useState(false);
|
||||
const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"];
|
||||
const currentSize =
|
||||
chatStore.currentSession().mask.modelConfig?.size || "1024x1024";
|
||||
|
||||
useEffect(() => {
|
||||
const show = isVisionModel(currentModel);
|
||||
setShowUploadImage(show);
|
||||
|
@ -624,6 +631,33 @@ export function ChatActions(props: {
|
|||
/>
|
||||
)}
|
||||
|
||||
{isDalle3(currentModel) && (
|
||||
<ChatAction
|
||||
onClick={() => setShowSizeSelector(true)}
|
||||
text={currentSize}
|
||||
icon={<SizeIcon />}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showSizeSelector && (
|
||||
<Selector
|
||||
defaultSelectedValue={currentSize}
|
||||
items={dalle3Sizes.map((m) => ({
|
||||
title: m,
|
||||
value: m,
|
||||
}))}
|
||||
onClose={() => setShowSizeSelector(false)}
|
||||
onSelection={(s) => {
|
||||
if (s.length === 0) return;
|
||||
const size = s[0];
|
||||
chatStore.updateCurrentSession((session) => {
|
||||
session.mask.modelConfig.size = size;
|
||||
});
|
||||
showToast(size);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<ChatAction
|
||||
onClick={() => setShowPluginSelector(true)}
|
||||
text={Locale.Plugin.Name}
|
||||
|
|
|
@ -146,6 +146,7 @@ export const Anthropic = {
|
|||
|
||||
export const OpenaiPath = {
|
||||
ChatPath: "v1/chat/completions",
|
||||
ImagePath: "v1/images/generations",
|
||||
UsagePath: "dashboard/billing/usage",
|
||||
SubsPath: "dashboard/billing/subscription",
|
||||
ListModelPath: "v1/models",
|
||||
|
@ -154,7 +155,10 @@ export const OpenaiPath = {
|
|||
export const Azure = {
|
||||
ChatPath: (deployName: string, apiVersion: string) =>
|
||||
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
|
||||
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
|
||||
// https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
|
||||
ImagePath: (deployName: string, apiVersion: string) =>
|
||||
`deployments/${deployName}/images/generations?api-version=${apiVersion}`,
|
||||
ExampleEndpoint: "https://{resource-url}/openai",
|
||||
};
|
||||
|
||||
export const Google = {
|
||||
|
@ -256,6 +260,7 @@ const openaiModels = [
|
|||
"gpt-4-vision-preview",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-1106-preview",
|
||||
"dall-e-3",
|
||||
];
|
||||
|
||||
const googleModels = [
|
||||
|
|
|
@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
|
|||
import { createPersistStore } from "../utils/store";
|
||||
import { collectModelsWithDefaultModel } from "../utils/model";
|
||||
import { useAccessStore } from "./access";
|
||||
import { isDalle3 } from "../utils";
|
||||
|
||||
export type ChatMessage = RequestMessage & {
|
||||
date: string;
|
||||
|
@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
|
|||
const config = useAppConfig.getState();
|
||||
const session = get().currentSession();
|
||||
const modelConfig = session.mask.modelConfig;
|
||||
// skip summarize when using dalle3?
|
||||
if (isDalle3(modelConfig.model)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const api: ClientApi = getClientApi(modelConfig.providerName);
|
||||
|
||||
|
|
|
@ -265,3 +265,7 @@ export function isVisionModel(model: string) {
|
|||
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
|
||||
);
|
||||
}
|
||||
|
||||
export function isDalle3(model: string) {
|
||||
return "dall-e-3" === model;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue