ChatGPT-Next-Web/app/api/bedrock.ts

289 lines
7.0 KiB
TypeScript

import { ModelProvider } from "../constant";
import { prettyObject } from "../utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "./auth";
import {
BedrockRuntimeClient,
InvokeModelCommand,
ValidationException,
} from "@aws-sdk/client-bedrock-runtime";
import { validateModelId } from "./bedrock/utils";
import {
ConverseRequest,
formatRequestBody,
parseModelResponse,
} from "./bedrock/models";
interface ContentItem {
type: string;
text?: string;
image_url?: {
url: string;
};
}
const ALLOWED_PATH = new Set(["invoke", "converse"]);
export async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Bedrock Route] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const subpath = params.path.join("/");
if (!ALLOWED_PATH.has(subpath)) {
console.log("[Bedrock Route] forbidden path ", subpath);
return NextResponse.json(
{
error: true,
msg: "you are not allowed to request " + subpath,
},
{
status: 403,
},
);
}
const authResult = auth(req, ModelProvider.Bedrock);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try {
if (subpath === "converse") {
const response = await handleConverseRequest(req);
return response;
} else {
const response = await handleInvokeRequest(req);
return response;
}
} catch (e) {
console.error("[Bedrock] ", e);
// Handle specific error cases
if (e instanceof ValidationException) {
return NextResponse.json(
{
error: true,
message:
"Model validation error. If using a Llama model, please provide a valid inference profile ARN.",
details: e.message,
},
{ status: 400 },
);
}
return NextResponse.json(
{
error: true,
message: e instanceof Error ? e.message : "Unknown error",
details: prettyObject(e),
},
{ status: 500 },
);
}
}
async function handleConverseRequest(req: NextRequest) {
const controller = new AbortController();
const region = req.headers.get("X-Region") || "us-east-1";
const accessKeyId = req.headers.get("X-Access-Key") || "";
const secretAccessKey = req.headers.get("X-Secret-Key") || "";
const sessionToken = req.headers.get("X-Session-Token");
if (!accessKeyId || !secretAccessKey) {
return NextResponse.json(
{
error: true,
message: "Missing AWS credentials",
},
{
status: 401,
},
);
}
console.log("[Bedrock] Using region:", region);
const client = new BedrockRuntimeClient({
region,
credentials: {
accessKeyId,
secretAccessKey,
sessionToken: sessionToken || undefined,
},
});
const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
try {
const body = (await req.json()) as ConverseRequest;
const { modelId } = body;
// Validate model ID
const validationError = validateModelId(modelId);
if (validationError) {
throw new ValidationException({
message: validationError,
$metadata: {},
});
}
console.log("[Bedrock] Invoking model:", modelId);
console.log("[Bedrock] Messages:", body.messages);
const requestBody = formatRequestBody(body);
const jsonString = JSON.stringify(requestBody);
const input = {
modelId,
contentType: "application/json",
accept: "application/json",
body: Uint8Array.from(Buffer.from(jsonString)),
};
console.log("[Bedrock] Request input:", {
...input,
body: requestBody,
});
const command = new InvokeModelCommand(input);
const response = await client.send(command);
console.log("[Bedrock] Got response");
// Parse and format the response based on model type
const responseBody = new TextDecoder().decode(response.body);
const formattedResponse = parseModelResponse(responseBody, modelId);
return NextResponse.json(formattedResponse);
} catch (e) {
console.error("[Bedrock] Request error:", e);
throw e; // Let the main error handler deal with it
} finally {
clearTimeout(timeoutId);
}
}
async function handleInvokeRequest(req: NextRequest) {
const controller = new AbortController();
const region = req.headers.get("X-Region") || "us-east-1";
const accessKeyId = req.headers.get("X-Access-Key") || "";
const secretAccessKey = req.headers.get("X-Secret-Key") || "";
const sessionToken = req.headers.get("X-Session-Token");
if (!accessKeyId || !secretAccessKey) {
return NextResponse.json(
{
error: true,
message: "Missing AWS credentials",
},
{
status: 401,
},
);
}
const client = new BedrockRuntimeClient({
region,
credentials: {
accessKeyId,
secretAccessKey,
sessionToken: sessionToken || undefined,
},
});
const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
try {
const body = await req.json();
const { messages, model } = body;
// Validate model ID
const validationError = validateModelId(model);
if (validationError) {
throw new ValidationException({
message: validationError,
$metadata: {},
});
}
console.log("[Bedrock] Invoking model:", model);
console.log("[Bedrock] Messages:", messages);
const requestBody = formatRequestBody({
modelId: model,
messages,
inferenceConfig: {
maxTokens: 2048,
temperature: 0.7,
topP: 0.9,
},
});
const jsonString = JSON.stringify(requestBody);
const input = {
modelId: model,
contentType: "application/json",
accept: "application/json",
body: Uint8Array.from(Buffer.from(jsonString)),
};
console.log("[Bedrock] Request input:", {
...input,
body: requestBody,
});
const command = new InvokeModelCommand(input);
const response = await client.send(command);
console.log("[Bedrock] Got response");
// Parse and format the response
const responseBody = new TextDecoder().decode(response.body);
const formattedResponse = parseModelResponse(responseBody, model);
// Extract text content from the response
let textContent = "";
if (formattedResponse.content && Array.isArray(formattedResponse.content)) {
textContent = formattedResponse.content
.filter((item: ContentItem) => item.type === "text")
.map((item: ContentItem) => item.text || "")
.join("");
} else if (typeof formattedResponse.content === "string") {
textContent = formattedResponse.content;
}
// Return plain text response
return new NextResponse(textContent, {
headers: {
"Content-Type": "text/plain",
},
});
} catch (e) {
console.error("[Bedrock] Request error:", e);
throw e;
} finally {
clearTimeout(timeoutId);
}
}