Add vision support (#4076)

This commit is contained in:
TheRam_
2024-02-20 18:04:32 +08:00
committed by GitHub
parent 05b6d989b6
commit e2da3406d2
16 changed files with 650 additions and 73 deletions

View File

@@ -1,4 +1,4 @@
import { trimTopic } from "../utils";
import { trimTopic, getMessageTextContent } from "../utils";
import Locale, { getLang } from "../locales";
import { showToast } from "../components/ui-lib";
@@ -12,8 +12,9 @@ import {
ModelProvider,
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import { ClientApi, RequestMessage } from "../client/api";
import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
@@ -84,11 +85,20 @@ function createEmptySession(): ChatSession {
function getSummarizeModel(currentModel: string) {
// if it is using gpt-* models, force to use 3.5 to summarize
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
if (currentModel.startsWith("gpt")) {
return SUMMARIZE_MODEL;
}
if (currentModel.startsWith("gemini-pro")) {
return GEMINI_SUMMARIZE_MODEL;
}
return currentModel;
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
return msgs.reduce(
(pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
0,
);
}
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
@@ -280,16 +290,36 @@ export const useChatStore = createPersistStore(
get().summarizeSession();
},
async onUserInput(content: string) {
async onUserInput(content: string, attachImages?: string[]) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
const userContent = fillTemplateWith(content, modelConfig);
console.log("[User Input] after template: ", userContent);
const userMessage: ChatMessage = createMessage({
let mContent: string | MultimodalContent[] = userContent;
if (attachImages && attachImages.length > 0) {
mContent = [
{
type: "text",
text: userContent,
},
];
mContent = mContent.concat(
attachImages.map((url) => {
return {
type: "image_url",
image_url: {
url: url,
},
};
}),
);
}
let userMessage: ChatMessage = createMessage({
role: "user",
content: userContent,
content: mContent,
});
const botMessage: ChatMessage = createMessage({
@@ -307,7 +337,7 @@ export const useChatStore = createPersistStore(
get().updateCurrentSession((session) => {
const savedUserMessage = {
...userMessage,
content,
content: mContent,
};
session.messages = session.messages.concat([
savedUserMessage,
@@ -461,7 +491,7 @@ export const useChatStore = createPersistStore(
) {
const msg = messages[i];
if (!msg || msg.isError) continue;
tokenCount += estimateTokenLength(msg.content);
tokenCount += estimateTokenLength(getMessageTextContent(msg));
reversedRecentMessages.push(msg);
}