wip: doubao

This commit is contained in:
Dogtiti 2024-07-06 14:59:37 +08:00
parent 2d1f522aaf
commit 9b3b4494ba
10 changed files with 475 additions and 0 deletions

View File

@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.Claude: case ModelProvider.Claude:
systemApiKey = serverConfig.anthropicApiKey; systemApiKey = serverConfig.anthropicApiKey;
break; break;
case ModelProvider.Doubao:
systemApiKey = serverConfig.bytedanceApiKey;
break;
case ModelProvider.GPT: case ModelProvider.GPT:
default: default:
if (req.nextUrl.pathname.includes("azure/deployments")) { if (req.nextUrl.pathname.includes("azure/deployments")) {

View File

@ -0,0 +1,160 @@
import { getServerSideConfig } from "@/app/config/server";
import {
BYTEDANCE_BASE_URL,
ApiPath,
ModelProvider,
ServiceProvider,
} from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/app/api/auth";
import { isModelAvailableInServer } from "@/app/utils/model";
const serverConfig = getServerSideConfig();
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[ByteDance Route] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const authResult = auth(req, ModelProvider.Doubao);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try {
const response = await request(req);
return response;
} catch (e) {
console.error("[ByteDance] ", e);
return NextResponse.json(prettyObject(e));
}
}
export const GET = handle;
export const POST = handle;
export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];
async function request(req: NextRequest) {
const controller = new AbortController();
let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.ByteDance, "");
let baseUrl = serverConfig.bytedanceUrl || BYTEDANCE_BASE_URL;
if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}
console.log("[Proxy] ", path);
console.log("[Base Url]", baseUrl);
const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
const fetchUrl = `${baseUrl}${path}`;
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
Authorization: req.headers.get("Authorization") ?? "",
},
method: req.method,
body: req.body,
redirect: "manual",
// @ts-ignore
duplex: "half",
signal: controller.signal,
};
// #1815 try to refuse some request to some models
if (serverConfig.customModels && req.body) {
try {
const clonedBody = await req.text();
fetchOptions.body = clonedBody;
const jsonBody = JSON.parse(clonedBody) as { model?: string };
// not undefined and is false
if (
isModelAvailableInServer(
serverConfig.customModels,
jsonBody?.model as string,
ServiceProvider.ByteDance as string,
)
) {
return NextResponse.json(
{
error: true,
message: `you are not allowed to use ${jsonBody?.model} model`,
},
{
status: 403,
},
);
}
} catch (e) {
console.error(`[ByteDance] filter`, e);
}
}
console.log("[ByteDance request]", fetchOptions.headers, req.method);
try {
const res = await fetch(fetchUrl, fetchOptions);
console.log(
"[ByteDance response]",
res.status,
" ",
res.headers,
res.url,
);
// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");
return new Response(res.body, {
status: res.status,
statusText: res.statusText,
headers: newHeaders,
});
} finally {
clearTimeout(timeoutId);
}
}

View File

@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
import { ChatGPTApi } from "./platforms/openai"; import { ChatGPTApi } from "./platforms/openai";
import { GeminiProApi } from "./platforms/google"; import { GeminiProApi } from "./platforms/google";
import { ClaudeApi } from "./platforms/anthropic"; import { ClaudeApi } from "./platforms/anthropic";
import { DoubaoApi } from "./platforms/bytedance";
export const ROLES = ["system", "user", "assistant"] as const; export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number]; export type MessageRole = (typeof ROLES)[number];
@ -104,6 +106,9 @@ export class ClientApi {
case ModelProvider.Claude: case ModelProvider.Claude:
this.llm = new ClaudeApi(); this.llm = new ClaudeApi();
break; break;
case ModelProvider.Doubao:
this.llm = new DoubaoApi();
break;
default: default:
this.llm = new ChatGPTApi(); this.llm = new ChatGPTApi();
} }

View File

@ -0,0 +1,260 @@
"use client";
import {
ApiPath,
ByteDance,
DEFAULT_API_HOST,
REQUEST_TIMEOUT_MS,
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
ChatOptions,
getHeaders,
LLMApi,
LLMModel,
MultimodalContent,
} from "../api";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client";
import { getMessageTextContent, isVisionModel } from "@/app/utils";
export interface OpenAIListModelResponse {
object: string;
data: Array<{
id: string;
object: string;
root: string;
}>;
}
interface RequestPayload {
messages: {
role: "system" | "user" | "assistant";
content: string | MultimodalContent[];
}[];
stream?: boolean;
model: string;
temperature: number;
presence_penalty: number;
frequency_penalty: number;
top_p: number;
max_tokens?: number;
}
export class DoubaoApi implements LLMApi {
path(path: string): string {
const accessStore = useAccessStore.getState();
let baseUrl = "";
if (accessStore.useCustomConfig) {
baseUrl = accessStore.bytedanceUrl;
}
if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp
? DEFAULT_API_HOST + "/api/proxy/bytedance"
: ApiPath.ByteDance;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ByteDance)) {
baseUrl = "https://" + baseUrl;
}
console.log("[Proxy Endpoint] ", baseUrl, path);
return [baseUrl, path].join("/");
}
extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}
async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const messages = options.messages.map((v) => ({
role: v.role,
content: visionModel ? v.content : getMessageTextContent(v),
}));
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
},
};
const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
};
console.log("[Request] ByteDance payload: ", requestPayload);
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(ByteDance.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
headers: getHeaders(),
};
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
if (shouldStream) {
let responseText = "";
let remainText = "";
let finished = false;
// animate response to make it looks smooth
function animateResponseText() {
if (finished || controller.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
if (responseText?.length === 0) {
options.onError?.(new Error("empty response from server"));
}
return;
}
if (remainText.length > 0) {
const fetchCount = Math.max(1, Math.round(remainText.length / 60));
const fetchText = remainText.slice(0, fetchCount);
responseText += fetchText;
remainText = remainText.slice(fetchCount);
options.onUpdate?.(responseText, fetchText);
}
requestAnimationFrame(animateResponseText);
}
// start animaion
animateResponseText();
const finish = () => {
if (!finished) {
finished = true;
options.onFinish(responseText + remainText);
}
};
controller.signal.onabort = finish;
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[ByteDance] request response content type: ",
contentType,
);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const choices = json.choices as Array<{
delta: { content: string };
}>;
const delta = choices[0]?.delta?.content;
if (delta) {
remainText += delta;
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
}
} catch (e) {
console.log("[Request] failed to make a chat request", e);
options.onError?.(e as Error);
}
}
async usage() {
return {
used: 0,
total: 0,
};
}
async models(): Promise<LLMModel[]> {
return [];
}
}
export { ByteDance };

View File

@ -321,6 +321,8 @@ export function PreviewActions(props: {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
api = new ClientApi(ModelProvider.Doubao);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }

View File

@ -175,6 +175,8 @@ export function useLoadData() {
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
api = new ClientApi(ModelProvider.Doubao);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }

View File

@ -32,6 +32,10 @@ declare global {
GOOGLE_API_KEY?: string; GOOGLE_API_KEY?: string;
GOOGLE_URL?: string; GOOGLE_URL?: string;
// bytedance only
BYTEDANCE_URL?: string;
BYTEDANCE_API_KEY?: string;
// google tag manager // google tag manager
GTM_ID?: string; GTM_ID?: string;
@ -92,6 +96,7 @@ export const getServerSideConfig = () => {
const isAzure = !!process.env.AZURE_URL; const isAzure = !!process.env.AZURE_URL;
const isGoogle = !!process.env.GOOGLE_API_KEY; const isGoogle = !!process.env.GOOGLE_API_KEY;
const isAnthropic = !!process.env.ANTHROPIC_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY;
const isBytedance = !!process.env.BYTEDANCE_API_KEY;
// const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
// const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim());
@ -126,6 +131,10 @@ export const getServerSideConfig = () => {
gtmId: process.env.GTM_ID, gtmId: process.env.GTM_ID,
isBytedance,
bytedanceApiKey: getApiKey(process.env.BYTEDANCE_API_KEY),
bytedanceUrl: process.env.BYTEDANCE_URL,
needCode: ACCESS_CODES.size > 0, needCode: ACCESS_CODES.size > 0,
code: process.env.CODE, code: process.env.CODE,
codes: ACCESS_CODES, codes: ACCESS_CODES,

View File

@ -14,6 +14,8 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com";
export enum Path { export enum Path {
Home = "/", Home = "/",
Chat = "/chat", Chat = "/chat",
@ -28,6 +30,7 @@ export enum ApiPath {
Azure = "/api/azure", Azure = "/api/azure",
OpenAI = "/api/openai", OpenAI = "/api/openai",
Anthropic = "/api/anthropic", Anthropic = "/api/anthropic",
ByteDance = "/api/bytedance",
} }
export enum SlotID { export enum SlotID {
@ -71,12 +74,14 @@ export enum ServiceProvider {
Azure = "Azure", Azure = "Azure",
Google = "Google", Google = "Google",
Anthropic = "Anthropic", Anthropic = "Anthropic",
ByteDance = "ByteDance",
} }
export enum ModelProvider { export enum ModelProvider {
GPT = "GPT", GPT = "GPT",
GeminiPro = "GeminiPro", GeminiPro = "GeminiPro",
Claude = "Claude", Claude = "Claude",
Doubao = "Doubao",
} }
export const Anthropic = { export const Anthropic = {
@ -104,6 +109,11 @@ export const Google = {
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
}; };
export const ByteDance = {
ExampleEndpoint: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
ChatPath: "/api/v3/chat/completions",
};
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
// export const DEFAULT_SYSTEM_TEMPLATE = ` // export const DEFAULT_SYSTEM_TEMPLATE = `
// You are ChatGPT, a large language model trained by {{ServiceProvider}}. // You are ChatGPT, a large language model trained by {{ServiceProvider}}.
@ -173,6 +183,8 @@ const anthropicModels = [
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
]; ];
const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"];
export const DEFAULT_MODELS = [ export const DEFAULT_MODELS = [
...openaiModels.map((name) => ({ ...openaiModels.map((name) => ({
name, name,
@ -210,6 +222,15 @@ export const DEFAULT_MODELS = [
providerType: "anthropic", providerType: "anthropic",
}, },
})), })),
...bytedanceModels.map((name) => ({
name,
available: true,
provider: {
id: "bytedance",
providerName: "ByteDance",
providerType: "bytedance",
},
})),
] as const; ] as const;
export const CHAT_PAGE_SIZE = 15; export const CHAT_PAGE_SIZE = 15;

View File

@ -47,6 +47,10 @@ const DEFAULT_ACCESS_STATE = {
anthropicApiVersion: "2023-06-01", anthropicApiVersion: "2023-06-01",
anthropicUrl: "", anthropicUrl: "",
// bytedance
bytedanceApiKey: "",
bytedanceUrl: "",
// server config // server config
needCode: true, needCode: true,
hideUserApiKey: false, hideUserApiKey: false,
@ -83,6 +87,10 @@ export const useAccessStore = createPersistStore(
return ensure(get(), ["anthropicApiKey"]); return ensure(get(), ["anthropicApiKey"]);
}, },
isValidByteDance() {
return ensure(get(), ["bytedanceApiKey"]);
},
isAuthorized() { isAuthorized() {
this.fetch(); this.fetch();
@ -92,6 +100,7 @@ export const useAccessStore = createPersistStore(
this.isValidAzure() || this.isValidAzure() ||
this.isValidGoogle() || this.isValidGoogle() ||
this.isValidAnthropic() || this.isValidAnthropic() ||
this.isValidByteDance() ||
!this.enabledAccessControl() || !this.enabledAccessControl() ||
(this.enabledAccessControl() && ensure(get(), ["accessCode"])) (this.enabledAccessControl() && ensure(get(), ["accessCode"]))
); );

View File

@ -368,6 +368,8 @@ export const useChatStore = createPersistStore(
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) { } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else if (modelConfig.providerName == ServiceProvider.ByteDance) {
api = new ClientApi(ModelProvider.Doubao);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }
@ -552,6 +554,8 @@ export const useChatStore = createPersistStore(
api = new ClientApi(ModelProvider.GeminiPro); api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) { } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude); api = new ClientApi(ModelProvider.Claude);
} else if (modelConfig.providerName == ServiceProvider.ByteDance) {
api = new ClientApi(ModelProvider.Doubao);
} else { } else {
api = new ClientApi(ModelProvider.GPT); api = new ClientApi(ModelProvider.GPT);
} }