1.修复防爆策略用户登录成功后没有重置计数的Bug

2.增加otp防爆
3.添加otp使用说明
4.优化代码
This commit is contained in:
wsczx
2024-10-26 09:13:02 +08:00
parent fdc755bd98
commit f8685490dc
6 changed files with 86 additions and 165 deletions

View File

@@ -12,11 +12,7 @@ import (
"github.com/bjdgyc/anylink/base"
)
// 自定义 contextKey 类型,避免键冲突
type contextKey string
// 定义常量作为上下文的键
const loginStatusKey contextKey = "login_status"
const loginStatusKey = "login_status"
const defaultGlobalLockStateExpirationTime = 3600
func initAntiBruteForce() {
@@ -53,6 +49,20 @@ func antiBruteForce(next http.Handler) http.Handler {
}
username := cr.Auth.Username
if r.URL.Path == "/otp-verification" {
sessionID, err := GetCookie(r, "auth-session-id")
if err != nil {
http.Error(w, "Invalid session, please login again", http.StatusUnauthorized)
return
}
sessionData, err := SessStore.GetAuthSession(sessionID)
if err != nil {
http.Error(w, "Invalid session, please login again", http.StatusUnauthorized)
return
}
username = sessionData.ClientRequest.Auth.Username
}
ip, _, err := net.SplitHostPort(r.RemoteAddr) // 提取纯 IP 地址,去掉端口号
if err != nil {
http.Error(w, "Unable to parse IP address", http.StatusInternalServerError)
@@ -109,13 +119,17 @@ func antiBruteForce(next http.Handler) http.Handler {
// 调用下一个处理器
next.ServeHTTP(w, r)
// 从 context 中获取登录状态
loginStatus, _ := r.Context().Value(loginStatusKey).(bool)
// 检查登录状态
Status, _ := lockManager.loginStatus.Load(loginStatusKey)
loginStatus, _ := Status.(bool)
// 更新用户登录状态
lockManager.updateGlobalIPLock(ip, now, loginStatus)
lockManager.updateGlobalUserLock(username, now, loginStatus)
lockManager.updateUserIPLock(username, ip, now, loginStatus)
// 清除登录状态
lockManager.loginStatus.Delete(loginStatusKey)
})
}
@@ -131,6 +145,7 @@ type IPWhitelists struct {
type LockManager struct {
mu sync.Mutex
loginStatus sync.Map // 登录状态
ipLocks map[string]*LockState // 全局IP锁定状态
userLocks map[string]*LockState // 全局用户锁定状态
ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态
@@ -140,6 +155,7 @@ type LockManager struct {
}
var lockManager = &LockManager{
loginStatus: sync.Map{},
ipLocks: make(map[string]*LockState),
userLocks: make(map[string]*LockState),
ipUserLocks: make(map[string]map[string]*LockState),
@@ -251,14 +267,7 @@ func (lm *LockManager) checkGlobalIPLock(ip string, now time.Time) bool {
return false
}
// 如果超过时间窗口,重置失败计数
lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalIPBanResetTime)
if !state.LockTime.IsZero() && now.Before(state.LockTime) {
return true
}
return false
return lm.checkLockState(state, now, base.Cfg.GlobalIPBanResetTime)
}
// 检查全局用户锁定
@@ -274,14 +283,7 @@ func (lm *LockManager) checkGlobalUserLock(username string, now time.Time) bool
if !exists {
return false
}
// 如果超过时间窗口,重置失败计数
lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalUserBanResetTime)
if !state.LockTime.IsZero() && now.Before(state.LockTime) {
return true
}
return false
return lm.checkLockState(state, now, base.Cfg.GlobalUserBanResetTime)
}
// 检查单个用户的 IP 锁定
@@ -303,14 +305,7 @@ func (lm *LockManager) checkUserIPLock(username, ip string, now time.Time) bool
return false
}
// 如果超过时间窗口,重置失败计数
lm.resetLockStateIfExpired(state, now, base.Cfg.BanResetTime)
if !state.LockTime.IsZero() && now.Before(state.LockTime) {
return true
}
return false
return lm.checkLockState(state, now, base.Cfg.BanResetTime)
}
// 更新全局 IP 锁定状态
@@ -383,22 +378,27 @@ func (lm *LockManager) updateLockState(state *LockState, now time.Time, success
state.LastAttempt = now
}
// 超过窗口时间和锁定时间时重置锁定状态
func (lm *LockManager) resetLockStateIfExpired(state *LockState, now time.Time, resetTime int) {
// 检查锁定状态
func (lm *LockManager) checkLockState(state *LockState, now time.Time, resetTime int) bool {
if state == nil || state.LastAttempt.IsZero() {
return
return false
}
// 如果超过锁定时间,重置锁定状态
if !state.LockTime.IsZero() && now.After(state.LockTime) {
state.FailureCount = 0
state.LockTime = time.Time{}
return
return false
}
// 如果超过窗口时间,重置失败计数
if now.Sub(state.LastAttempt) > time.Duration(resetTime)*time.Second {
state.FailureCount = 0
state.LockTime = time.Time{}
return false
}
// 如果锁定时间还在有效期内,继续锁定
if !state.LockTime.IsZero() && now.Before(state.LockTime) {
return true
}
return false
}

View File

@@ -2,7 +2,6 @@ package handler
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
@@ -95,7 +94,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
// TODO 用户密码校验
err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect)
if err != nil {
r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, false)) // 传递登录失败状态
lockManager.loginStatus.Store(loginStatusKey, false) // 记录登录失败状态
base.Warn(err, r.RemoteAddr)
ua.Info = err.Error()
ua.Status = dbdata.UserAuthFail
@@ -109,7 +108,6 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
tplRequest(tpl_request, w, data)
return
}
r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, true)) // 传递登录成功状态
dbdata.UserActLogIns.Add(*ua, userAgent)
v := &dbdata.User{}
@@ -121,6 +119,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
}
// 用户otp验证
if !v.DisableOtp {
lockManager.loginStatus.Store(loginStatusKey, true) // 重置OTP验证计数
sessionID, err := GenerateSessionID()
if err != nil {
base.Error("Failed to generate session ID: ", err)

View File

@@ -110,6 +110,8 @@ func DeleteCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, cookie)
}
func CreateSession(w http.ResponseWriter, r *http.Request, authSession *AuthSession) {
lockManager.loginStatus.Store(loginStatusKey, true) // 更新登录成功状态
cr := authSession.ClientRequest
ua := authSession.UserActLog
@@ -200,6 +202,7 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
return
}
lockManager.loginStatus.Store(loginStatusKey, false) // 记录登录失败状态
base.Warn("OTP 动态码错误", username, r.RemoteAddr)
ua.Info = "OTP 动态码错误"

View File

@@ -114,7 +114,7 @@ func initRoute() http.Handler {
r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost)
r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect)
r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet)
r.HandleFunc("/otp-verification", LinkAuth_otp)
r.Handle("/otp-verification", antiBruteForce(http.HandlerFunc(LinkAuth_otp))).Methods(http.MethodPost)
r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) {
b, _ := os.ReadFile(base.Cfg.Profile)
w.Write(b)