diff --git a/server/admin/lockmanager.go b/server/admin/lockmanager.go index 4db1985..2d79f09 100644 --- a/server/admin/lockmanager.go +++ b/server/admin/lockmanager.go @@ -30,8 +30,8 @@ type IPWhitelists struct { } type LockManager struct { - mu sync.Mutex - LoginStatus sync.Map // 登录状态 + mu sync.Mutex + // LoginStatus sync.Map // 登录状态 ipLocks map[string]*LockState // 全局IP锁定状态 userLocks map[string]*LockState // 全局用户锁定状态 ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态 @@ -45,7 +45,7 @@ var once sync.Once func GetLockManager() *LockManager { once.Do(func() { lockmanager = &LockManager{ - LoginStatus: sync.Map{}, + // LoginStatus: sync.Map{}, ipLocks: make(map[string]*LockState), userLocks: make(map[string]*LockState), ipUserLocks: make(map[string]map[string]*LockState), @@ -409,3 +409,53 @@ func (lm *LockManager) Unlock(state *LockState) { state.LockTime = time.Time{} state.Locked = false } + +// 检查锁定状态 +func (lm *LockManager) CheckLocked(username, ipaddr string) bool { + if !base.Cfg.AntiBruteForce { + return true + } + + ip, _, err := net.SplitHostPort(ipaddr) // 提取纯 IP 地址,去掉端口号 + if err != nil { + return true + } + + now := time.Now() + // 检查IP是否在白名单中 + if lm.IsWhitelisted(ip) { + return true + } + + // 检查全局 IP 锁定 + if base.Cfg.MaxGlobalIPBanCount > 0 && lm.CheckGlobalIPLock(ip, now) { + base.Warn("IP", ip, "is globally locked. Try again later.") + return false + } + + // 检查全局用户锁定 + if base.Cfg.MaxGlobalUserBanCount > 0 && lm.CheckGlobalUserLock(username, now) { + base.Warn("User", username, "is globally locked. Try again later.") + return false + } + + // 检查单个用户的 IP 锁定 + if base.Cfg.MaxBanCount > 0 && lm.CheckUserIPLock(username, ip, now) { + base.Warn("IP", ip, "is locked for user", username, "Try again later.") + return false + } + + return true +} + +func (lm *LockManager) UpdateLock(username, ipaddr string, loginStatus bool) { + ip, _, err := net.SplitHostPort(ipaddr) // 提取纯 IP 地址,去掉端口号 + if err != nil { + return + } + now := time.Now() + // 更新用户登录状态 + lm.UpdateGlobalIPLock(ip, now, loginStatus) + lm.UpdateGlobalUserLock(username, now, loginStatus) + lm.UpdateUserIPLock(username, ip, now, loginStatus) +} diff --git a/server/handler/antiBruteForce.go b/server/handler/antiBruteForce.go deleted file mode 100644 index 999a6fc..0000000 --- a/server/handler/antiBruteForce.go +++ /dev/null @@ -1,108 +0,0 @@ -package handler - -import ( - "encoding/xml" - "io" - "net" - "net/http" - "strings" - "time" - - "github.com/bjdgyc/anylink/admin" - "github.com/bjdgyc/anylink/base" -) - -var lockManager = admin.GetLockManager() - -// 防爆破中间件 -func antiBruteForce(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 防爆破功能全局开关 - if !base.Cfg.AntiBruteForce { - next.ServeHTTP(w, r) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - cr := ClientRequest{} - err = xml.Unmarshal(body, &cr) - if err != nil { - http.Error(w, "Failed to parse XML", http.StatusBadRequest) - return - } - - username := cr.Auth.Username - if r.URL.Path == "/otp-verification" { - sessionID, err := GetCookie(r, "auth-session-id") - if err != nil { - http.Error(w, "Invalid session, please login again", http.StatusUnauthorized) - return - } - - sessionData, err := SessStore.GetAuthSession(sessionID) - if err != nil { - http.Error(w, "Invalid session, please login again", http.StatusUnauthorized) - return - } - username = sessionData.ClientRequest.Auth.Username - } - ip, _, err := net.SplitHostPort(r.RemoteAddr) // 提取纯 IP 地址,去掉端口号 - if err != nil { - http.Error(w, "Unable to parse IP address", http.StatusInternalServerError) - return - } - - now := time.Now() - // 检查IP是否在白名单中 - if lockManager.IsWhitelisted(ip) { - r.Body = io.NopCloser(strings.NewReader(string(body))) - next.ServeHTTP(w, r) - return - } - - // 检查全局 IP 锁定 - if base.Cfg.MaxGlobalIPBanCount > 0 && lockManager.CheckGlobalIPLock(ip, now) { - base.Warn("IP", ip, "is globally locked. Try again later.") - http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) - return - } - - // 检查全局用户锁定 - if base.Cfg.MaxGlobalUserBanCount > 0 && lockManager.CheckGlobalUserLock(username, now) { - base.Warn("User", username, "is globally locked. Try again later.") - http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) - return - } - - // 检查单个用户的 IP 锁定 - if base.Cfg.MaxBanCount > 0 && lockManager.CheckUserIPLock(username, ip, now) { - base.Warn("IP", ip, "is locked for user", username, "Try again later.") - http.Error(w, "Account locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) - return - } - - // 重新设置请求体以便后续处理器可以访问 - r.Body = io.NopCloser(strings.NewReader(string(body))) - - // 调用下一个处理器 - next.ServeHTTP(w, r) - - // 检查登录状态 - Status, _ := lockManager.LoginStatus.Load(cr.SessionId) - loginStatus, _ := Status.(bool) - - // 更新用户登录状态 - lockManager.UpdateGlobalIPLock(ip, now, loginStatus) - lockManager.UpdateGlobalUserLock(username, now, loginStatus) - lockManager.UpdateUserIPLock(username, ip, now, loginStatus) - - // 清除登录状态 - lockManager.LoginStatus.Delete(cr.SessionId) - }) -} diff --git a/server/handler/link_auth.go b/server/handler/link_auth.go index 158f061..a36289f 100644 --- a/server/handler/link_auth.go +++ b/server/handler/link_auth.go @@ -77,6 +77,12 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } + + // 锁定状态判断 + if !lockManager.CheckLocked(cr.Auth.Username, r.RemoteAddr) { + http.Error(w, "Locked! Try again later.", http.StatusTooManyRequests) + return + } // 用户活动日志 ua := &dbdata.UserActLog{ Username: cr.Auth.Username, @@ -94,7 +100,8 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { // TODO 用户密码校验 err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) if err != nil { - lockManager.LoginStatus.Store(cr.SessionId, false) // 记录登录失败状态 + // lockManager.LoginStatus.Store(loginStatusKey, false) // 记录登录失败状态 + lockManager.UpdateLock(cr.Auth.Username, r.RemoteAddr, false) // 记录登录失败状态 base.Warn(err, r.RemoteAddr) ua.Info = err.Error() ua.Status = dbdata.UserAuthFail @@ -119,12 +126,18 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { } // 用户otp验证 if !v.DisableOtp { - lockManager.LoginStatus.Store(cr.SessionId, true) // 重置OTP验证计数 + lockManager.UpdateLock(cr.Auth.Username, r.RemoteAddr, true) // 重置OTP验证计数 + sessionID, err := GenerateSessionID() + if err != nil { + base.Error("Failed to generate session ID: ", err) + http.Error(w, "Failed to generate session ID", http.StatusInternalServerError) + return + } sessionData.ClientRequest.Auth.OtpSecret = v.OtpSecret - SessStore.SaveAuthSession(cr.SessionId, sessionData) + SessStore.SaveAuthSession(sessionID, sessionData) - SetCookie(w, "auth-session-id", cr.SessionId, 0) + SetCookie(w, "auth-session-id", sessionID, 0) data := RequestData{} w.WriteHeader(http.StatusOK) diff --git a/server/handler/link_auth_otp.go b/server/handler/link_auth_otp.go index d36a64f..dd085e5 100644 --- a/server/handler/link_auth_otp.go +++ b/server/handler/link_auth_otp.go @@ -9,12 +9,15 @@ import ( "net/http" "sync" + "github.com/bjdgyc/anylink/admin" "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/pkg/utils" "github.com/bjdgyc/anylink/sessdata" ) var SessStore = NewSessionStore() +var lockManager = admin.GetLockManager() // const maxOtpErrCount = 3 @@ -65,14 +68,13 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) { // return int(newI) // } -// func GenerateSessionID() (string, error) { -// sessionID := utils.RandomRunes(32) -// if sessionID == "" { -// return "", fmt.Errorf("failed to generate session ID") -// } - -// return sessionID, nil -// } +func GenerateSessionID() (string, error) { + sessionID := utils.RandomRunes(32) + if sessionID == "" { + return "", fmt.Errorf("failed to generate session ID") + } + return sessionID, nil +} func SetCookie(w http.ResponseWriter, name, value string, maxAge int) { cookie := &http.Cookie{ @@ -111,8 +113,7 @@ func CreateSession(w http.ResponseWriter, r *http.Request, authSession *AuthSess cr := authSession.ClientRequest ua := authSession.UserActLog - lockManager.LoginStatus.Store(cr.SessionId, true) // 更新登录成功状态 - lockManager.LoginStatus.Delete(cr.SessionId) // 清除登录状态 + lockManager.UpdateLock(cr.Auth.Username, r.RemoteAddr, true) // 更新登录成功状态 sess := sessdata.NewSession("") sess.Username = cr.Auth.Username @@ -194,6 +195,11 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) { otpSecret := sessionData.ClientRequest.Auth.OtpSecret otp := cr.Auth.SecondaryPassword + // 锁定状态判断 + if !lockManager.CheckLocked(username, r.RemoteAddr) { + http.Error(w, "Locked! Try again later.", http.StatusTooManyRequests) + return + } // 动态码错误 if !dbdata.CheckOtp(username, otp, otpSecret) { // if sessionData.AddOtpErrCount(1) > maxOtpErrCount { @@ -201,7 +207,7 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) { // http.Error(w, "TooManyError, please login again", http.StatusBadRequest) // return // } - lockManager.LoginStatus.Store(cr.SessionId, false) // 记录登录失败状态 + lockManager.UpdateLock(username, r.RemoteAddr, false) // 记录登录失败状态 base.Warn("OTP 动态码错误", username, r.RemoteAddr) ua.Info = "OTP 动态码错误" diff --git a/server/handler/server.go b/server/handler/server.go index c257aad..8fb734f 100644 --- a/server/handler/server.go +++ b/server/handler/server.go @@ -111,10 +111,10 @@ func initRoute() http.Handler { }) r.HandleFunc("/", LinkHome).Methods(http.MethodGet) - r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost) + r.HandleFunc("/", LinkAuth).Methods(http.MethodPost) r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect) r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet) - r.Handle("/otp-verification", antiBruteForce(http.HandlerFunc(LinkAuth_otp))).Methods(http.MethodPost) + r.HandleFunc("/otp-verification", LinkAuth_otp).Methods(http.MethodPost) r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) { b, _ := os.ReadFile(base.Cfg.Profile) w.Write(b)