@@ -18,10 +18,15 @@ type contextKey string
// 定义常量作为上下文的键
const loginStatusKey contextKey = "login_status"
const defaultGlobalLockStateExpirationTime = 3600
func init ( ) {
func initAntiBruteForce ( ) {
if base . Cfg . AntiBruteForce {
if base . Cfg . GlobalLockStateExpirationTime <= 0 {
base . Cfg . GlobalLockStateExpirationTime = defaultGlobalLockStateExpirationTime
}
lockManager . startCleanupTicker ( )
lockManager . initIPWhitelist ( )
}
}
@@ -56,6 +61,12 @@ func antiBruteForce(next http.Handler) http.Handler {
}
now := time . Now ( )
// 检查IP是否在白名单中
if lockManager . isWhitelisted ( ip ) {
r . Body = io . NopCloser ( strings . NewReader ( string ( body ) ) )
next . ServeHTTP ( w , r )
return
}
// // 速率限制
// lockManager.mu.RLock()
@@ -114,23 +125,69 @@ type LockState struct {
LockTime time . Time
LastAttempt time . Time
}
type IPWhitelists struct {
IP net . IP
CIDR * net . IPNet
}
type LockManager struct {
mu sync . RW Mutex
ipLocks map [ string ] * LockState // 全局IP锁定状态
userLocks map [ string ] * LockState // 全局用户锁定状态
ipUserLocks map [ string ] map [ string ] * LockState // 单用户IP锁定状态
mu sync . Mutex
ipLocks map [ string ] * LockState // 全局IP锁定状态
userLocks map [ string ] * LockState // 全局用户锁定状态
ipUserLocks map [ string ] map [ string ] * LockState // 单用户IP锁定状态
ipWhitelists [ ] IPWhitelists // 全局IP白名单, 包含IP地址和CIDR范围
// rateLimiter map[string]*rate.Limiter // 速率限制器
cleanupTicker * time . Ticker
}
var lockManager = & LockManager {
ipLocks : make ( map [ string ] * LockState ) ,
userLocks : make ( map [ string ] * LockState ) ,
ipUserLocks : make ( map [ string ] map [ string ] * LockState ) ,
ipLocks : make ( map [ string ] * LockState ) ,
userLocks : make ( map [ string ] * LockState ) ,
ipUserLocks : make ( map [ string ] map [ string ] * LockState ) ,
ipWhitelists : make ( [ ] IPWhitelists , 0 ) ,
// rateLimiter: make(map[string]*rate.Limiter),
}
// 初始化IP白名单
func ( lm * LockManager ) initIPWhitelist ( ) {
ipWhitelist := strings . Split ( base . Cfg . IPWhitelist , "," )
for _ , ipWhitelist := range ipWhitelist {
ipWhitelist = strings . TrimSpace ( ipWhitelist )
if ipWhitelist == "" {
continue
}
_ , ipNet , err := net . ParseCIDR ( ipWhitelist )
if err == nil {
lm . ipWhitelists = append ( lm . ipWhitelists , IPWhitelists { CIDR : ipNet } )
continue
}
ip := net . ParseIP ( ipWhitelist )
if ip != nil {
lm . ipWhitelists = append ( lm . ipWhitelists , IPWhitelists { IP : ip } )
continue
}
}
}
// 检查 IP 是否在白名单中
func ( lm * LockManager ) isWhitelisted ( ip string ) bool {
clientIP := net . ParseIP ( ip )
if clientIP == nil {
return false
}
for _ , ipWhitelist := range lm . ipWhitelists {
if ipWhitelist . CIDR != nil && ipWhitelist . CIDR . Contains ( clientIP ) {
return true
}
if ipWhitelist . IP != nil && ipWhitelist . IP . Equal ( clientIP ) {
return true
}
}
return false
}
func ( lm * LockManager ) startCleanupTicker ( ) {
lm . cleanupTicker = time . NewTicker ( 5 * time . Minute )
go func ( ) {
@@ -187,8 +244,8 @@ func (lm *LockManager) cleanupExpiredLocks() {
// 检查全局 IP 锁定
func ( lm * LockManager ) checkGlobalIPLock ( ip string , now time . Time ) bool {
lm . mu . R Lock( )
defer lm . mu . R Unlock( )
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . ipLocks [ ip ]
if ! exists {
@@ -211,14 +268,13 @@ func (lm *LockManager) checkGlobalUserLock(username string, now time.Time) bool
if username == "" {
return false
}
lm . mu . R Lock( )
defer lm . mu . R Unlock( )
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
state , exists := lm . userLocks [ username ]
if ! exists {
return false
}
// 如果超过时间窗口,重置失败计数
lm . resetLockStateIfExpired ( state , now , base . Cfg . GlobalUserBanResetTime )
@@ -235,8 +291,8 @@ func (lm *LockManager) checkUserIPLock(username, ip string, now time.Time) bool
if username == "" {
return false
}
lm . mu . R Lock( )
defer lm . mu . R Unlock( )
lm . mu . Lock ( )
defer lm . mu . Unlock ( )
userIPMap , userExists := lm . ipUserLocks [ username ]
if ! userExists {