优化代码,修改方法命名错误

This commit is contained in:
glay 2024-11-23 12:09:45 +08:00
parent b0c1ccd0a0
commit a85db21e1f
5 changed files with 73 additions and 71 deletions

View File

@ -1,7 +1,9 @@
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { auth } from "./auth";
import { sign } from "../utils/aws"; import { sign } from "../utils/aws";
import { getServerSideConfig } from "../config/server"; import { getServerSideConfig } from "../config/server";
import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
const ALLOWED_PATH = new Set(["chat", "models"]); const ALLOWED_PATH = new Set(["chat", "models"]);
function parseEventData(chunk: Uint8Array): any { function parseEventData(chunk: Uint8Array): any {
@ -189,7 +191,7 @@ async function requestBedrock(req: NextRequest) {
let awsRegion = config.awsRegion; let awsRegion = config.awsRegion;
let awsAccessKey = config.awsAccessKey; let awsAccessKey = config.awsAccessKey;
let awsSecretKey = config.awsSecretKey; let awsSecretKey = config.awsSecretKey;
let modelId = ""; let modelId = req.headers.get("ModelID");
// If server-side credentials are not available, parse from Authorization header // If server-side credentials are not available, parse from Authorization header
if (!awsRegion || !awsAccessKey || !awsSecretKey) { if (!awsRegion || !awsAccessKey || !awsSecretKey) {
@ -199,16 +201,15 @@ async function requestBedrock(req: NextRequest) {
} }
const [_, credentials] = authHeader.split("Bearer "); const [_, credentials] = authHeader.split("Bearer ");
const [region, accessKey, secretKey, model] = credentials.split(","); const [region, accessKey, secretKey] = credentials.split(":");
if (!region || !accessKey || !secretKey || !model) { if (!region || !accessKey || !secretKey) {
throw new Error("Invalid Authorization header format"); throw new Error("Invalid Authorization header format");
} }
awsRegion = region; awsRegion = region;
awsAccessKey = accessKey; awsAccessKey = accessKey;
awsSecretKey = secretKey; awsSecretKey = secretKey;
modelId = model;
} }
if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) {
@ -329,14 +330,16 @@ export async function handle(
{ status: 403 }, { status: 403 },
); );
} }
const authResult = auth(req, ModelProvider.Bedrock);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try { try {
return await requestBedrock(req); return await requestBedrock(req);
} catch (e) { } catch (e) {
console.error("Handler error:", e); console.error("Handler error:", e);
return NextResponse.json( return NextResponse.json(prettyObject(e));
{ error: true, msg: e instanceof Error ? e.message : "Unknown error" },
{ status: 500 },
);
} }
} }

View File

@ -280,12 +280,10 @@ export function getHeaders(ignoreHeaders: boolean = false) {
accessStore.awsAccessKey && accessStore.awsAccessKey &&
accessStore.awsSecretKey accessStore.awsSecretKey
? accessStore.awsRegion + ? accessStore.awsRegion +
"," + ":" +
accessStore.awsAccessKey + accessStore.awsAccessKey +
"," + ":" +
accessStore.awsSecretKey + accessStore.awsSecretKey
"," +
modelConfig.model
: "" : ""
: accessStore.openaiApiKey; : accessStore.openaiApiKey;
return { return {
@ -316,6 +314,7 @@ export function getHeaders(ignoreHeaders: boolean = false) {
} }
const { const {
isBedrock,
isGoogle, isGoogle,
isAzure, isAzure,
isAnthropic, isAnthropic,
@ -328,23 +327,23 @@ export function getHeaders(ignoreHeaders: boolean = false) {
const authHeader = getAuthHeader(); const authHeader = getAuthHeader();
// if (isBedrock) { if (isBedrock) {
// // Secure encryption of AWS credentials using the new encryption utility if (apiKey) {
// headers["X-Region"] = encrypt(accessStore.awsRegion); headers[authHeader] = getBearerToken(apiKey);
// headers["X-Access-Key"] = encrypt(accessStore.awsAccessKey); }
// headers["X-Secret-Key"] = encrypt(accessStore.awsSecretKey); } else {
// } else { const bearerToken = getBearerToken(
const bearerToken = getBearerToken( apiKey,
apiKey, isAzure || isAnthropic || isGoogle,
isAzure || isAnthropic || isGoogle,
);
if (bearerToken) {
headers[authHeader] = bearerToken;
} else if (isEnabledAccessControl && validString(accessStore.accessCode)) {
headers["Authorization"] = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
); );
if (bearerToken) {
headers[authHeader] = bearerToken;
} else if (isEnabledAccessControl && validString(accessStore.accessCode)) {
headers["Authorization"] = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
} }
return headers; return headers;

View File

@ -1,3 +1,4 @@
"use client";
import { import {
ChatOptions, ChatOptions,
getHeaders, getHeaders,
@ -16,6 +17,8 @@ import {
} from "../../store"; } from "../../store";
import { preProcessImageContent, stream } from "../../utils/chat"; import { preProcessImageContent, stream } from "../../utils/chat";
import { getMessageTextContent, isVisionModel } from "../../utils"; import { getMessageTextContent, isVisionModel } from "../../utils";
import { ApiPath, BEDROCK_BASE_URL } from "../../constant";
import { getClientConfig } from "../../config/client";
const ClaudeMapper = { const ClaudeMapper = {
assistant: "assistant", assistant: "assistant",
@ -34,6 +37,35 @@ interface ToolDefinition {
} }
export class BedrockApi implements LLMApi { export class BedrockApi implements LLMApi {
private disableListModels = true;
path(path: string): string {
const accessStore = useAccessStore.getState();
let baseUrl = "";
if (accessStore.useCustomConfig) {
baseUrl = accessStore.bedrockUrl;
}
if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
const apiPath = ApiPath.Bedrock;
baseUrl = isApp ? BEDROCK_BASE_URL : apiPath;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Bedrock)) {
baseUrl = "https://" + baseUrl;
}
console.log("[Proxy Endpoint] ", baseUrl, path);
return [baseUrl, path].join("/");
}
speech(options: SpeechOptions): Promise<ArrayBuffer> { speech(options: SpeechOptions): Promise<ArrayBuffer> {
throw new Error("Speech not implemented for Bedrock."); throw new Error("Speech not implemented for Bedrock.");
} }
@ -239,17 +271,9 @@ export class BedrockApi implements LLMApi {
} }
try { try {
const apiEndpoint = "/api/bedrock/chat"; const chatPath = this.path("chat");
// const headers = {
// "Content-Type": requestBody.contentType || "application/json",
// Accept: requestBody.accept || "application/json",
// "X-Region": accessStore.awsRegion,
// "X-Access-Key": accessStore.awsAccessKey,
// "X-Secret-Key": accessStore.awsSecretKey,
// "X-Model-Id": modelConfig.model,
// "X-Encryption-Key": accessStore.bedrockEncryptionKey,
// };
const headers = getHeaders(); const headers = getHeaders();
headers.ModelID = modelConfig.model;
if (options.config.stream) { if (options.config.stream) {
let index = -1; let index = -1;
@ -261,7 +285,7 @@ export class BedrockApi implements LLMApi {
); );
return stream( return stream(
apiEndpoint, chatPath,
requestBody, requestBody,
headers, headers,
(tools as ToolDefinition[]).map((tool) => ({ (tools as ToolDefinition[]).map((tool) => ({
@ -367,7 +391,7 @@ export class BedrockApi implements LLMApi {
options, options,
); );
} else { } else {
const res = await fetch(apiEndpoint, { const res = await fetch(chatPath, {
method: "POST", method: "POST",
headers, headers,
body: JSON.stringify(requestBody), body: JSON.stringify(requestBody),
@ -375,7 +399,6 @@ export class BedrockApi implements LLMApi {
const resJson = await res.json(); const resJson = await res.json();
const message = this.extractMessage(resJson, modelConfig.model); const message = this.extractMessage(resJson, modelConfig.model);
// console.log("Extracted message:", message);
options.onFinish(message, res); options.onFinish(message, res);
} }
} catch (e) { } catch (e) {

View File

@ -120,7 +120,7 @@ const DEFAULT_ACCESS_STATE = {
chatglmApiKey: "", chatglmApiKey: "",
// aws bedrock // aws bedrock
bedrokUrl: DEFAULT_BEDROCK_URL, bedrockUrl: DEFAULT_BEDROCK_URL,
awsRegion: "", awsRegion: "",
awsAccessKey: "", awsAccessKey: "",
awsSecretKey: "", awsSecretKey: "",

View File

@ -6,9 +6,7 @@ import { AES, enc } from "crypto-js";
import { getServerSideConfig } from "../config/server"; import { getServerSideConfig } from "../config/server";
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
// console.info(serverConfig);
const SECRET_KEY = serverConfig.bedrockEncryptionKey || ""; const SECRET_KEY = serverConfig.bedrockEncryptionKey || "";
// console.info("======SECRET_KEY:"+SECRET_KEY);
if (serverConfig.isBedrock && !SECRET_KEY) { if (serverConfig.isBedrock && !SECRET_KEY) {
console.error("When use Bedrock modle,ENCRYPTION_KEY should been set!"); console.error("When use Bedrock modle,ENCRYPTION_KEY should been set!");
} }
@ -26,18 +24,13 @@ export function encrypt(data: string): string {
export function decrypt(encryptedData: string): string { export function decrypt(encryptedData: string): string {
if (!encryptedData) return ""; if (!encryptedData) return "";
try { try {
// Try to decrypt
const bytes = AES.decrypt(encryptedData, SECRET_KEY); const bytes = AES.decrypt(encryptedData, SECRET_KEY);
const decrypted = bytes.toString(enc.Utf8); const decrypted = bytes.toString(enc.Utf8);
// If decryption results in empty string but input wasn't empty,
// the input might already be decrypted
if (!decrypted && encryptedData) { if (!decrypted && encryptedData) {
return encryptedData; return encryptedData;
} }
return decrypted; return decrypted;
} catch (error) { } catch (error) {
// If decryption fails, the input might already be decrypted
return encryptedData; return encryptedData;
} }
} }
@ -91,32 +84,28 @@ function encodeURIComponent_RFC3986(str: string): string {
/[!'()*]/g, /[!'()*]/g,
(c) => "%" + c.charCodeAt(0).toString(16).toUpperCase(), (c) => "%" + c.charCodeAt(0).toString(16).toUpperCase(),
) )
.replace(/[-_.~]/g, (c) => c); // RFC 3986 unreserved characters .replace(/[-_.~]/g, (c) => c);
} }
function encodeURI_RFC3986(uri: string): string { function encodeURI_RFC3986(uri: string): string {
// Handle empty or root path
if (!uri || uri === "/") return ""; if (!uri || uri === "/") return "";
// Split the path into segments, preserving empty segments for double slashes
const segments = uri.split("/"); const segments = uri.split("/");
return segments return segments
.map((segment) => { .map((segment) => {
if (!segment) return ""; if (!segment) return "";
// Special handling for Bedrock model paths
if (segment.includes("model/")) { if (segment.includes("model/")) {
const parts = segment.split(/(model\/)/); const parts = segment.split(/(model\/)/);
return parts return parts
.map((part) => { .map((part) => {
if (part === "model/") return part; if (part === "model/") return part;
// Handle the model identifier part
if (part.includes(".") || part.includes(":")) { if (part.includes(".") || part.includes(":")) {
return part return part
.split(/([.:])/g) .split(/([.:])/g)
.map((subpart, i) => { .map((subpart, i) => {
if (i % 2 === 1) return subpart; // Don't encode separators if (i % 2 === 1) return subpart;
return encodeURIComponent_RFC3986(subpart); return encodeURIComponent_RFC3986(subpart);
}) })
.join(""); .join("");
@ -126,7 +115,6 @@ function encodeURI_RFC3986(uri: string): string {
.join(""); .join("");
} }
// Handle invoke-with-response-stream without encoding
if (segment === "invoke-with-response-stream") { if (segment === "invoke-with-response-stream") {
return segment; return segment;
} }
@ -147,17 +135,14 @@ export async function sign({
}: SignParams): Promise<Record<string, string>> { }: SignParams): Promise<Record<string, string>> {
const endpoint = new URL(url); const endpoint = new URL(url);
const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1));
const canonicalQueryString = endpoint.search.slice(1); // Remove leading '?' const canonicalQueryString = endpoint.search.slice(1);
// Create a date stamp and time stamp in ISO8601 format
const now = new Date(); const now = new Date();
const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, "");
const dateStamp = amzDate.slice(0, 8); const dateStamp = amzDate.slice(0, 8);
// Calculate the hash of the payload
const payloadHash = SHA256(body).toString(Hex); const payloadHash = SHA256(body).toString(Hex);
// Define headers with normalized values
const headers: Record<string, string> = { const headers: Record<string, string> = {
accept: "application/vnd.amazon.eventstream", accept: "application/vnd.amazon.eventstream",
"content-type": "application/json", "content-type": "application/json",
@ -167,24 +152,20 @@ export async function sign({
"x-amzn-bedrock-accept": "*/*", "x-amzn-bedrock-accept": "*/*",
}; };
// Get sorted header keys (case-insensitive)
const sortedHeaderKeys = Object.keys(headers).sort((a, b) => const sortedHeaderKeys = Object.keys(headers).sort((a, b) =>
a.toLowerCase().localeCompare(b.toLowerCase()), a.toLowerCase().localeCompare(b.toLowerCase()),
); );
// Create canonical headers string with normalized values
const canonicalHeaders = sortedHeaderKeys const canonicalHeaders = sortedHeaderKeys
.map( .map(
(key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`, (key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`,
) )
.join(""); .join("");
// Create signed headers string
const signedHeaders = sortedHeaderKeys const signedHeaders = sortedHeaderKeys
.map((key) => key.toLowerCase()) .map((key) => key.toLowerCase())
.join(";"); .join(";");
// Create canonical request
const canonicalRequest = [ const canonicalRequest = [
method.toUpperCase(), method.toUpperCase(),
canonicalUri, canonicalUri,
@ -194,7 +175,6 @@ export async function sign({
payloadHash, payloadHash,
].join("\n"); ].join("\n");
// Create the string to sign
const algorithm = "AWS4-HMAC-SHA256"; const algorithm = "AWS4-HMAC-SHA256";
const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`;
const stringToSign = [ const stringToSign = [
@ -204,18 +184,15 @@ export async function sign({
SHA256(canonicalRequest).toString(Hex), SHA256(canonicalRequest).toString(Hex),
].join("\n"); ].join("\n");
// Calculate the signature
const signingKey = getSigningKey(secretAccessKey, dateStamp, region, service); const signingKey = getSigningKey(secretAccessKey, dateStamp, region, service);
const signature = hmac(signingKey, stringToSign).toString(Hex); const signature = hmac(signingKey, stringToSign).toString(Hex);
// Create the authorization header
const authorization = [ const authorization = [
`${algorithm} Credential=${accessKeyId}/${credentialScope}`, `${algorithm} Credential=${accessKeyId}/${credentialScope}`,
`SignedHeaders=${signedHeaders}`, `SignedHeaders=${signedHeaders}`,
`Signature=${signature}`, `Signature=${signature}`,
].join(", "); ].join(", ");
// Return headers with proper casing for the request
return { return {
Accept: headers.accept, Accept: headers.accept,
"Content-Type": headers["content-type"], "Content-Type": headers["content-type"],