@@ -3,6 +3,8 @@ package handler
import (
"encoding/xml"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
@@ -11,58 +13,38 @@ import (
"github.com/bjdgyc/anylink/base"
)
// UserState 用于存储用户的登录状态
type UserState struct {
FailureCount int
LastAttempt time . Time
LockTime time . Time
}
// 自定义 contextKey 类型,避免键冲突
type contextKey string
// 定义常量作为上下文的键
const loginStatusKey contextKey = "login_status"
const defaultGlobalLockStateExpirationTime = 3600
// 用户状态映射
var userStates = make ( map [ string ] * UserState )
var mu sync . Mutex
func init ( ) {
go cleanupUserStates ( )
}
// 清理过期的登录状态
func cleanupUserStates ( ) {
ticker := time . NewTicker ( 1 * time . Minute )
defer ticker . Stop ( )
for range ticker . C {
mu . Lock ( )
now := time . Now ( )
for username , state := range userStates {
if now . Sub ( state . LastAttempt ) > time . Duration ( base . Cfg . UserStateExpiration ) * time . Second {
delete ( userStates , username )
}
func initAntiBruteForce ( ) {
if base . Cfg . AntiBruteForce {
if base . Cfg . GlobalLockStateExpirationTime <= 0 {
base . Cfg . GlobalLockStateExpirationTime = defaultGlobalLockStateExpirationTime
}
mu . Unlock ( )
lockManager . startCleanupTicker ( )
lockManager . initIPWhitelist ( )
}
}
// 防爆破中间件
func antiBruteForce ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
// 如果最大验证失败次数为0, 则不启用 防爆破功能
if base . Cfg . MaxBanCount == 0 {
// 防爆破功能全局开关
if ! base . Cfg . AntiBruteForce {
next . ServeHTTP ( w , r )
return
}
body , err := io . ReadAll ( r . Body )
if err != nil {
http . Error ( w , "Failed to read request body" , http . StatusBadRequest )
return
}
r . Body . Close ( )
defer r . Body . Close ( )
cr := ClientRequest { }
err = xml . Unmarshal ( body , & cr )
@@ -72,35 +54,55 @@ func antiBruteForce(next http.Handler) http.Handler {
}
username := cr . Auth . Username
// 更新用户登录状态
mu . Lock ( )
state , exists := userStates [ username ]
if ! exists {
state = & UserState { }
userStates [ username ] = state
}
// 检查是否已超过锁定时间
if ! state . LockTime . IsZero ( ) {
if time . Now ( ) . After ( state . LockTime ) {
// 如果已经超过了锁定时间,重置失败计数和锁定时间
state . FailureCount = 0
state . LockTime = time . Time { }
} else {
// 如果还在锁定时间内,返回错误信息
http . Error ( w , "Account locked due to too many failed attempts. Try again later." , http . StatusTooManyRequests )
mu . Unlock ( )
return
}
ip , _ , err := net . SplitHostPort ( r . RemoteAddr ) // 提取纯 IP 地址,去掉端口号
if err != nil {
http . Error ( w , "Unable to parse IP address" , http . StatusInternalServerError )
return
}
// 如果超过时间窗口,重置失败计数
if time . Since ( state . LastAttempt ) > time . Duration ( base . Cfg . BanResetTime ) * time . Second {
sta te . FailureCount = 0
now := time . Now ( )
// 检查IP是否在白名单中
if lockManager . isWhitelis ted ( ip ) {
r . Body = io . NopCloser ( strings . NewReader ( string ( body ) ) )
next . ServeHTTP ( w , r )
return
}
state . LastAttempt = time . Now ( )
mu . Unlock ( )
// // 速率限制
// lockManager.mu.RLock( )
// limiter, exists := lockManager.rateLimiter[ip]
// if !exists {
// limiter = rate.NewLimiter(rate.Limit(base.Cfg.RateLimit), base.Cfg.Burst)
// lockManager.rateLimiter[ip] = limiter
// }
// lockManager.mu.RUnlock()
// if !limiter.Allow() {
// log.Printf("Rate limit exceeded for IP %s. Try again later.", ip)
// http.Error(w, "Rate limit exceeded. Try again later.", http.StatusTooManyRequests)
// return
// }
// 检查全局 IP 锁定
if base . Cfg . MaxGlobalIPBanCount > 0 && lockManager . checkGlobalIPLock ( ip , now ) {
log . Printf ( "IP %s is globally locked. Try again later." , ip )
http . Error ( w , "Account globally locked due to too many failed attempts. Try again later." , http . StatusTooManyRequests )
return
}
// 检查全局用户锁定
if base . Cfg . MaxGlobalUserBanCount > 0 && lockManager . checkGlobalUserLock ( username , now ) {
log . Printf ( "User %s is globally locked. Try again later." , username )
http . Error ( w , "Account globally locked due to too many failed attempts. Try again later." , http . StatusTooManyRequests )
return
}
// 检查单个用户的 IP 锁定
if base . Cfg . MaxBanCount > 0 && lockManager . checkUserIPLock ( username , ip , now ) {
log . Printf ( "IP %s is locked for user %s. Try again later." , ip , username )
http . Error ( w , "Account locked due to too many failed attempts. Try again later." , http . StatusTooManyRequests )
return
}
// 重新设置请求体以便后续处理器可以访问
r . Body = io . NopCloser ( strings . NewReader ( string ( body ) ) )
@@ -109,23 +111,295 @@ func antiBruteForce(next http.Handler) http.Handler {
next . ServeHTTP ( w , r )
// 从 context 中获取登录状态
loginStatus , ok := r . Context ( ) . Value ( loginStatusKey ) . ( bool )
if ! ok {
// 如果没有找到登录状态,默认为登录失败
loginStatus = false
}
loginStatus , _ := r . Context ( ) . Value ( loginStatusKey ) . ( bool )
// 更新用户登录状态
mu . Lock ( )
defer mu . Unlock ( )
if ! loginStatus {
state . FailureCount ++
if state . FailureCount >= base . Cfg . MaxBanCount {
state . LockTime = time . Now ( ) . Add ( time . Duration ( base . Cfg . LockTime ) * time . Second )
}
} else {
state . FailureCount = 0 // 成功登录后重置
}
lockManager . updateGlobalIPLock ( ip , now , loginStatus )
lockManager . updateGlobalUserLock ( username , now , loginStatus )
lockManager . updateUserIPLock ( username , ip , now , loginStatus )
} )
}
type LockState struct {
FailureCount int
LockTime time . Time
LastAttempt time . Time
}
type IPWhitelists struct {
IP net . IP
CIDR * net . IPNet
}
type LockManager struct {
mu sync . Mutex
ipLocks map [ string ] * LockState // 全局IP锁定状态
userLocks map [ string ] * LockState // 全局用户锁定状态
ipUserLocks map [ string ] map [ string ] * LockState // 单用户IP锁定状态
ipWhitelists [ ] IPWhitelists // 全局IP白名单, 包含IP地址和CIDR范围
// rateLimiter map[string]*rate.Limiter // 速率限制器
cleanupTicker * time . Ticker
}
var lockManager = & LockManager {
ipLocks : make ( map [ string ] * LockState ) ,
userLocks : make ( map [ string ] * LockState ) ,
ipUserLocks : make ( map [ string ] map [ string ] * LockState ) ,
ipWhitelists : make ( [ ] IPWhitelists , 0 ) ,
// rateLimiter: make(map[string]*rate.Limiter),
}
// 初始化IP白名单
func ( lm * LockManager ) initIPWhitelist ( ) {
ipWhitelist := strings . Split ( base . Cfg . IPWhitelist , "," )
for _ , ipWhitelist := range ipWhitelist {
ipWhitelist = strings . TrimSpace ( ipWhitelist )
if ipWhitelist == "" {
continue
}
_ , ipNet , err := net . ParseCIDR ( ipWhitelist )
if err == nil {
lm . ipWhitelists = append ( lm . ipWhitelists , IPWhitelists { CIDR : ipNet } )
continue
}
ip := net . ParseIP ( ipWhitelist )
if ip != nil {
lm . ipWhitelists = append ( lm . ipWhitelists , IPWhitelists { IP : ip } )
continue
}
}
}
// 检查 IP 是否在白名单中
func ( lm * LockManager ) isWhitelisted ( ip string ) bool {
clientIP := net . ParseIP ( ip )
if clientIP == nil {
return false
}
for _ , ipWhitelist := range lm . ipWhitelists {
if ipWhitelist . CIDR != nil && ipWhitelist . CIDR . Contains ( clientIP ) {
return true
}
if ipWhitelist . IP != nil && ipWhitelist . IP . Equal ( clientIP ) {
return true
}
}
return false
}
func ( lm * LockManager ) startCleanupTicker ( ) {
lm . cleanupTicker = time . NewTicker ( 5 * time . Minute )
go func ( ) {
for range lm . cleanupTicker . C {
lm . cleanupExpiredLocks ( )
}
} ( )
}
// 定期清理过期的锁定
func ( lm * LockManager ) cleanupExpiredLocks ( ) {
now := time . Now ( )
var ipKeys , userKeys [ ] string
var IPuserKeys [ ] struct { user , ip string }
lm . mu . Lock ( )
for ip , state := range lm . ipLocks {
if now . Sub ( state . LastAttempt ) > time . Duration ( base . Cfg . GlobalLockStateExpirationTime ) * time . Second {
ipKeys = append ( ipKeys , ip )
}
}
for user , state := range lm . userLocks {
if now . Sub ( state . LastAttempt ) > time . Duration ( base . Cfg . GlobalLockStateExpirationTime ) * time . Second {
userKeys = append ( userKeys , user )
}
}
for user , ipMap := range lm . ipUserLocks {
for ip , state := range ipMap {
if now . Sub ( state . LastAttempt ) > time . Duration ( base . Cfg . GlobalLockStateExpirationTime ) * time . Second {
IPuserKeys = append ( IPuserKeys , struct { user , ip string } { user , ip } )
}
}
}
lm . mu . Unlock ( )
lm . mu . Lock ( )
for _ , ip := range ipKeys {
delete ( lm . ipLocks , ip )
}
for _ , user := range userKeys {
delete ( lm . userLocks , user )
}
for _ , key := range IPuserKeys {
delete ( lm . ipUserLocks [ key . user ] , key . ip )
if len ( lm . ipUserLocks [ key . user ] ) == 0 {
delete ( lm . ipUserLocks , key . user )
}
}
lm . mu . Unlock ( )
}
// 检查全局 IP 锁定
func ( lm * LockManager ) checkGlobalIPLock ( ip string , now time . Time ) bool {
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . ipLocks [ ip ]
if ! exists {
return false
}
// 如果超过时间窗口,重置失败计数
lm . resetLockStateIfExpired ( state , now , base . Cfg . GlobalIPBanResetTime )
if ! state . LockTime . IsZero ( ) && now . Before ( state . LockTime ) {
return true
}
return false
}
// 检查全局用户锁定
func ( lm * LockManager ) checkGlobalUserLock ( username string , now time . Time ) bool {
// 我也不知道为什么cisco anyconnect每次连接会先传一个空用户请求····
if username == "" {
return false
}
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . userLocks [ username ]
if ! exists {
return false
}
// 如果超过时间窗口,重置失败计数
lm . resetLockStateIfExpired ( state , now , base . Cfg . GlobalUserBanResetTime )
if ! state . LockTime . IsZero ( ) && now . Before ( state . LockTime ) {
return true
}
return false
}
// 检查单个用户的 IP 锁定
func ( lm * LockManager ) checkUserIPLock ( username , ip string , now time . Time ) bool {
// 我也不知道为什么cisco anyconnect每次连接会先传一个空用户请求····
if username == "" {
return false
}
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
userIPMap , userExists := lm . ipUserLocks [ username ]
if ! userExists {
return false
}
state , ipExists := userIPMap [ ip ]
if ! ipExists {
return false
}
// 如果超过时间窗口,重置失败计数
lm . resetLockStateIfExpired ( state , now , base . Cfg . BanResetTime )
if ! state . LockTime . IsZero ( ) && now . Before ( state . LockTime ) {
return true
}
return false
}
// 更新全局 IP 锁定状态
func ( lm * LockManager ) updateGlobalIPLock ( ip string , now time . Time , success bool ) {
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . ipLocks [ ip ]
if ! exists {
state = & LockState { }
lm . ipLocks [ ip ] = state
}
lm . updateLockState ( state , now , success , base . Cfg . MaxGlobalIPBanCount , base . Cfg . GlobalIPLockTime )
}
// 更新全局用户锁定状态
func ( lm * LockManager ) updateGlobalUserLock ( username string , now time . Time , success bool ) {
// 我也不知道为什么cisco anyconnect每次连接会先传一个空用户请求····
if username == "" {
return
}
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . userLocks [ username ]
if ! exists {
state = & LockState { }
lm . userLocks [ username ] = state
}
lm . updateLockState ( state , now , success , base . Cfg . MaxGlobalUserBanCount , base . Cfg . GlobalUserLockTime )
}
// 更新单个用户的 IP 锁定状态
func ( lm * LockManager ) updateUserIPLock ( username , ip string , now time . Time , success bool ) {
// 我也不知道为什么cisco anyconnect每次连接会先传一个空用户请求····
if username == "" {
return
}
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
userIPMap , userExists := lm . ipUserLocks [ username ]
if ! userExists {
userIPMap = make ( map [ string ] * LockState )
lm . ipUserLocks [ username ] = userIPMap
}
state , ipExists := userIPMap [ ip ]
if ! ipExists {
state = & LockState { }
userIPMap [ ip ] = state
}
lm . updateLockState ( state , now , success , base . Cfg . MaxBanCount , base . Cfg . LockTime )
}
// 更新锁定状态
func ( lm * LockManager ) updateLockState ( state * LockState , now time . Time , success bool , maxBanCount , lockTime int ) {
if success {
state . FailureCount = 0
state . LockTime = time . Time { }
} else {
state . FailureCount ++
if state . FailureCount >= maxBanCount {
state . LockTime = now . Add ( time . Duration ( lockTime ) * time . Second )
}
}
state . LastAttempt = now
}
// 超过窗口时间和锁定时间时重置锁定状态
func ( lm * LockManager ) resetLockStateIfExpired ( state * LockState , now time . Time , resetTime int ) {
if state == nil || state . LastAttempt . IsZero ( ) {
return
}
// 如果超过锁定时间,重置锁定状态
if ! state . LockTime . IsZero ( ) && now . After ( state . LockTime ) {
state . FailureCount = 0
state . LockTime = time . Time { }
return
}
// 如果超过窗口时间,重置失败计数
if now . Sub ( state . LastAttempt ) > time . Duration ( resetTime ) * time . Second {
state . FailureCount = 0
state . LockTime = time . Time { }
}
}