Add vision support (#4076)

This commit is contained in:
TheRam_
2024-02-20 18:04:32 +08:00
committed by GitHub
parent 05b6d989b6
commit e2da3406d2
16 changed files with 650 additions and 73 deletions

View File

@@ -15,6 +15,7 @@ import ExportIcon from "../icons/share.svg";
import ReturnIcon from "../icons/return.svg";
import CopyIcon from "../icons/copy.svg";
import LoadingIcon from "../icons/three-dots.svg";
import LoadingButtonIcon from "../icons/loading.svg";
import PromptIcon from "../icons/prompt.svg";
import MaskIcon from "../icons/mask.svg";
import MaxIcon from "../icons/max.svg";
@@ -27,6 +28,7 @@ import PinIcon from "../icons/pin.svg";
import EditIcon from "../icons/rename.svg";
import ConfirmIcon from "../icons/confirm.svg";
import CancelIcon from "../icons/cancel.svg";
import ImageIcon from "../icons/image.svg";
import LightIcon from "../icons/light.svg";
import DarkIcon from "../icons/dark.svg";
@@ -53,6 +55,10 @@ import {
selectOrCopy,
autoGrowTextArea,
useMobileScreen,
getMessageTextContent,
getMessageImages,
isVisionModel,
compressImage,
} from "../utils";
import dynamic from "next/dynamic";
@@ -89,6 +95,7 @@ import { prettyObject } from "../utils/format";
import { ExportMessageModal } from "./exporter";
import { getClientConfig } from "../config/client";
import { useAllModels } from "../utils/hooks";
import { MultimodalContent } from "../client/api";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
@@ -406,10 +413,14 @@ function useScrollToBottom() {
}
export function ChatActions(props: {
uploadImage: () => void;
setAttachImages: (images: string[]) => void;
setUploading: (uploading: boolean) => void;
showPromptModal: () => void;
scrollToBottom: () => void;
showPromptHints: () => void;
hitBottom: boolean;
uploading: boolean;
}) {
const config = useAppConfig();
const navigate = useNavigate();
@@ -437,8 +448,16 @@ export function ChatActions(props: {
[allModels],
);
const [showModelSelector, setShowModelSelector] = useState(false);
const [showUploadImage, setShowUploadImage] = useState(false);
useEffect(() => {
const show = isVisionModel(currentModel);
setShowUploadImage(show);
if (!show) {
props.setAttachImages([]);
props.setUploading(false);
}
// if current model is not available
// switch to first available model
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
@@ -475,6 +494,13 @@ export function ChatActions(props: {
/>
)}
{showUploadImage && (
<ChatAction
onClick={props.uploadImage}
text={Locale.Chat.InputActions.UploadImage}
icon={props.uploading ? <LoadingButtonIcon /> : <ImageIcon />}
/>
)}
<ChatAction
onClick={nextTheme}
text={Locale.Chat.InputActions.Theme[theme]}
@@ -610,6 +636,14 @@ export function EditMessageModal(props: { onClose: () => void }) {
);
}
export function DeleteImageButton(props: { deleteImage: () => void }) {
return (
<div className={styles["delete-image"]} onClick={props.deleteImage}>
<DeleteIcon />
</div>
);
}
function _Chat() {
type RenderMessage = ChatMessage & { preview?: boolean };
@@ -628,6 +662,8 @@ function _Chat() {
const [hitBottom, setHitBottom] = useState(true);
const isMobileScreen = useMobileScreen();
const navigate = useNavigate();
const [attachImages, setAttachImages] = useState<string[]>([]);
const [uploading, setUploading] = useState(false);
// prompt hints
const promptStore = usePromptStore();
@@ -705,7 +741,10 @@ function _Chat() {
return;
}
setIsLoading(true);
chatStore.onUserInput(userInput).then(() => setIsLoading(false));
chatStore
.onUserInput(userInput, attachImages)
.then(() => setIsLoading(false));
setAttachImages([]);
localStorage.setItem(LAST_INPUT_KEY, userInput);
setUserInput("");
setPromptHints([]);
@@ -783,9 +822,9 @@ function _Chat() {
};
const onRightClick = (e: any, message: ChatMessage) => {
// copy to clipboard
if (selectOrCopy(e.currentTarget, message.content)) {
if (selectOrCopy(e.currentTarget, getMessageTextContent(message))) {
if (userInput.length === 0) {
setUserInput(message.content);
setUserInput(getMessageTextContent(message));
}
e.preventDefault();
@@ -853,7 +892,9 @@ function _Chat() {
// resend the message
setIsLoading(true);
chatStore.onUserInput(userMessage.content).then(() => setIsLoading(false));
const textContent = getMessageTextContent(userMessage);
const images = getMessageImages(userMessage);
chatStore.onUserInput(textContent, images).then(() => setIsLoading(false));
inputRef.current?.focus();
};
@@ -1048,6 +1089,51 @@ function _Chat() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
async function uploadImage() {
const images: string[] = [];
images.push(...attachImages);
images.push(
...(await new Promise<string[]>((res, rej) => {
const fileInput = document.createElement("input");
fileInput.type = "file";
fileInput.accept =
"image/png, image/jpeg, image/webp, image/heic, image/heif";
fileInput.multiple = true;
fileInput.onchange = (event: any) => {
setUploading(true);
const files = event.target.files;
const imagesData: string[] = [];
for (let i = 0; i < files.length; i++) {
const file = event.target.files[i];
compressImage(file, 256 * 1024)
.then((dataUrl) => {
imagesData.push(dataUrl);
if (
imagesData.length === 3 ||
imagesData.length === files.length
) {
setUploading(false);
res(imagesData);
}
})
.catch((e) => {
setUploading(false);
rej(e);
});
}
};
fileInput.click();
})),
);
const imagesLength = images.length;
if (imagesLength > 3) {
images.splice(3, imagesLength - 3);
}
setAttachImages(images);
}
return (
<div className={styles.chat} key={session.id}>
<div className="window-header" data-tauri-drag-region>
@@ -1154,15 +1240,29 @@ function _Chat() {
onClick={async () => {
const newMessage = await showPrompt(
Locale.Chat.Actions.Edit,
message.content,
getMessageTextContent(message),
10,
);
let newContent: string | MultimodalContent[] =
newMessage;
const images = getMessageImages(message);
if (images.length > 0) {
newContent = [{ type: "text", text: newMessage }];
for (let i = 0; i < images.length; i++) {
newContent.push({
type: "image_url",
image_url: {
url: images[i],
},
});
}
}
chatStore.updateCurrentSession((session) => {
const m = session.mask.context
.concat(session.messages)
.find((m) => m.id === message.id);
if (m) {
m.content = newMessage;
m.content = newContent;
}
});
}}
@@ -1217,7 +1317,11 @@ function _Chat() {
<ChatAction
text={Locale.Chat.Actions.Copy}
icon={<CopyIcon />}
onClick={() => copyToClipboard(message.content)}
onClick={() =>
copyToClipboard(
getMessageTextContent(message),
)
}
/>
</>
)}
@@ -1232,7 +1336,7 @@ function _Chat() {
)}
<div className={styles["chat-message-item"]}>
<Markdown
content={message.content}
content={getMessageTextContent(message)}
loading={
(message.preview || message.streaming) &&
message.content.length === 0 &&
@@ -1241,12 +1345,42 @@ function _Chat() {
onContextMenu={(e) => onRightClick(e, message)}
onDoubleClickCapture={() => {
if (!isMobileScreen) return;
setUserInput(message.content);
setUserInput(getMessageTextContent(message));
}}
fontSize={fontSize}
parentRef={scrollRef}
defaultShow={i >= messages.length - 6}
/>
{getMessageImages(message).length == 1 && (
<img
className={styles["chat-message-item-image"]}
src={getMessageImages(message)[0]}
alt=""
/>
)}
{getMessageImages(message).length > 1 && (
<div
className={styles["chat-message-item-images"]}
style={
{
"--image-count": getMessageImages(message).length,
} as React.CSSProperties
}
>
{getMessageImages(message).map((image, index) => {
return (
<img
className={
styles["chat-message-item-image-multi"]
}
key={index}
src={image}
alt=""
/>
);
})}
</div>
)}
</div>
<div className={styles["chat-message-action-date"]}>
@@ -1266,9 +1400,13 @@ function _Chat() {
<PromptHints prompts={promptHints} onPromptSelect={onPromptSelect} />
<ChatActions
uploadImage={uploadImage}
setAttachImages={setAttachImages}
setUploading={setUploading}
showPromptModal={() => setShowPromptModal(true)}
scrollToBottom={scrollToBottom}
hitBottom={hitBottom}
uploading={uploading}
showPromptHints={() => {
// Click again to close
if (promptHints.length > 0) {
@@ -1281,8 +1419,16 @@ function _Chat() {
onSearch("");
}}
/>
<div className={styles["chat-input-panel-inner"]}>
<label
className={`${styles["chat-input-panel-inner"]} ${
attachImages.length != 0
? styles["chat-input-panel-inner-attach"]
: ""
}`}
htmlFor="chat-input"
>
<textarea
id="chat-input"
ref={inputRef}
className={styles["chat-input"]}
placeholder={Locale.Chat.Input(submitKey)}
@@ -1297,6 +1443,29 @@ function _Chat() {
fontSize: config.fontSize,
}}
/>
{attachImages.length != 0 && (
<div className={styles["attach-images"]}>
{attachImages.map((image, index) => {
return (
<div
key={index}
className={styles["attach-image"]}
style={{ backgroundImage: `url("${image}")` }}
>
<div className={styles["attach-image-mask"]}>
<DeleteImageButton
deleteImage={() => {
setAttachImages(
attachImages.filter((_, i) => i !== index),
);
}}
/>
</div>
</div>
);
})}
</div>
)}
<IconButton
icon={<SendWhiteIcon />}
text={Locale.Chat.Send}
@@ -1304,7 +1473,7 @@ function _Chat() {
type="primary"
onClick={() => doSubmit(userInput)}
/>
</div>
</label>
</div>
{showExport && (