增加全局IP白名单功能

This commit is contained in:
wsczx
2024-10-04 16:02:24 +08:00
parent 59748fe395
commit c8cb9c163a
5 changed files with 77 additions and 16 deletions

View File

@@ -18,10 +18,15 @@ type contextKey string
// 定义常量作为上下文的键
const loginStatusKey contextKey = "login_status"
const defaultGlobalLockStateExpirationTime = 3600
func init() {
func initAntiBruteForce() {
if base.Cfg.AntiBruteForce {
if base.Cfg.GlobalLockStateExpirationTime <= 0 {
base.Cfg.GlobalLockStateExpirationTime = defaultGlobalLockStateExpirationTime
}
lockManager.startCleanupTicker()
lockManager.initIPWhitelist()
}
}
@@ -56,6 +61,12 @@ func antiBruteForce(next http.Handler) http.Handler {
}
now := time.Now()
// 检查IP是否在白名单中
if lockManager.isWhitelisted(ip) {
r.Body = io.NopCloser(strings.NewReader(string(body)))
next.ServeHTTP(w, r)
return
}
// // 速率限制
// lockManager.mu.RLock()
@@ -114,23 +125,69 @@ type LockState struct {
LockTime time.Time
LastAttempt time.Time
}
type IPWhitelists struct {
IP net.IP
CIDR *net.IPNet
}
type LockManager struct {
mu sync.RWMutex
ipLocks map[string]*LockState // 全局IP锁定状态
userLocks map[string]*LockState // 全局用户锁定状态
ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态
mu sync.Mutex
ipLocks map[string]*LockState // 全局IP锁定状态
userLocks map[string]*LockState // 全局用户锁定状态
ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态
ipWhitelists []IPWhitelists // 全局IP白名单包含IP地址和CIDR范围
// rateLimiter map[string]*rate.Limiter // 速率限制器
cleanupTicker *time.Ticker
}
var lockManager = &LockManager{
ipLocks: make(map[string]*LockState),
userLocks: make(map[string]*LockState),
ipUserLocks: make(map[string]map[string]*LockState),
ipLocks: make(map[string]*LockState),
userLocks: make(map[string]*LockState),
ipUserLocks: make(map[string]map[string]*LockState),
ipWhitelists: make([]IPWhitelists, 0),
// rateLimiter: make(map[string]*rate.Limiter),
}
// 初始化IP白名单
func (lm *LockManager) initIPWhitelist() {
ipWhitelist := strings.Split(base.Cfg.IPWhitelist, ",")
for _, ipWhitelist := range ipWhitelist {
ipWhitelist = strings.TrimSpace(ipWhitelist)
if ipWhitelist == "" {
continue
}
_, ipNet, err := net.ParseCIDR(ipWhitelist)
if err == nil {
lm.ipWhitelists = append(lm.ipWhitelists, IPWhitelists{CIDR: ipNet})
continue
}
ip := net.ParseIP(ipWhitelist)
if ip != nil {
lm.ipWhitelists = append(lm.ipWhitelists, IPWhitelists{IP: ip})
continue
}
}
}
// 检查 IP 是否在白名单中
func (lm *LockManager) isWhitelisted(ip string) bool {
clientIP := net.ParseIP(ip)
if clientIP == nil {
return false
}
for _, ipWhitelist := range lm.ipWhitelists {
if ipWhitelist.CIDR != nil && ipWhitelist.CIDR.Contains(clientIP) {
return true
}
if ipWhitelist.IP != nil && ipWhitelist.IP.Equal(clientIP) {
return true
}
}
return false
}
func (lm *LockManager) startCleanupTicker() {
lm.cleanupTicker = time.NewTicker(5 * time.Minute)
go func() {
@@ -187,8 +244,8 @@ func (lm *LockManager) cleanupExpiredLocks() {
// 检查全局 IP 锁定
func (lm *LockManager) checkGlobalIPLock(ip string, now time.Time) bool {
lm.mu.RLock()
defer lm.mu.RUnlock()
lm.mu.Lock()
defer lm.mu.Unlock()
state, exists := lm.ipLocks[ip]
if !exists {
@@ -211,14 +268,13 @@ func (lm *LockManager) checkGlobalUserLock(username string, now time.Time) bool
if username == "" {
return false
}
lm.mu.RLock()
defer lm.mu.RUnlock()
lm.mu.Lock()
defer lm.mu.Unlock()
state, exists := lm.userLocks[username]
if !exists {
return false
}
// 如果超过时间窗口,重置失败计数
lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalUserBanResetTime)
@@ -235,8 +291,8 @@ func (lm *LockManager) checkUserIPLock(username, ip string, now time.Time) bool
if username == "" {
return false
}
lm.mu.RLock()
defer lm.mu.RUnlock()
lm.mu.Lock()
defer lm.mu.Unlock()
userIPMap, userExists := lm.ipUserLocks[username]
if !userExists {