From ffb7dbfb1c075446e1d401dda7915719f8b0d12e Mon Sep 17 00:00:00 2001
From: wsczx <wsc@wsczx.com>
Date: Sat, 9 Nov 2024 00:12:26 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=98=B2=E7=88=86=E5=8A=9F?=
 =?UTF-8?q?=E8=83=BD=E5=B9=B6=E5=8F=91=E7=9A=84=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 server/admin/lockmanager.go      | 11 +++++------
 server/handler/antiBruteForce.go |  6 ++----
 server/handler/link_auth.go      | 14 ++++----------
 server/handler/link_auth_otp.go  | 22 +++++++++++-----------
 4 files changed, 22 insertions(+), 31 deletions(-)

diff --git a/server/admin/lockmanager.go b/server/admin/lockmanager.go
index 968ae64..4db1985 100644
--- a/server/admin/lockmanager.go
+++ b/server/admin/lockmanager.go
@@ -2,7 +2,6 @@ package admin
 
 import (
 	"encoding/json"
-	"fmt"
 	"io"
 	"net"
 	"net/http"
@@ -89,7 +88,7 @@ func UnlockUser(w http.ResponseWriter, r *http.Request) {
 	}
 
 	if lockinfo.State == nil {
-		RespError(w, RespInternalErr, fmt.Errorf("未找到锁定用户!"))
+		RespError(w, RespInternalErr, "未找到锁定用户!")
 		return
 	}
 	lm := GetLockManager()
@@ -111,7 +110,7 @@ func UnlockUser(w http.ResponseWriter, r *http.Request) {
 	}
 
 	if state == nil || !state.Locked {
-		RespError(w, RespInternalErr, fmt.Errorf("锁定状态未找到或已解锁"))
+		RespError(w, RespInternalErr, "锁定状态未找到或已解锁")
 		return
 	}
 
@@ -238,14 +237,14 @@ func (lm *LockManager) CleanupExpiredLocks() {
 	defer lm.mu.Unlock()
 
 	for ip, state := range lm.ipLocks {
-		if !lm.CheckLockState(state, now, base.Cfg.GlobalIPLockTime) ||
+		if !lm.CheckLockState(state, now, base.Cfg.GlobalIPBanResetTime) ||
 			now.Sub(state.LastAttempt) > time.Duration(base.Cfg.GlobalLockStateExpirationTime)*time.Second {
 			delete(lm.ipLocks, ip)
 		}
 	}
 
 	for user, state := range lm.userLocks {
-		if !lm.CheckLockState(state, now, base.Cfg.GlobalUserLockTime) ||
+		if !lm.CheckLockState(state, now, base.Cfg.GlobalUserBanResetTime) ||
 			now.Sub(state.LastAttempt) > time.Duration(base.Cfg.GlobalLockStateExpirationTime)*time.Second {
 			delete(lm.userLocks, user)
 		}
@@ -253,7 +252,7 @@ func (lm *LockManager) CleanupExpiredLocks() {
 
 	for user, ipMap := range lm.ipUserLocks {
 		for ip, state := range ipMap {
-			if !lm.CheckLockState(state, now, base.Cfg.LockTime) ||
+			if !lm.CheckLockState(state, now, base.Cfg.BanResetTime) ||
 				now.Sub(state.LastAttempt) > time.Duration(base.Cfg.GlobalLockStateExpirationTime)*time.Second {
 				delete(ipMap, ip)
 				if len(ipMap) == 0 {
diff --git a/server/handler/antiBruteForce.go b/server/handler/antiBruteForce.go
index 7a5e82c..999a6fc 100644
--- a/server/handler/antiBruteForce.go
+++ b/server/handler/antiBruteForce.go
@@ -14,8 +14,6 @@ import (
 
 var lockManager = admin.GetLockManager()
 
-const loginStatusKey = "login_status"
-
 // 防爆破中间件
 func antiBruteForce(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -96,7 +94,7 @@ func antiBruteForce(next http.Handler) http.Handler {
 		next.ServeHTTP(w, r)
 
 		// 检查登录状态
-		Status, _ := lockManager.LoginStatus.Load(loginStatusKey)
+		Status, _ := lockManager.LoginStatus.Load(cr.SessionId)
 		loginStatus, _ := Status.(bool)
 
 		// 更新用户登录状态
@@ -105,6 +103,6 @@ func antiBruteForce(next http.Handler) http.Handler {
 		lockManager.UpdateUserIPLock(username, ip, now, loginStatus)
 
 		// 清除登录状态
-		lockManager.LoginStatus.Delete(loginStatusKey)
+		lockManager.LoginStatus.Delete(cr.SessionId)
 	})
 }
diff --git a/server/handler/link_auth.go b/server/handler/link_auth.go
index 666330d..158f061 100644
--- a/server/handler/link_auth.go
+++ b/server/handler/link_auth.go
@@ -94,7 +94,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
 	// TODO 用户密码校验
 	err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect)
 	if err != nil {
-		lockManager.LoginStatus.Store(loginStatusKey, false) // 记录登录失败状态
+		lockManager.LoginStatus.Store(cr.SessionId, false) // 记录登录失败状态
 		base.Warn(err, r.RemoteAddr)
 		ua.Info = err.Error()
 		ua.Status = dbdata.UserAuthFail
@@ -119,18 +119,12 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
 	}
 	// 用户otp验证
 	if !v.DisableOtp {
-		lockManager.LoginStatus.Store(loginStatusKey, true) // 重置OTP验证计数
-		sessionID, err := GenerateSessionID()
-		if err != nil {
-			base.Error("Failed to generate session ID: ", err)
-			http.Error(w, "Failed to generate session ID", http.StatusInternalServerError)
-			return
-		}
+		lockManager.LoginStatus.Store(cr.SessionId, true) // 重置OTP验证计数
 
 		sessionData.ClientRequest.Auth.OtpSecret = v.OtpSecret
-		SessStore.SaveAuthSession(sessionID, sessionData)
+		SessStore.SaveAuthSession(cr.SessionId, sessionData)
 
-		SetCookie(w, "auth-session-id", sessionID, 0)
+		SetCookie(w, "auth-session-id", cr.SessionId, 0)
 
 		data := RequestData{}
 		w.WriteHeader(http.StatusOK)
diff --git a/server/handler/link_auth_otp.go b/server/handler/link_auth_otp.go
index 4db0f59..d36a64f 100644
--- a/server/handler/link_auth_otp.go
+++ b/server/handler/link_auth_otp.go
@@ -11,7 +11,6 @@ import (
 
 	"github.com/bjdgyc/anylink/base"
 	"github.com/bjdgyc/anylink/dbdata"
-	"github.com/bjdgyc/anylink/pkg/utils"
 	"github.com/bjdgyc/anylink/sessdata"
 )
 
@@ -66,14 +65,14 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
 // 	return int(newI)
 // }
 
-func GenerateSessionID() (string, error) {
-	sessionID := utils.RandomRunes(32)
-	if sessionID == "" {
-		return "", fmt.Errorf("failed to generate session ID")
-	}
+// func GenerateSessionID() (string, error) {
+// 	sessionID := utils.RandomRunes(32)
+// 	if sessionID == "" {
+// 		return "", fmt.Errorf("failed to generate session ID")
+// 	}
 
-	return sessionID, nil
-}
+// 	return sessionID, nil
+// }
 
 func SetCookie(w http.ResponseWriter, name, value string, maxAge int) {
 	cookie := &http.Cookie{
@@ -109,11 +108,12 @@ func DeleteCookie(w http.ResponseWriter, name string) {
 	http.SetCookie(w, cookie)
 }
 func CreateSession(w http.ResponseWriter, r *http.Request, authSession *AuthSession) {
-	lockManager.LoginStatus.Store(loginStatusKey, true) // 更新登录成功状态
-
 	cr := authSession.ClientRequest
 	ua := authSession.UserActLog
 
+	lockManager.LoginStatus.Store(cr.SessionId, true) // 更新登录成功状态
+	lockManager.LoginStatus.Delete(cr.SessionId)      // 清除登录状态
+
 	sess := sessdata.NewSession("")
 	sess.Username = cr.Auth.Username
 	sess.Group = cr.GroupSelect
@@ -201,7 +201,7 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
 		// 	http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
 		// 	return
 		// }
-		lockManager.LoginStatus.Store(loginStatusKey, false) // 记录登录失败状态
+		lockManager.LoginStatus.Store(cr.SessionId, false) // 记录登录失败状态
 
 		base.Warn("OTP 动态码错误", username, r.RemoteAddr)
 		ua.Info = "OTP 动态码错误"