From fa2e046285b824b61ba7cbeaa1eca4c088efa8ad Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Mon, 25 Dec 2023 12:40:09 +0800 Subject: [PATCH] Merge remote --- .env.template | 2 +- README_CN.md | 2 +- app/api/common.ts | 5 ++-- app/api/file/upload/route.ts | 3 ++- app/api/google/[...path]/route.ts | 6 ++--- app/api/langchain/tool/agent/edge/route.ts | 3 ++- app/api/langchain/tool/agent/nodejs/route.ts | 3 ++- app/client/api.ts | 4 +-- app/client/platforms/google.ts | 26 ++++++++++++++------ app/client/platforms/utils.ts | 4 +-- app/components/chat.tsx | 4 ++- app/components/settings.tsx | 4 +-- app/config/server.ts | 4 +-- app/constant.ts | 4 +-- app/store/access.ts | 2 +- app/store/chat.ts | 2 +- 16 files changed, 47 insertions(+), 31 deletions(-) diff --git a/.env.template b/.env.template index 0f14aa85f..c7cf0bcdd 100644 --- a/.env.template +++ b/.env.template @@ -16,7 +16,7 @@ GOOGLE_API_KEY= # (optional) # Default: https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent # Googel Gemini Pro API url, set if you want to customize Google Gemini Pro API url. -GOOGLE_URL= +GOOGLE_BASE_URL= # Override openai api request base url. (optional) # Default: https://api.openai.com diff --git a/README_CN.md b/README_CN.md index cd0225c92..422802a2b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -110,7 +110,7 @@ Azure Api 版本,你可以在这里找到:[Azure 文档](https://learn.micro Google Gemini Pro 密钥. -### `GOOGLE_URL` (optional) +### `GOOGLE_BASE_URL` (optional) Google Gemini Pro Api Url. diff --git a/app/api/common.ts b/app/api/common.ts index a6f4c5721..9df835d17 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -9,15 +9,14 @@ const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); + let authValue = req.headers.get("Authorization") ?? ""; if (serverConfig.isAzure) { - const authValue = + authValue = req.headers .get("Authorization") ?.trim() .replaceAll("Bearer ", "") .trim() ?? ""; - } else { - const authValue = req.headers.get("Authorization") ?? ""; } const authHeaderName = serverConfig.isAzure ? "api-key" : "Authorization"; diff --git a/app/api/file/upload/route.ts b/app/api/file/upload/route.ts index 92c9cee89..7b37066a7 100644 --- a/app/api/file/upload/route.ts +++ b/app/api/file/upload/route.ts @@ -1,13 +1,14 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; import S3FileStorage from "../../../utils/s3_file_storage"; +import { ModelProvider } from "@/app/constant"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } - const authResult = auth(req); + const authResult = auth(req, ModelProvider.GPT); if (authResult.error) { return NextResponse.json(authResult, { status: 401, diff --git a/app/api/google/[...path]/route.ts b/app/api/google/[...path]/route.ts index 869bd5076..ef148d6d7 100644 --- a/app/api/google/[...path]/route.ts +++ b/app/api/google/[...path]/route.ts @@ -1,7 +1,7 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; import { getServerSideConfig } from "@/app/config/server"; -import { GEMINI_BASE_URL, Google, ModelProvider } from "@/app/constant"; +import { GEMINI_BASE_URL, ModelProvider } from "@/app/constant"; async function handle( req: NextRequest, @@ -17,7 +17,7 @@ async function handle( const serverConfig = getServerSideConfig(); - let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL; + let baseUrl = serverConfig.googleBaseUrl || GEMINI_BASE_URL; if (!baseUrl.startsWith("http")) { baseUrl = `https://${baseUrl}`; @@ -63,7 +63,7 @@ async function handle( ); } - const fetchUrl = `${baseUrl}/${path}?key=${key}`; + const fetchUrl = `${baseUrl}/${path}?key=${key}&alt=sse`; const fetchOptions: RequestInit = { headers: { "Content-Type": "application/json", diff --git a/app/api/langchain/tool/agent/edge/route.ts b/app/api/langchain/tool/agent/edge/route.ts index f48099e79..4de6a73a8 100644 --- a/app/api/langchain/tool/agent/edge/route.ts +++ b/app/api/langchain/tool/agent/edge/route.ts @@ -4,13 +4,14 @@ import { auth } from "@/app/api/auth"; import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { ModelProvider } from "@/app/constant"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } try { - const authResult = auth(req); + const authResult = auth(req, ModelProvider.GPT); if (authResult.error) { return NextResponse.json(authResult, { status: 401, diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index 201dbd49c..63cee7b53 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -5,13 +5,14 @@ import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; +import { ModelProvider } from "@/app/constant"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } try { - const authResult = auth(req); + const authResult = auth(req, ModelProvider.GPT); if (authResult.error) { return NextResponse.json(authResult, { status: 401, diff --git a/app/client/api.ts b/app/client/api.ts index e8ebe806f..b5a72154c 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -115,9 +115,9 @@ export class ClientApi { constructor(provider: ModelProvider = ModelProvider.GPT) { if (provider === ModelProvider.GeminiPro) { this.llm = new GeminiProApi(); - return; + } else { + this.llm = new ChatGPTApi(); } - this.llm = new ChatGPTApi(); this.file = new FileApi(); } diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index c35e93cb3..83619774d 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -1,5 +1,12 @@ import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; -import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; +import { + AgentChatOptions, + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + LLMUsage, +} from "../api"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { EventStreamContentType, @@ -10,6 +17,9 @@ import { getClientConfig } from "@/app/config/client"; import Locale from "../../locales"; import { getServerSideConfig } from "@/app/config/server"; export class GeminiProApi implements LLMApi { + toolAgentChat(options: AgentChatOptions): Promise { + throw new Error("Method not implemented."); + } extractMessage(res: any) { console.log("[Response] gemini-pro response: ", res); @@ -62,7 +72,7 @@ export class GeminiProApi implements LLMApi { console.log("[Request] google payload: ", requestPayload); // todo: support stream later - const shouldStream = false; + const shouldStream = true; const controller = new AbortController(); options.onController?.(controller); try { @@ -121,7 +131,7 @@ export class GeminiProApi implements LLMApi { clearTimeout(requestTimeoutId); const contentType = res.headers.get("content-type"); console.log( - "[OpenAI] request response content type: ", + "[Google] request response content type: ", contentType, ); @@ -164,13 +174,15 @@ export class GeminiProApi implements LLMApi { const text = msg.data; try { const json = JSON.parse(text) as { - choices: Array<{ - delta: { - content: string; + candidates: Array<{ + content: { + parts: Array<{ + text: string; + }>; }; }>; }; - const delta = json.choices[0]?.delta?.content; + const delta = json.candidates[0]?.content?.parts[0]?.text; if (delta) { remainText += delta; } diff --git a/app/client/platforms/utils.ts b/app/client/platforms/utils.ts index d796166ab..a4127d0fe 100644 --- a/app/client/platforms/utils.ts +++ b/app/client/platforms/utils.ts @@ -1,10 +1,10 @@ -import { getAuthHeaders } from "../api"; +import { getHeaders } from "../api"; export class FileApi { async upload(file: any): Promise { const formData = new FormData(); formData.append("file", file); - var headers = getAuthHeaders(); + var headers = getHeaders(); var res = await fetch("/api/file/upload", { method: "POST", body: formData, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index de1287ed2..c1364701e 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -96,7 +96,7 @@ import { ExportMessageModal } from "./exporter"; import { getClientConfig } from "../config/client"; import { useAllModels } from "../utils/hooks"; import Image from "next/image"; -import { api } from "../client/api"; +import { ClientApi } from "../client/api"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -464,6 +464,7 @@ export function ChatActions(props: { const onImageSelected = async (e: any) => { const file = e.target.files[0]; + const api = new ClientApi(); const fileName = await api.file.upload(file); props.imageSelected({ fileName, @@ -494,6 +495,7 @@ export function ChatActions(props: { } const onPaste = (event: ClipboardEvent) => { const items = event.clipboardData?.items || []; + const api = new ClientApi(); for (let i = 0; i < items.length; i++) { if (items[i].type.indexOf("image") === -1) continue; const file = items[i].getAsFile(); diff --git a/app/components/settings.tsx b/app/components/settings.tsx index eb584c21d..dbf367ed0 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -1071,12 +1071,12 @@ export function Settings() { > accessStore.update( (access) => - (access.googleUrl = e.currentTarget.value), + (access.googleBaseUrl = e.currentTarget.value), ) } > diff --git a/app/config/server.ts b/app/config/server.ts index c6251a5c2..373c6d582 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -29,7 +29,7 @@ declare global { // google only GOOGLE_API_KEY?: string; - GOOGLE_URL?: string; + GOOGLE_BASE_URL?: string; } } } @@ -87,7 +87,7 @@ export const getServerSideConfig = () => { isGoogle, googleApiKey: process.env.GOOGLE_API_KEY, - googleUrl: process.env.GOOGLE_URL, + googleBaseUrl: process.env.GOOGLE_BASE_URL, needCode: ACCESS_CODES.size > 0, code: process.env.CODE, diff --git a/app/constant.ts b/app/constant.ts index 66b41870e..92184d1df 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -99,8 +99,8 @@ export const Azure = { export const Google = { ExampleEndpoint: - "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent", - ChatPath: "v1beta/models/gemini-pro:generateContent", + "https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent", + ChatPath: "v1/models/gemini-pro:generateContent", // /api/openai/v1/chat/completions }; diff --git a/app/store/access.ts b/app/store/access.ts index 67515a58a..eef999a60 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -30,7 +30,7 @@ const DEFAULT_ACCESS_STATE = { azureApiVersion: "2023-08-01-preview", // google ai studio - googleUrl: "", + googleBaseUrl: "", googleApiKey: "", googleApiVersion: "v1", diff --git a/app/store/chat.ts b/app/store/chat.ts index 3ad5f6220..447743bf4 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -662,7 +662,7 @@ export const useChatStore = createPersistStore( session.memoryPrompt = message; }, onFinish(message) { - console.log("[Memory] ", message); + // console.log("[Memory] ", message); get().updateCurrentSession((session) => { session.lastSummarizeIndex = lastSummarizeIndex; session.memoryPrompt = message; // Update the memory prompt for stored it in local storage