From 712022d8c7838cab88fd3cc2ab2304164f4e3ed3 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Sun, 7 Jul 2024 15:41:58 +0800 Subject: [PATCH] feat: optimize rag --- .env.template | 2 - README.md | 2 +- app/api/langchain-tools/myfiles_browser.ts | 78 ++++++++++++++++++ app/api/langchain-tools/nodejs_tools.ts | 8 +- app/api/langchain-tools/rag_search.ts | 79 ------------------- app/api/langchain/rag/store/route.ts | 56 ++++++++----- app/api/langchain/tool/agent/nodejs/route.ts | 25 ++++-- app/client/api.ts | 2 +- app/client/platforms/anthropic.ts | 2 +- app/client/platforms/google.ts | 2 +- app/client/platforms/openai.ts | 5 +- app/client/platforms/utils.ts | 16 +++- app/components/chat.tsx | 41 ++++------ app/constant.ts | 27 +++++++ app/store/chat.ts | 50 +++++++----- docs/images/rag-example-2.jpg | Bin 0 -> 127979 bytes docs/rag-cn.md | 38 +++++---- package.json | 1 + yarn.lock | 74 +++++++++++++++++ 19 files changed, 332 insertions(+), 176 deletions(-) create mode 100644 app/api/langchain-tools/myfiles_browser.ts delete mode 100644 app/api/langchain-tools/rag_search.ts create mode 100644 docs/images/rag-example-2.jpg diff --git a/.env.template b/.env.template index e9166ce30..3566fe36a 100644 --- a/.env.template +++ b/.env.template @@ -85,8 +85,6 @@ ANTHROPIC_API_KEY= ### anthropic claude Api version. (optional) ANTHROPIC_API_VERSION= - - ### anthropic claude Api url (optional) ANTHROPIC_URL= diff --git a/README.md b/README.md index d425d87ff..20490ab62 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ ## 主要功能 -- RAG 功能 (预览) +- RAG 功能 - 配置请参考文档[RAG 功能配置说明](./docs/rag-cn.md) - 除插件工具外,与原项目保持一致 [ChatGPT-Next-Web 主要功能](https://github.com/Yidadaa/ChatGPT-Next-Web#主要功能) diff --git a/app/api/langchain-tools/myfiles_browser.ts b/app/api/langchain-tools/myfiles_browser.ts new file mode 100644 index 000000000..906b5e8e1 --- /dev/null +++ b/app/api/langchain-tools/myfiles_browser.ts @@ -0,0 +1,78 @@ +import { Tool } from "@langchain/core/tools"; +import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager"; +import { BaseLanguageModel } from "langchain/dist/base_language"; +import { formatDocumentsAsString } from "langchain/util/document"; +import { Embeddings } from "langchain/dist/embeddings/base.js"; +import { getServerSideConfig } from "@/app/config/server"; +import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase"; +import { createClient } from "@supabase/supabase-js"; +import { z } from "zod"; +import { StructuredTool } from "@langchain/core/tools"; + +export class MyFilesBrowser extends StructuredTool { + static lc_name() { + return "MyFilesBrowser"; + } + + get lc_namespace() { + return [...super.lc_namespace, "myfilesbrowser"]; + } + + private sessionId: string; + private model: BaseLanguageModel; + private embeddings: Embeddings; + + constructor( + sessionId: string, + model: BaseLanguageModel, + embeddings: Embeddings, + ) { + super(); + this.sessionId = sessionId; + this.model = model; + this.embeddings = embeddings; + } + + schema = z.object({ + queries: z.array(z.string()).describe("A query list."), + }); + + /** @ignore */ + async _call({ queries }: z.infer) { + const serverConfig = getServerSideConfig(); + if (!serverConfig.isEnableRAG) + throw new Error("env ENABLE_RAG not configured"); + + const privateKey = process.env.SUPABASE_PRIVATE_KEY; + if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`); + + const url = process.env.SUPABASE_URL; + if (!url) throw new Error(`Expected env var SUPABASE_URL`); + const client = createClient(url, privateKey); + const vectorStore = new SupabaseVectorStore(this.embeddings, { + client, + tableName: "documents", + queryName: "match_documents", + }); + + let context; + const returnCunt = serverConfig.ragReturnCount + ? parseInt(serverConfig.ragReturnCount, 10) + : 4; + console.log("[myfiles_browser]", { queries, returnCunt }); + let documents: any[] = []; + for (var i = 0; i < queries.length; i++) { + let results = await vectorStore.similaritySearch(queries[i], returnCunt, { + sessionId: this.sessionId, + }); + results.forEach((item) => documents.push(item)); + } + context = formatDocumentsAsString(documents); + console.log("[myfiles_browser]", { context }); + return context; + } + + name = "myfiles_browser"; + + description = `queries to a search over the file(s) uploaded in the current conversation and displays the results.`; +} diff --git a/app/api/langchain-tools/nodejs_tools.ts b/app/api/langchain-tools/nodejs_tools.ts index 632971715..de6529957 100644 --- a/app/api/langchain-tools/nodejs_tools.ts +++ b/app/api/langchain-tools/nodejs_tools.ts @@ -10,7 +10,7 @@ import { WolframAlphaTool } from "@/app/api/langchain-tools/wolframalpha"; import { BilibiliVideoInfoTool } from "./bilibili_vid_info"; import { BilibiliVideoSearchTool } from "./bilibili_vid_search"; import { BilibiliMusicRecognitionTool } from "./bilibili_music_recognition"; -import { RAGSearch } from "./rag_search"; +import { MyFilesBrowser } from "./myfiles_browser"; import { BilibiliVideoConclusionTool } from "./bilibili_vid_conclusion"; export class NodeJSTool { @@ -59,7 +59,7 @@ export class NodeJSTool { const bilibiliVideoSearchTool = new BilibiliVideoSearchTool(); const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool(); const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool(); - let tools = [ + let tools: any = [ calculatorTool, webBrowserTool, dallEAPITool, @@ -73,7 +73,9 @@ export class NodeJSTool { bilibiliVideoConclusionTool, ]; if (!!process.env.ENABLE_RAG) { - tools.push(new RAGSearch(this.sessionId, this.model, this.ragEmbeddings)); + tools.push( + new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings), + ); } return tools; } diff --git a/app/api/langchain-tools/rag_search.ts b/app/api/langchain-tools/rag_search.ts deleted file mode 100644 index c3db3c4c3..000000000 --- a/app/api/langchain-tools/rag_search.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { Tool } from "@langchain/core/tools"; -import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager"; -import { BaseLanguageModel } from "langchain/dist/base_language"; -import { formatDocumentsAsString } from "langchain/util/document"; -import { Embeddings } from "langchain/dist/embeddings/base.js"; -import { RunnableSequence } from "@langchain/core/runnables"; -import { StringOutputParser } from "@langchain/core/output_parsers"; -import { Pinecone } from "@pinecone-database/pinecone"; -import { PineconeStore } from "@langchain/pinecone"; -import { getServerSideConfig } from "@/app/config/server"; -import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant"; - -export class RAGSearch extends Tool { - static lc_name() { - return "RAGSearch"; - } - - get lc_namespace() { - return [...super.lc_namespace, "ragsearch"]; - } - - private sessionId: string; - private model: BaseLanguageModel; - private embeddings: Embeddings; - - constructor( - sessionId: string, - model: BaseLanguageModel, - embeddings: Embeddings, - ) { - super(); - this.sessionId = sessionId; - this.model = model; - this.embeddings = embeddings; - } - - /** @ignore */ - async _call(inputs: string, runManager?: CallbackManagerForToolRun) { - const serverConfig = getServerSideConfig(); - if (!serverConfig.isEnableRAG) - throw new Error("env ENABLE_RAG not configured"); - // const pinecone = new Pinecone(); - // const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!); - // const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, { - // pineconeIndex, - // }); - const vectorStore = await QdrantVectorStore.fromExistingCollection( - this.embeddings, - { - url: process.env.QDRANT_URL, - apiKey: process.env.QDRANT_API_KEY, - collectionName: this.sessionId, - }, - ); - - let context; - const returnCunt = serverConfig.ragReturnCount - ? parseInt(serverConfig.ragReturnCount, 10) - : 4; - console.log("[rag-search]", { inputs, returnCunt }); - // const results = await vectorStore.similaritySearch(inputs, returnCunt, { - // sessionId: this.sessionId, - // }); - const results = await vectorStore.similaritySearch(inputs, returnCunt); - context = formatDocumentsAsString(results); - console.log("[rag-search]", { context }); - return context; - // const input = `Text:${context}\n\nQuestion:${inputs}\n\nI need you to answer the question based on the text.`; - - // console.log("[rag-search]", input); - - // const chain = RunnableSequence.from([this.model, new StringOutputParser()]); - // return chain.invoke(input, runManager?.getChild()); - } - - name = "rag-search"; - - description = `It is used to query documents entered by the user.The input content is the keywords extracted from the user's question, and multiple keywords are separated by spaces and passed in.`; -} diff --git a/app/api/langchain/rag/store/route.ts b/app/api/langchain/rag/store/route.ts index 9ded033d9..ad97f4330 100644 --- a/app/api/langchain/rag/store/route.ts +++ b/app/api/langchain/rag/store/route.ts @@ -20,7 +20,10 @@ import { FileInfo } from "@/app/client/platforms/utils"; import mime from "mime"; import LocalFileStorage from "@/app/utils/local_file_storage"; import S3FileStorage from "@/app/utils/s3_file_storage"; -import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant"; +import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"; +import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase"; +import { createClient } from "@supabase/supabase-js"; +import { Embeddings } from "langchain/dist/embeddings/base"; interface RequestBody { sessionId: string; @@ -67,6 +70,11 @@ async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); } + const privateKey = process.env.SUPABASE_PRIVATE_KEY; + if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`); + const url = process.env.SUPABASE_URL; + if (!url) throw new Error(`Expected env var SUPABASE_URL`); + try { const authResult = auth(req, ModelProvider.GPT); if (authResult.error) { @@ -81,18 +89,25 @@ async function handle(req: NextRequest) { const apiKey = getOpenAIApiKey(token); const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl); const serverConfig = getServerSideConfig(); - // const pinecone = new Pinecone(); - // const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!); - const embeddings = new OpenAIEmbeddings( - { - modelName: serverConfig.ragEmbeddingModel, - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); + let embeddings: Embeddings; + if (process.env.OLLAMA_BASE_URL) { + embeddings = new OllamaEmbeddings({ + model: serverConfig.ragEmbeddingModel, + baseUrl: process.env.OLLAMA_BASE_URL, + }); + } else { + embeddings = new OpenAIEmbeddings( + { + modelName: serverConfig.ragEmbeddingModel, + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + } // https://js.langchain.com/docs/integrations/vectorstores/pinecone // https://js.langchain.com/docs/integrations/vectorstores/qdrant // process files + let partial = ""; for (let i = 0; i < reqBody.fileInfos.length; i++) { const fileInfo = reqBody.fileInfos[i]; const contentType = mime.getType(fileInfo.fileName); @@ -134,26 +149,25 @@ async function handle(req: NextRequest) { chunkOverlap: chunkOverlap, }); const splits = await textSplitter.splitDocuments(docs); - const vectorStore = await QdrantVectorStore.fromDocuments( + const client = createClient(url, privateKey); + const vectorStore = await SupabaseVectorStore.fromDocuments( splits, embeddings, { - url: process.env.QDRANT_URL, - apiKey: process.env.QDRANT_API_KEY, - collectionName: reqBody.sessionId, + client, + tableName: "documents", + queryName: "match_documents", }, ); - // await PineconeStore.fromDocuments(splits, embeddings, { - // pineconeIndex, - // maxConcurrency: 5, - // }); - // const vectorStore = await PineconeStore.fromExistingIndex(embeddings, { - // pineconeIndex, - // }); + partial = splits + .slice(0, 2) + .map((v) => v.pageContent) + .join("\n"); } return NextResponse.json( { sessionId: reqBody.sessionId, + partial: partial, }, { status: 200, diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index e8169373b..85be0d930 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -4,6 +4,8 @@ import { auth } from "@/app/api/auth"; import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; import { ModelProvider } from "@/app/constant"; import { OpenAI, OpenAIEmbeddings } from "@langchain/openai"; +import { Embeddings } from "langchain/dist/embeddings/base"; +import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { @@ -44,13 +46,22 @@ async function handle(req: NextRequest) { }, { basePath: baseUrl }, ); - const ragEmbeddings = new OpenAIEmbeddings( - { - modelName: process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large", - openAIApiKey: apiKey, - }, - { basePath: baseUrl }, - ); + let ragEmbeddings: Embeddings; + if (process.env.OLLAMA_BASE_URL) { + ragEmbeddings = new OllamaEmbeddings({ + model: process.env.RAG_EMBEDDING_MODEL, + baseUrl: process.env.OLLAMA_BASE_URL, + }); + } else { + ragEmbeddings = new OpenAIEmbeddings( + { + modelName: + process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large", + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + } var dalleCallback = async (data: string) => { var response = new ResponseBody(); diff --git a/app/client/api.ts b/app/client/api.ts index 9d14d31d2..bd576950c 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -116,7 +116,7 @@ export abstract class LLMApi { abstract speech(options: SpeechOptions): Promise; abstract transcription(options: TranscriptionOptions): Promise; abstract toolAgentChat(options: AgentChatOptions): Promise; - abstract createRAGStore(options: CreateRAGStoreOptions): Promise; + abstract createRAGStore(options: CreateRAGStoreOptions): Promise; abstract usage(): Promise; abstract models(): Promise; } diff --git a/app/client/platforms/anthropic.ts b/app/client/platforms/anthropic.ts index c2b62ca04..9d09d0b2a 100644 --- a/app/client/platforms/anthropic.ts +++ b/app/client/platforms/anthropic.ts @@ -89,7 +89,7 @@ export class ClaudeApi implements LLMApi { toolAgentChat(options: AgentChatOptions): Promise { throw new Error("Method not implemented."); } - createRAGStore(options: CreateRAGStoreOptions): Promise { + createRAGStore(options: CreateRAGStoreOptions): Promise { throw new Error("Method not implemented."); } extractMessage(res: any) { diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 600b3f4f7..93c00b23d 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -20,7 +20,7 @@ import { } from "@/app/utils"; export class GeminiProApi implements LLMApi { - createRAGStore(options: CreateRAGStoreOptions): Promise { + createRAGStore(options: CreateRAGStoreOptions): Promise { throw new Error("Method not implemented."); } transcription(options: TranscriptionOptions): Promise { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index fab0e533c..70a639b93 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -373,7 +373,7 @@ export class ChatGPTApi implements LLMApi { } } - async createRAGStore(options: CreateRAGStoreOptions): Promise { + async createRAGStore(options: CreateRAGStoreOptions): Promise { try { const accessStore = useAccessStore.getState(); const isAzure = accessStore.provider === ServiceProvider.Azure; @@ -395,9 +395,12 @@ export class ChatGPTApi implements LLMApi { }; const res = await fetch(path, chatPayload); if (res.status !== 200) throw new Error(await res.text()); + const resJson = await res.json(); + return resJson.partial; } catch (e) { console.log("[Request] failed to make a chat reqeust", e); options.onError?.(e as Error); + return ""; } } diff --git a/app/client/platforms/utils.ts b/app/client/platforms/utils.ts index 543b96a12..a6df0e3b2 100644 --- a/app/client/platforms/utils.ts +++ b/app/client/platforms/utils.ts @@ -1,10 +1,13 @@ -import { getHeaders } from "../api"; +import { getClientApi } from "@/app/utils"; +import { ClientApi, getHeaders } from "../api"; +import { ChatSession } from "@/app/store"; export interface FileInfo { originalFilename: string; fileName: string; filePath: string; size: number; + partial?: string; } export class FileApi { @@ -31,4 +34,15 @@ export class FileApi { filePath: resJson.filePath, }; } + + async uploadForRag(file: any, session: ChatSession): Promise { + var fileInfo = await this.upload(file); + var api: ClientApi = getClientApi(session.mask.modelConfig.model); + let partial = await api.llm.createRAGStore({ + chatSessionId: session.id, + fileInfos: [fileInfo], + }); + fileInfo.partial = partial; + return fileInfo; + } } diff --git a/app/components/chat.tsx b/app/components/chat.tsx index c7e4af82a..e3f67630b 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -1398,32 +1398,25 @@ function _Chat() { const fileInput = document.createElement("input"); fileInput.type = "file"; fileInput.accept = ".pdf,.txt,.md,.json,.csv,.docx,.srt,.mp3"; - fileInput.multiple = true; + fileInput.multiple = false; fileInput.onchange = (event: any) => { setUploading(true); - const files = event.target.files; + const file = event.target.files[0]; const api = new ClientApi(); const fileDatas: FileInfo[] = []; - for (let i = 0; i < files.length; i++) { - const file = event.target.files[i]; - api.file - .upload(file) - .then((fileInfo) => { - console.log(fileInfo); - fileDatas.push(fileInfo); - if ( - fileDatas.length === 3 || - fileDatas.length === files.length - ) { - setUploading(false); - res(fileDatas); - } - }) - .catch((e) => { - setUploading(false); - rej(e); - }); - } + api.file + .uploadForRag(file, session) + .then((fileInfo) => { + console.log(fileInfo); + fileDatas.push(fileInfo); + session.attachFiles.push(fileInfo); + setUploading(false); + res(fileDatas); + }) + .catch((e) => { + setUploading(false); + rej(e); + }); }; fileInput.click(); })), @@ -1694,7 +1687,7 @@ function _Chat() { parentRef={scrollRef} defaultShow={i >= messages.length - 6} /> - {message.fileInfos && message.fileInfos.length > 0 && ( + {/* {message.fileInfos && message.fileInfos.length > 0 && (