diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 064590bf8..11f77120a 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -13,6 +13,13 @@ import { LLMUsage, } from "../api"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import axios from "axios"; + +const getImageBase64Data = async (url: string) => { + const response = await axios.get(url, { responseType: "arraybuffer" }); + const base64 = Buffer.from(response.data, "binary").toString("base64"); + return base64; +}; export class GeminiProApi implements LLMApi { toolAgentChat(options: AgentChatOptions): Promise { @@ -28,11 +35,32 @@ export class GeminiProApi implements LLMApi { ); } async chat(options: ChatOptions): Promise { - const apiClient = this; - const messages = options.messages.map((v) => ({ - role: v.role.replace("assistant", "model").replace("system", "user"), - parts: [{ text: v.content }], - })); + const messages: any[] = []; + if (options.config.model.includes("vision")) { + for (const v of options.messages) { + let message: any = { + role: v.role.replace("assistant", "model").replace("system", "user"), + parts: [{ text: v.content }], + }; + if (v.image_url) { + var base64Data = await getImageBase64Data(v.image_url); + message.parts.push({ + inline_data: { + mime_type: "image/jpeg", + data: base64Data, + }, + }); + } + messages.push(message); + } + } else { + options.messages.map((v) => + messages.push({ + role: v.role.replace("assistant", "model").replace("system", "user"), + parts: [{ text: v.content }], + }), + ); + } // google requires that role in neighboring messages must not be the same for (let i = 0; i < messages.length - 1; ) { @@ -92,7 +120,9 @@ export class GeminiProApi implements LLMApi { const controller = new AbortController(); options.onController?.(controller); try { - const chatPath = this.path(Google.ChatPath); + const chatPath = this.path( + Google.ChatPath.replace("{{model}}", options.config.model), + ); const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 01659e93f..3decfcf5a 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -140,10 +140,9 @@ export class ChatGPTApi implements LLMApi { presence_penalty: modelConfig.presence_penalty, frequency_penalty: modelConfig.frequency_penalty, top_p: modelConfig.top_p, - max_tokens: - modelConfig.model == "gpt-4-vision-preview" - ? modelConfig.max_tokens - : null, + max_tokens: modelConfig.model.includes("vision") + ? modelConfig.max_tokens + : null, // max_tokens: Math.max(modelConfig.max_tokens, 1024), // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. }; diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 42f38a13e..5ff0cd015 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -538,7 +538,7 @@ export function ChatActions(props: { } } }; - if (currentModel === "gpt-4-vision-preview") { + if (currentModel.includes("vision")) { window.addEventListener("paste", onPaste); return () => { window.removeEventListener("paste", onPaste); @@ -620,7 +620,7 @@ export function ChatActions(props: { icon={usePlugins ? : } /> )} - {currentModel == "gpt-4-vision-preview" && ( + {currentModel.includes("vision") && ( = messages.length - 6} /> - {!isUser && message.model == "gpt-4-vision-preview" && ( + {!isUser && message.model?.includes("vision") && (