feat: support whisper

This commit is contained in:
Hk-Gosuto 2024-03-18 16:02:39 +08:00
parent f10fa91432
commit bab838b9c6
6 changed files with 216 additions and 28 deletions

View File

@ -53,6 +53,16 @@ export interface SpeechOptions {
onController?: (controller: AbortController) => void; onController?: (controller: AbortController) => void;
} }
export interface TranscriptionOptions {
model?: "whisper-1";
file: Blob;
language?: string;
prompt?: string;
response_format?: "json" | "text" | "srt" | "verbose_json" | "vtt";
temperature?: number;
onController?: (controller: AbortController) => void;
}
export interface ChatOptions { export interface ChatOptions {
messages: RequestMessage[]; messages: RequestMessage[];
config: LLMConfig; config: LLMConfig;
@ -94,6 +104,7 @@ export interface LLMModelProvider {
export abstract class LLMApi { export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>; abstract chat(options: ChatOptions): Promise<void>;
abstract speech(options: SpeechOptions): Promise<ArrayBuffer>; abstract speech(options: SpeechOptions): Promise<ArrayBuffer>;
abstract transcription(options: TranscriptionOptions): Promise<string>;
abstract toolAgentChat(options: AgentChatOptions): Promise<void>; abstract toolAgentChat(options: AgentChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>; abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>; abstract models(): Promise<LLMModel[]>;

View File

@ -7,6 +7,7 @@ import {
LLMModel, LLMModel,
LLMUsage, LLMUsage,
SpeechOptions, SpeechOptions,
TranscriptionOptions,
} from "../api"; } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
@ -18,6 +19,9 @@ import {
} from "@/app/utils"; } from "@/app/utils";
export class GeminiProApi implements LLMApi { export class GeminiProApi implements LLMApi {
transcription(options: TranscriptionOptions): Promise<string> {
throw new Error("Method not implemented.");
}
speech(options: SpeechOptions): Promise<ArrayBuffer> { speech(options: SpeechOptions): Promise<ArrayBuffer> {
throw new Error("Method not implemented."); throw new Error("Method not implemented.");
} }

View File

@ -18,6 +18,7 @@ import {
LLMUsage, LLMUsage,
MultimodalContent, MultimodalContent,
SpeechOptions, SpeechOptions,
TranscriptionOptions,
} from "../api"; } from "../api";
import Locale from "../../locales"; import Locale from "../../locales";
import { import {
@ -124,6 +125,47 @@ export class ChatGPTApi implements LLMApi {
} }
} }
async transcription(options: TranscriptionOptions): Promise<string> {
const formData = new FormData();
formData.append("file", options.file, "audio.wav");
formData.append("model", options.model ?? "whisper-1");
if (options.language) formData.append("language", options.language);
if (options.prompt) formData.append("prompt", options.prompt);
if (options.response_format)
formData.append("response_format", options.response_format);
if (options.temperature)
formData.append("temperature", options.temperature.toString());
console.log("[Request] openai audio transcriptions payload: ", options);
const controller = new AbortController();
options.onController?.(controller);
try {
const path = this.path(OpenaiPath.TranscriptionPath, options.model);
const payload = {
method: "POST",
body: formData,
signal: controller.signal,
headers: getHeaders(),
};
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
const res = await fetch(path, payload);
clearTimeout(requestTimeoutId);
const json = await res.json();
return json.text;
} catch (e) {
console.log("[Request] failed to make a audio transcriptions request", e);
throw e;
}
}
async chat(options: ChatOptions) { async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model); const visionModel = isVisionModel(options.config.model);
const messages = options.messages.map((v) => ({ const messages = options.messages.map((v) => ({

View File

@ -108,6 +108,11 @@ import { useAllModels } from "../utils/hooks";
import { ClientApi } from "../client/api"; import { ClientApi } from "../client/api";
import { createTTSPlayer } from "../utils/audio"; import { createTTSPlayer } from "../utils/audio";
import { MultimodalContent } from "../client/api"; import { MultimodalContent } from "../client/api";
import {
OpenAITranscriptionApi,
SpeechApi,
WebTranscriptionApi,
} from "../utils/speech";
const ttsPlayer = createTTSPlayer(); const ttsPlayer = createTTSPlayer();
@ -801,17 +806,19 @@ function _Chat() {
}; };
const [isListening, setIsListening] = useState(false); const [isListening, setIsListening] = useState(false);
const [recognition, setRecognition] = useState<any>(null); const [speechApi, setSpeechApi] = useState<any>(null);
const startListening = () => {
if (recognition) { const startListening = async () => {
recognition.start(); console.log(speechApi);
if (speechApi) {
await speechApi.start();
setIsListening(true); setIsListening(true);
} }
}; };
const stopListening = () => { const stopListening = async () => {
if (recognition) { if (speechApi) {
recognition.stop(); await speechApi.stop();
setIsListening(false); setIsListening(false);
} }
}; };
@ -891,26 +898,11 @@ function _Chat() {
} }
}); });
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
if (typeof window !== "undefined") { setSpeechApi(
const SpeechRecognition = new WebTranscriptionApi((transcription) =>
(window as any).SpeechRecognition || onRecognitionEnd(transcription),
(window as any).webkitSpeechRecognition; ),
const recognitionInstance = new SpeechRecognition(); );
recognitionInstance.continuous = true;
recognitionInstance.interimResults = true;
let lang = getSTTLang();
recognitionInstance.lang = lang;
recognitionInstance.onresult = (event: any) => {
const result = event.results[event.results.length - 1];
if (result.isFinal) {
if (!isListening) {
onRecognitionEnd(result[0].transcript);
}
}
};
setRecognition(recognitionInstance);
}
}, []); }, []);
// check if should send message // check if should send message
@ -1700,7 +1692,9 @@ function _Chat() {
} }
className={styles["chat-input-send"]} className={styles["chat-input-send"]}
type="primary" type="primary"
onClick={() => (isListening ? stopListening() : startListening())} onClick={async () =>
isListening ? await stopListening() : await startListening()
}
/> />
) : ( ) : (
<IconButton <IconButton

View File

@ -80,6 +80,7 @@ export enum ModelProvider {
export const OpenaiPath = { export const OpenaiPath = {
ChatPath: "v1/chat/completions", ChatPath: "v1/chat/completions",
SpeechPath: "v1/audio/speech", SpeechPath: "v1/audio/speech",
TranscriptionPath: "v1/audio/transcriptions",
UsagePath: "dashboard/billing/usage", UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription", SubsPath: "dashboard/billing/subscription",
ListModelPath: "v1/models", ListModelPath: "v1/models",

136
app/utils/speech.ts Normal file
View File

@ -0,0 +1,136 @@
import { ChatGPTApi } from "../client/platforms/openai";
import { getSTTLang } from "../locales";
export type TranscriptionCallback = (transcription: string) => void;
export abstract class SpeechApi {
protected onTranscription: TranscriptionCallback = () => {};
abstract isListening(): boolean;
abstract start(): Promise<void>;
abstract stop(): Promise<void>;
onTranscriptionReceived(callback: TranscriptionCallback) {
this.onTranscription = callback;
}
protected async getMediaStream(): Promise<MediaStream | null> {
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
return await navigator.mediaDevices.getUserMedia({ audio: true });
} else if (navigator.getUserMedia) {
return new Promise((resolve, reject) => {
navigator.getUserMedia({ audio: true }, resolve, reject);
});
} else {
console.warn("当前浏览器不支持 getUserMedia");
return null;
}
}
protected createRecorder(stream: MediaStream): MediaRecorder | null {
if (MediaRecorder.isTypeSupported("audio/webm")) {
return new MediaRecorder(stream, { mimeType: "audio/webm" });
} else if (MediaRecorder.isTypeSupported("audio/ogg")) {
return new MediaRecorder(stream, { mimeType: "audio/ogg" });
} else {
console.warn("当前浏览器不支持 MediaRecorder");
return null;
}
}
}
export class OpenAITranscriptionApi extends SpeechApi {
private listeningStatus = false;
private recorder: MediaRecorder | null = null;
private audioChunks: Blob[] = [];
isListening = () => this.listeningStatus;
constructor(transcriptionCallback?: TranscriptionCallback) {
super();
if (transcriptionCallback) {
this.onTranscriptionReceived(transcriptionCallback);
}
}
async start(): Promise<void> {
const stream = await this.getMediaStream();
if (!stream) {
console.error("无法获取音频流");
return;
}
this.recorder = this.createRecorder(stream);
if (!this.recorder) {
console.error("无法创建 MediaRecorder");
return;
}
this.audioChunks = [];
this.recorder.addEventListener("dataavailable", (event) => {
this.audioChunks.push(event.data);
});
this.recorder.start();
this.listeningStatus = true;
}
async stop(): Promise<void> {
if (!this.recorder || !this.listeningStatus) {
return;
}
return new Promise((resolve) => {
this.recorder!.addEventListener("stop", async () => {
const audioBlob = new Blob(this.audioChunks, { type: "audio/wav" });
const llm = new ChatGPTApi();
const transcription = await llm.transcription({ file: audioBlob });
this.onTranscription(transcription);
this.listeningStatus = false;
resolve();
});
this.recorder!.stop();
});
}
}
export class WebTranscriptionApi extends SpeechApi {
private listeningStatus = false;
private recognitionInstance: any | null = null;
isListening = () => this.listeningStatus;
constructor(transcriptionCallback?: TranscriptionCallback) {
super();
const SpeechRecognition =
(window as any).SpeechRecognition ||
(window as any).webkitSpeechRecognition;
this.recognitionInstance = new SpeechRecognition();
this.recognitionInstance.continuous = true;
this.recognitionInstance.interimResults = true;
this.recognitionInstance.lang = getSTTLang();
if (transcriptionCallback) {
this.onTranscriptionReceived(transcriptionCallback);
}
this.recognitionInstance.onresult = (event: any) => {
const result = event.results[event.results.length - 1];
if (result.isFinal) {
if (!this.isListening) {
this.onTranscriptionReceived(result[0].transcript);
}
}
};
}
async start(): Promise<void> {
await this.recognitionInstance.start();
this.listeningStatus = true;
}
async stop(): Promise<void> {
await this.recognitionInstance.stop();
this.listeningStatus = false;
}
}