From f195ae2d3023c696d97230735e83d8992f98b5af Mon Sep 17 00:00:00 2001
From: wsczx <wsc@wsczx.com>
Date: Fri, 4 Oct 2024 00:17:56 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 server/handler/antiBruteForce.go | 70 ++++++++++++++------------------
 1 file changed, 30 insertions(+), 40 deletions(-)

diff --git a/server/handler/antiBruteForce.go b/server/handler/antiBruteForce.go
index 684a669..507856c 100644
--- a/server/handler/antiBruteForce.go
+++ b/server/handler/antiBruteForce.go
@@ -197,16 +197,13 @@ func (lm *LockManager) checkGlobalIPLock(ip string, now time.Time) bool {
 		return false
 	}
 
+	// 如果超过时间窗口,重置失败计数
+	lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalIPBanResetTime)
+
 	if !state.LockTime.IsZero() && now.Before(state.LockTime) {
 		return true
 	}
 
-	// 如果超过时间窗口,重置失败计数
-	if now.Sub(state.LastAttempt) > time.Duration(base.Cfg.GlobalIPBanResetTime)*time.Second {
-		state.FailureCount = 0
-		state.LockTime = time.Time{}
-	}
-
 	return false
 }
 
@@ -224,16 +221,13 @@ func (lm *LockManager) checkGlobalUserLock(username string, now time.Time) bool
 		return false
 	}
 
+	// 如果超过时间窗口,重置失败计数
+	lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalUserBanResetTime)
+
 	if !state.LockTime.IsZero() && now.Before(state.LockTime) {
 		return true
 	}
 
-	// 如果超过时间窗口,重置失败计数
-	if now.Sub(state.LastAttempt) > time.Duration(base.Cfg.GlobalUserBanResetTime)*time.Second {
-		state.FailureCount = 0
-		state.LockTime = time.Time{}
-	}
-
 	return false
 }
 
@@ -256,16 +250,13 @@ func (lm *LockManager) checkUserIPLock(username, ip string, now time.Time) bool
 		return false
 	}
 
+	// 如果超过时间窗口,重置失败计数
+	lm.resetLockStateIfExpired(state, now, base.Cfg.BanResetTime)
+
 	if !state.LockTime.IsZero() && now.Before(state.LockTime) {
 		return true
 	}
 
-	// 如果超过时间窗口,重置失败计数
-	if now.Sub(state.LastAttempt) > time.Duration(base.Cfg.BanResetTime)*time.Second {
-		state.FailureCount = 0
-		state.LockTime = time.Time{}
-	}
-
 	return false
 }
 
@@ -280,16 +271,7 @@ func (lm *LockManager) updateGlobalIPLock(ip string, now time.Time, success bool
 		lm.ipLocks[ip] = state
 	}
 
-	if success {
-		state.FailureCount = 0
-		state.LockTime = time.Time{}
-	} else {
-		state.FailureCount++
-		if state.FailureCount >= base.Cfg.MaxGlobalIPBanCount {
-			state.LockTime = now.Add(time.Duration(base.Cfg.GlobalIPLockTime) * time.Second)
-		}
-	}
-	state.LastAttempt = now
+	lm.updateLockState(state, now, success, base.Cfg.MaxGlobalIPBanCount, base.Cfg.GlobalIPLockTime)
 }
 
 // 更新全局用户锁定状态
@@ -307,16 +289,7 @@ func (lm *LockManager) updateGlobalUserLock(username string, now time.Time, succ
 		lm.userLocks[username] = state
 	}
 
-	if success {
-		state.FailureCount = 0
-		state.LockTime = time.Time{}
-	} else {
-		state.FailureCount++
-		if state.FailureCount >= base.Cfg.MaxGlobalUserBanCount {
-			state.LockTime = now.Add(time.Duration(base.Cfg.GlobalUserLockTime) * time.Second)
-		}
-	}
-	state.LastAttempt = now
+	lm.updateLockState(state, now, success, base.Cfg.MaxGlobalUserBanCount, base.Cfg.GlobalUserLockTime)
 }
 
 // 更新单个用户的 IP 锁定状态
@@ -340,14 +313,31 @@ func (lm *LockManager) updateUserIPLock(username, ip string, now time.Time, succ
 		userIPMap[ip] = state
 	}
 
+	lm.updateLockState(state, now, success, base.Cfg.MaxBanCount, base.Cfg.LockTime)
+}
+
+// 更新锁定状态
+func (lm *LockManager) updateLockState(state *LockState, now time.Time, success bool, maxBanCount, lockTime int) {
 	if success {
 		state.FailureCount = 0
 		state.LockTime = time.Time{}
 	} else {
 		state.FailureCount++
-		if state.FailureCount >= base.Cfg.MaxBanCount {
-			state.LockTime = now.Add(time.Duration(base.Cfg.LockTime) * time.Second)
+		if state.FailureCount >= maxBanCount {
+			state.LockTime = now.Add(time.Duration(lockTime) * time.Second)
 		}
 	}
 	state.LastAttempt = now
 }
+
+// 超过时间窗口时重置锁定状态
+func (lm *LockManager) resetLockStateIfExpired(state *LockState, now time.Time, resetTime int) {
+	if state == nil || state.LastAttempt.IsZero() {
+		return
+	}
+
+	if now.Sub(state.LastAttempt) > time.Duration(resetTime)*time.Second {
+		state.FailureCount = 0
+		state.LockTime = time.Time{}
+	}
+}