feat: mix handlers of proxy server in providers

This commit is contained in:
Dean-YZG
2024-05-22 21:31:54 +08:00
parent 8093d1ffba
commit 8de8acdce8
23 changed files with 1570 additions and 420 deletions

View File

@@ -1,3 +1,5 @@
export * from "./shim";
export * from "../common/types";
export * from "./providerClient";
@@ -5,5 +7,3 @@ export * from "./providerClient";
export * from "./modelClient";
export * from "../common/locale";
export * from "./shim";

View File

@@ -3,14 +3,15 @@ import {
InternalChatHandlers,
Model,
ModelTemplate,
ProviderTemplate,
StandChatReponseMessage,
StandChatRequestPayload,
isSameOrigin,
modelNameRequestHeader,
} from "../common";
import * as ProviderTemplates from "@/app/client/providers";
import { nanoid } from "nanoid";
export type ProviderTemplate = IProviderTemplate<any, any, any>;
export type ProviderTemplateName =
(typeof ProviderTemplates)[keyof typeof ProviderTemplates]["prototype"]["name"];
@@ -38,6 +39,7 @@ const providerTemplates = Object.values(ProviderTemplates).reduce(
export class ProviderClient {
providerTemplate: IProviderTemplate<any, any, any>;
genFetch: (modelName: string) => typeof window.fetch;
static ProviderTemplates = providerTemplates;
@@ -61,6 +63,31 @@ export class ProviderClient {
constructor(private provider: Provider) {
const { providerTemplateName } = provider;
this.providerTemplate = this.getProviderTemplate(providerTemplateName);
this.genFetch =
(modelName: string) =>
(...args) => {
const req = new Request(...args);
const headers: Record<string, any> = {
...req.headers,
};
if (isSameOrigin(req.url)) {
headers[modelNameRequestHeader] = modelName;
}
return window.fetch(req.url, {
method: req.method,
keepalive: req.keepalive,
headers,
body: req.body,
redirect: req.redirect,
integrity: req.integrity,
signal: req.signal,
credentials: req.credentials,
mode: req.mode,
referrer: req.referrer,
referrerPolicy: req.referrerPolicy,
});
};
}
private getProviderTemplate(providerTemplateName: string) {
@@ -98,12 +125,15 @@ export class ProviderClient {
async chat(
payload: StandChatRequestPayload,
): Promise<StandChatReponseMessage> {
return this.providerTemplate.chat({
...payload,
stream: false,
isVisionModel: this.getModelConfig(payload.model)?.isVisionModel,
providerConfig: this.provider.providerConfig,
});
return this.providerTemplate.chat(
{
...payload,
stream: false,
isVisionModel: this.getModelConfig(payload.model)?.isVisionModel,
providerConfig: this.provider.providerConfig,
},
this.genFetch(payload.model),
);
}
streamChat(payload: StandChatRequestPayload, handlers: InternalChatHandlers) {
@@ -129,6 +159,7 @@ export class ProviderClient {
handlers.onFinish(message);
},
},
this.genFetch(payload.model),
);
timer.signal.onabort = () => {