From a85db21e1f72ba784d63904c778487081a839336 Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 23 Nov 2024 12:09:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=B9=E6=B3=95=E5=91=BD=E5=90=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 23 +++++++++------- app/client/api.ts | 41 ++++++++++++++------------- app/client/platforms/bedrock.ts | 49 ++++++++++++++++++++++++--------- app/store/access.ts | 2 +- app/utils/aws.ts | 29 ++----------------- 5 files changed, 73 insertions(+), 71 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index e6b039ae8..2154ee5ff 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -1,7 +1,9 @@ import { NextRequest, NextResponse } from "next/server"; +import { auth } from "./auth"; import { sign } from "../utils/aws"; import { getServerSideConfig } from "../config/server"; - +import { ModelProvider } from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; const ALLOWED_PATH = new Set(["chat", "models"]); function parseEventData(chunk: Uint8Array): any { @@ -189,7 +191,7 @@ async function requestBedrock(req: NextRequest) { let awsRegion = config.awsRegion; let awsAccessKey = config.awsAccessKey; let awsSecretKey = config.awsSecretKey; - let modelId = ""; + let modelId = req.headers.get("ModelID"); // If server-side credentials are not available, parse from Authorization header if (!awsRegion || !awsAccessKey || !awsSecretKey) { @@ -199,16 +201,15 @@ async function requestBedrock(req: NextRequest) { } const [_, credentials] = authHeader.split("Bearer "); - const [region, accessKey, secretKey, model] = credentials.split(","); + const [region, accessKey, secretKey] = credentials.split(":"); - if (!region || !accessKey || !secretKey || !model) { + if (!region || !accessKey || !secretKey) { throw new Error("Invalid Authorization header format"); } awsRegion = region; awsAccessKey = accessKey; awsSecretKey = secretKey; - modelId = model; } if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { @@ -329,14 +330,16 @@ export async function handle( { status: 403 }, ); } - + const authResult = auth(req, ModelProvider.Bedrock); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } try { return await requestBedrock(req); } catch (e) { console.error("Handler error:", e); - return NextResponse.json( - { error: true, msg: e instanceof Error ? e.message : "Unknown error" }, - { status: 500 }, - ); + return NextResponse.json(prettyObject(e)); } } diff --git a/app/client/api.ts b/app/client/api.ts index eb0e4270d..47e3b674e 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -280,12 +280,10 @@ export function getHeaders(ignoreHeaders: boolean = false) { accessStore.awsAccessKey && accessStore.awsSecretKey ? accessStore.awsRegion + - "," + + ":" + accessStore.awsAccessKey + - "," + - accessStore.awsSecretKey + - "," + - modelConfig.model + ":" + + accessStore.awsSecretKey : "" : accessStore.openaiApiKey; return { @@ -316,6 +314,7 @@ export function getHeaders(ignoreHeaders: boolean = false) { } const { + isBedrock, isGoogle, isAzure, isAnthropic, @@ -328,23 +327,23 @@ export function getHeaders(ignoreHeaders: boolean = false) { const authHeader = getAuthHeader(); - // if (isBedrock) { - // // Secure encryption of AWS credentials using the new encryption utility - // headers["X-Region"] = encrypt(accessStore.awsRegion); - // headers["X-Access-Key"] = encrypt(accessStore.awsAccessKey); - // headers["X-Secret-Key"] = encrypt(accessStore.awsSecretKey); - // } else { - const bearerToken = getBearerToken( - apiKey, - isAzure || isAnthropic || isGoogle, - ); - - if (bearerToken) { - headers[authHeader] = bearerToken; - } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { - headers["Authorization"] = getBearerToken( - ACCESS_CODE_PREFIX + accessStore.accessCode, + if (isBedrock) { + if (apiKey) { + headers[authHeader] = getBearerToken(apiKey); + } + } else { + const bearerToken = getBearerToken( + apiKey, + isAzure || isAnthropic || isGoogle, ); + + if (bearerToken) { + headers[authHeader] = bearerToken; + } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { + headers["Authorization"] = getBearerToken( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); + } } return headers; diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index a4a11c30b..d55173746 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -1,3 +1,4 @@ +"use client"; import { ChatOptions, getHeaders, @@ -16,6 +17,8 @@ import { } from "../../store"; import { preProcessImageContent, stream } from "../../utils/chat"; import { getMessageTextContent, isVisionModel } from "../../utils"; +import { ApiPath, BEDROCK_BASE_URL } from "../../constant"; +import { getClientConfig } from "../../config/client"; const ClaudeMapper = { assistant: "assistant", @@ -34,6 +37,35 @@ interface ToolDefinition { } export class BedrockApi implements LLMApi { + private disableListModels = true; + + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + + if (accessStore.useCustomConfig) { + baseUrl = accessStore.bedrockUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + const apiPath = ApiPath.Bedrock; + baseUrl = isApp ? BEDROCK_BASE_URL : apiPath; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Bedrock)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + return [baseUrl, path].join("/"); + } + speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); } @@ -239,17 +271,9 @@ export class BedrockApi implements LLMApi { } try { - const apiEndpoint = "/api/bedrock/chat"; - // const headers = { - // "Content-Type": requestBody.contentType || "application/json", - // Accept: requestBody.accept || "application/json", - // "X-Region": accessStore.awsRegion, - // "X-Access-Key": accessStore.awsAccessKey, - // "X-Secret-Key": accessStore.awsSecretKey, - // "X-Model-Id": modelConfig.model, - // "X-Encryption-Key": accessStore.bedrockEncryptionKey, - // }; + const chatPath = this.path("chat"); const headers = getHeaders(); + headers.ModelID = modelConfig.model; if (options.config.stream) { let index = -1; @@ -261,7 +285,7 @@ export class BedrockApi implements LLMApi { ); return stream( - apiEndpoint, + chatPath, requestBody, headers, (tools as ToolDefinition[]).map((tool) => ({ @@ -367,7 +391,7 @@ export class BedrockApi implements LLMApi { options, ); } else { - const res = await fetch(apiEndpoint, { + const res = await fetch(chatPath, { method: "POST", headers, body: JSON.stringify(requestBody), @@ -375,7 +399,6 @@ export class BedrockApi implements LLMApi { const resJson = await res.json(); const message = this.extractMessage(resJson, modelConfig.model); - // console.log("Extracted message:", message); options.onFinish(message, res); } } catch (e) { diff --git a/app/store/access.ts b/app/store/access.ts index 5ec99b175..be35f8925 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -120,7 +120,7 @@ const DEFAULT_ACCESS_STATE = { chatglmApiKey: "", // aws bedrock - bedrokUrl: DEFAULT_BEDROCK_URL, + bedrockUrl: DEFAULT_BEDROCK_URL, awsRegion: "", awsAccessKey: "", awsSecretKey: "", diff --git a/app/utils/aws.ts b/app/utils/aws.ts index d88707328..bca9d699a 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -6,9 +6,7 @@ import { AES, enc } from "crypto-js"; import { getServerSideConfig } from "../config/server"; const serverConfig = getServerSideConfig(); -// console.info(serverConfig); const SECRET_KEY = serverConfig.bedrockEncryptionKey || ""; -// console.info("======SECRET_KEY:"+SECRET_KEY); if (serverConfig.isBedrock && !SECRET_KEY) { console.error("When use Bedrock modle,ENCRYPTION_KEY should been set!"); } @@ -26,18 +24,13 @@ export function encrypt(data: string): string { export function decrypt(encryptedData: string): string { if (!encryptedData) return ""; try { - // Try to decrypt const bytes = AES.decrypt(encryptedData, SECRET_KEY); const decrypted = bytes.toString(enc.Utf8); - - // If decryption results in empty string but input wasn't empty, - // the input might already be decrypted if (!decrypted && encryptedData) { return encryptedData; } return decrypted; } catch (error) { - // If decryption fails, the input might already be decrypted return encryptedData; } } @@ -91,32 +84,28 @@ function encodeURIComponent_RFC3986(str: string): string { /[!'()*]/g, (c) => "%" + c.charCodeAt(0).toString(16).toUpperCase(), ) - .replace(/[-_.~]/g, (c) => c); // RFC 3986 unreserved characters + .replace(/[-_.~]/g, (c) => c); } function encodeURI_RFC3986(uri: string): string { - // Handle empty or root path if (!uri || uri === "/") return ""; - // Split the path into segments, preserving empty segments for double slashes const segments = uri.split("/"); return segments .map((segment) => { if (!segment) return ""; - // Special handling for Bedrock model paths if (segment.includes("model/")) { const parts = segment.split(/(model\/)/); return parts .map((part) => { if (part === "model/") return part; - // Handle the model identifier part if (part.includes(".") || part.includes(":")) { return part .split(/([.:])/g) .map((subpart, i) => { - if (i % 2 === 1) return subpart; // Don't encode separators + if (i % 2 === 1) return subpart; return encodeURIComponent_RFC3986(subpart); }) .join(""); @@ -126,7 +115,6 @@ function encodeURI_RFC3986(uri: string): string { .join(""); } - // Handle invoke-with-response-stream without encoding if (segment === "invoke-with-response-stream") { return segment; } @@ -147,17 +135,14 @@ export async function sign({ }: SignParams): Promise> { const endpoint = new URL(url); const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); - const canonicalQueryString = endpoint.search.slice(1); // Remove leading '?' + const canonicalQueryString = endpoint.search.slice(1); - // Create a date stamp and time stamp in ISO8601 format const now = new Date(); const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); const dateStamp = amzDate.slice(0, 8); - // Calculate the hash of the payload const payloadHash = SHA256(body).toString(Hex); - // Define headers with normalized values const headers: Record = { accept: "application/vnd.amazon.eventstream", "content-type": "application/json", @@ -167,24 +152,20 @@ export async function sign({ "x-amzn-bedrock-accept": "*/*", }; - // Get sorted header keys (case-insensitive) const sortedHeaderKeys = Object.keys(headers).sort((a, b) => a.toLowerCase().localeCompare(b.toLowerCase()), ); - // Create canonical headers string with normalized values const canonicalHeaders = sortedHeaderKeys .map( (key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`, ) .join(""); - // Create signed headers string const signedHeaders = sortedHeaderKeys .map((key) => key.toLowerCase()) .join(";"); - // Create canonical request const canonicalRequest = [ method.toUpperCase(), canonicalUri, @@ -194,7 +175,6 @@ export async function sign({ payloadHash, ].join("\n"); - // Create the string to sign const algorithm = "AWS4-HMAC-SHA256"; const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; const stringToSign = [ @@ -204,18 +184,15 @@ export async function sign({ SHA256(canonicalRequest).toString(Hex), ].join("\n"); - // Calculate the signature const signingKey = getSigningKey(secretAccessKey, dateStamp, region, service); const signature = hmac(signingKey, stringToSign).toString(Hex); - // Create the authorization header const authorization = [ `${algorithm} Credential=${accessKeyId}/${credentialScope}`, `SignedHeaders=${signedHeaders}`, `Signature=${signature}`, ].join(", "); - // Return headers with proper casing for the request return { Accept: headers.accept, "Content-Type": headers["content-type"],