diff --git a/app/utils.ts b/app/utils.ts index fbe77c114..baf45abe5 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -288,12 +288,16 @@ export function showPlugins(provider: ServiceProvider, model: string) { } export function adapter(config: Record) { - const { baseURL, url, params, ...rest } = config; + const { baseURL, url, params, method, data, ...rest } = config; const path = baseURL ? `${baseURL}${url}` : url; const fetchUrl = params ? `${path}?${new URLSearchParams(params as any).toString()}` : path; - return fetch(fetchUrl as string, rest) + return fetch(fetchUrl as string, { + ...rest, + method, + body: method.toUpperCase() == "GET" ? undefined : data, + }) .then((res) => res.text()) .then((data) => ({ data })); } diff --git a/app/utils/chat.ts b/app/utils/chat.ts index 7f3bb23c5..359b2c53e 100644 --- a/app/utils/chat.ts +++ b/app/utils/chat.ts @@ -10,6 +10,7 @@ import { fetchEventSource, } from "@fortaine/fetch-event-source"; import { prettyObject } from "./format"; +import { fetch as tauriFetch } from "./stream"; export function compressImage(file: Blob, maxSize: number): Promise { return new Promise((resolve, reject) => { @@ -287,6 +288,7 @@ export function stream( REQUEST_TIMEOUT_MS, ); fetchEventSource(chatPath, { + fetch: tauriFetch, ...chatPayload, async onopen(res) { clearTimeout(requestTimeoutId); diff --git a/app/utils/stream.ts b/app/utils/stream.ts index 8f9ccfbaa..09b898431 100644 --- a/app/utils/stream.ts +++ b/app/utils/stream.ts @@ -1,100 +1,94 @@ -// using tauri register_uri_scheme_protocol, register `stream:` protocol +// using tauri command to send request // see src-tauri/src/stream.rs, and src-tauri/src/main.rs -// 1. window.fetch(`stream://localhost/${fetchUrl}`), get request_id -// 2. listen event: `stream-response` multi times to get response headers and body +// 1. invoke('stream_fetch', {url, method, headers, body}), get response with headers. +// 2. listen event: `stream-response` multi times to get body type ResponseEvent = { id: number; payload: { request_id: number; status?: number; - error?: string; - name?: string; - value?: string; chunk?: number[]; }; }; export function fetch(url: string, options?: RequestInit): Promise { if (window.__TAURI__) { - const tauriUri = window.__TAURI__.convertFileSrc(url, "stream"); - const { signal, ...rest } = options || {}; - return window - .fetch(tauriUri, rest) - .then((r) => r.text()) - .then((rid) => parseInt(rid)) - .then((request_id: number) => { - // 1. using event to get status and statusText and headers, and resolve it - let resolve: Function | undefined; - let reject: Function | undefined; - let status: number; - let writable: WritableStream | undefined; - let writer: WritableStreamDefaultWriter | undefined; - const headers = new Headers(); - let unlisten: Function | undefined; + const { signal, method = "GET", headers = {}, body = [] } = options || {}; + return window.__TAURI__ + .invoke("stream_fetch", { + method, + url, + headers, + // TODO FormData + body: + typeof body === "string" + ? Array.from(new TextEncoder().encode(body)) + : [], + }) + .then( + (res: { + request_id: number; + status: number; + status_text: string; + headers: Record; + }) => { + const { request_id, status, status_text: statusText, headers } = res; + console.log("send request_id", request_id, status, statusText); + let unlisten: Function | undefined; + const ts = new TransformStream(); + const writer = ts.writable.getWriter(); - if (signal) { - signal.addEventListener("abort", () => { - // Reject the promise with the abort reason. + const close = () => { unlisten && unlisten(); - reject && reject(signal.reason); - }); - } - // @ts-ignore 2. listen response multi times, and write to Response.body - window.__TAURI__.event - .listen("stream-response", (e: ResponseEvent) => { - const { id, payload } = e; - const { - request_id: rid, - status: _status, - name, - value, - error, - chunk, - } = payload; - if (request_id != rid) { - return; - } - /** - * 1. get status code - * 2. get headers - * 3. start get body, then resolve response - * 4. get body chunk - */ - if (error) { - unlisten && unlisten(); - return reject && reject(error); - } else if (_status) { - status = _status; - } else if (name && value) { - headers.append(name, value); - } else if (chunk) { - if (resolve) { - const ts = new TransformStream(); - writable = ts.writable; - writer = writable.getWriter(); - resolve(new Response(ts.readable, { status, headers })); - resolve = undefined; + writer.ready.then(() => { + try { + writer.releaseLock(); + } catch (e) { + console.error(e); } - writer && - writer.ready.then(() => { - writer && writer.write(new Uint8Array(chunk)); - }); - } else if (_status === 0) { - // end of body - unlisten && unlisten(); - writer && - writer.ready.then(() => { - writer && writer.releaseLock(); - writable && writable.close(); - }); - } - }) - .then((u: Function) => (unlisten = u)); - return new Promise( - (_resolve, _reject) => ([resolve, reject] = [_resolve, _reject]), - ); + ts.writable.close(); + }); + }; + + const response = new Response(ts.readable, { + status, + statusText, + headers, + }); + if (signal) { + signal.addEventListener("abort", () => close()); + } + // @ts-ignore 2. listen response multi times, and write to Response.body + window.__TAURI__.event + .listen("stream-response", (e: ResponseEvent) => { + const { id, payload } = e; + const { request_id: rid, chunk, status } = payload; + if (request_id != rid) { + return; + } + if (chunk) { + writer && + writer.ready.then(() => { + writer && writer.write(new Uint8Array(chunk)); + }); + } else if (status === 0) { + // end of body + close(); + } + }) + .then((u: Function) => (unlisten = u)); + return response; + }, + ) + .catch((e) => { + console.error("stream error", e); + throw e; }); } return window.fetch(url, options); } + +if (undefined !== window) { + window.tauriFetch = fetch; +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index e38208257..d04969c04 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -5,10 +5,8 @@ mod stream; fn main() { tauri::Builder::default() + .invoke_handler(tauri::generate_handler![stream::stream_fetch]) .plugin(tauri_plugin_window_state::Builder::default().build()) - .register_uri_scheme_protocol("stream", move |app_handle, request| { - stream::stream(app_handle, request) - }) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/src-tauri/src/stream.rs b/src-tauri/src/stream.rs index 5e84e0f00..514e62298 100644 --- a/src-tauri/src/stream.rs +++ b/src-tauri/src/stream.rs @@ -1,30 +1,25 @@ +// +// use std::error::Error; use futures_util::{StreamExt}; use reqwest::Client; -use tauri::{ Manager, AppHandle }; -use tauri::http::{Request, ResponseBuilder}; -use tauri::http::Response; +use reqwest::header::{HeaderName, HeaderMap}; static mut REQUEST_COUNTER: u32 = 0; #[derive(Clone, serde::Serialize)] -pub struct ErrorPayload { - request_id: u32, - error: String, -} - -#[derive(Clone, serde::Serialize)] -pub struct StatusPayload { +pub struct StreamResponse { request_id: u32, status: u16, + status_text: String, + headers: HashMap } #[derive(Clone, serde::Serialize)] -pub struct HeaderPayload { +pub struct EndPayload { request_id: u32, - name: String, - value: String, + status: u16, } #[derive(Clone, serde::Serialize)] @@ -33,64 +28,90 @@ pub struct ChunkPayload { chunk: bytes::Bytes, } -pub fn stream(app_handle: &AppHandle, request: &Request) -> Result> { +use std::collections::HashMap; + +#[derive(serde::Serialize)] +pub struct CustomResponse { + message: String, + other_val: usize, +} + +#[tauri::command] +pub async fn stream_fetch( + window: tauri::Window, + method: String, + url: String, + headers: HashMap, + body: Vec, +) -> Result { + let mut request_id = 0; let event_name = "stream-response"; unsafe { REQUEST_COUNTER += 1; request_id = REQUEST_COUNTER; } - let path = request.uri().to_string().replace("stream://localhost/", "").replace("http://stream.localhost/", ""); - let path = percent_encoding::percent_decode(path.as_bytes()) - .decode_utf8_lossy() - .to_string(); - // println!("path : {}", path); - let client = Client::new(); - let handle = app_handle.app_handle(); - // send http request - let body = reqwest::Body::from(request.body().clone()); - let response_future = client.request(request.method().clone(), path) - .headers(request.headers().clone()) - .body(body).send(); - // get response and emit to client - tauri::async_runtime::spawn(async move { - let res = response_future.await; + let mut _headers = HeaderMap::new(); + for (key, value) in headers { + _headers.insert(key.parse::().unwrap(), value.parse().unwrap()); + } + let body = bytes::Bytes::from(body); - match res { - Ok(res) => { - handle.emit_all(event_name, StatusPayload{ request_id, status: res.status().as_u16() }).unwrap(); - for (name, value) in res.headers() { - handle.emit_all(event_name, HeaderPayload { - request_id, - name: name.to_string(), - value: std::str::from_utf8(value.as_bytes()).unwrap().to_string() - }).unwrap(); - } + let response_future = Client::new().request( + method.parse::().map_err(|err| format!("failed to parse method: {}", err))?, + url.parse::().map_err(|err| format!("failed to parse url: {}", err))? + ).headers(_headers).body(body).send(); + + let res = response_future.await; + let response = match res { + Ok(res) => { + println!("Error: {:?}", res); + // get response and emit to client + // .register_uri_scheme_protocol("stream", move |app_handle, request| { + let mut headers = HashMap::new(); + for (name, value) in res.headers() { + headers.insert( + name.as_str().to_string(), + std::str::from_utf8(value.as_bytes()).unwrap().to_string() + ); + } + let status = res.status().as_u16(); + + tauri::async_runtime::spawn(async move { let mut stream = res.bytes_stream(); while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { - handle.emit_all(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap(); + println!("chunk: {:?}", bytes); + window.emit(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap(); } Err(err) => { println!("Error: {:?}", err); } } } - handle.emit_all(event_name, StatusPayload { request_id, status: 0 }).unwrap(); - } - Err(err) => { - println!("Error: {:?}", err.source().expect("REASON").to_string()); - handle.emit_all(event_name, ErrorPayload { - request_id, - error: err.source().expect("REASON").to_string() - }).unwrap(); + window.emit(event_name, EndPayload { request_id, status: 0 }).unwrap(); + }); + + StreamResponse { + request_id, + status, + status_text: "OK".to_string(), + headers, } } - }); - return ResponseBuilder::new() - .header("Access-Control-Allow-Origin", "*") - .status(200).body(request_id.to_string().into()) + Err(err) => { + println!("Error: {:?}", err.source().expect("REASON").to_string()); + StreamResponse { + request_id, + status: 599, + status_text: err.source().expect("REASON").to_string(), + headers: HashMap::new(), + } + } + }; + Ok(response) } +