增加全局IP白名单功能

This commit is contained in:
wsczx 2024-10-04 16:02:24 +08:00
parent 59748fe395
commit c8cb9c163a
5 changed files with 77 additions and 16 deletions

View File

@ -85,7 +85,8 @@ type ServerConfig struct {
DisplayError bool `json:"display_error"` DisplayError bool `json:"display_error"`
ExcludeExportIp bool `json:"exclude_export_ip"` ExcludeExportIp bool `json:"exclude_export_ip"`
AntiBruteForce bool `json:"anti_brute_force"` AntiBruteForce bool `json:"anti_brute_force"`
IPWhitelist string `json:"ip_whitelist"`
MaxBanCount int `json:"max_ban_score"` MaxBanCount int `json:"max_ban_score"`
BanResetTime int `json:"ban_reset_time"` BanResetTime int `json:"ban_reset_time"`

View File

@ -73,6 +73,7 @@ var configs = []config{
{Typ: cfgBool, Name: "exclude_export_ip", Usage: "排除出口ip路由(出口ip不加密传输)", ValBool: true}, {Typ: cfgBool, Name: "exclude_export_ip", Usage: "排除出口ip路由(出口ip不加密传输)", ValBool: true},
{Typ: cfgBool, Name: "anti_brute_force", Usage: "是否开启防爆功能", ValBool: true}, {Typ: cfgBool, Name: "anti_brute_force", Usage: "是否开启防爆功能", ValBool: true},
{Typ: cfgStr, Name: "ip_whitelist", Usage: "全局IP白名单,多个用逗号分隔支持单IP和CIDR范围", ValStr: "192.168.90.1,172.16.0.0/24"},
{Typ: cfgInt, Name: "max_ban_score", Usage: "单位时间内最大尝试次数0为关闭该功能", ValInt: 5}, {Typ: cfgInt, Name: "max_ban_score", Usage: "单位时间内最大尝试次数0为关闭该功能", ValInt: 5},
{Typ: cfgInt, Name: "ban_reset_time", Usage: "设置单位时间(秒),超过则重置计数", ValInt: 10}, {Typ: cfgInt, Name: "ban_reset_time", Usage: "设置单位时间(秒),超过则重置计数", ValInt: 10},

View File

@ -55,6 +55,7 @@ iptables_nat = true
#防爆破全局开关 #防爆破全局开关
anti_brute_force = true anti_brute_force = true
ip_whitelist = "192.168.90.1,172.16.0.0/24"
#单位时间内最大尝试次数0为关闭该功能 #单位时间内最大尝试次数0为关闭该功能
max_ban_score = 5 max_ban_score = 5

View File

@ -18,10 +18,15 @@ type contextKey string
// 定义常量作为上下文的键 // 定义常量作为上下文的键
const loginStatusKey contextKey = "login_status" const loginStatusKey contextKey = "login_status"
const defaultGlobalLockStateExpirationTime = 3600
func init() { func initAntiBruteForce() {
if base.Cfg.AntiBruteForce { if base.Cfg.AntiBruteForce {
if base.Cfg.GlobalLockStateExpirationTime <= 0 {
base.Cfg.GlobalLockStateExpirationTime = defaultGlobalLockStateExpirationTime
}
lockManager.startCleanupTicker() lockManager.startCleanupTicker()
lockManager.initIPWhitelist()
} }
} }
@ -56,6 +61,12 @@ func antiBruteForce(next http.Handler) http.Handler {
} }
now := time.Now() now := time.Now()
// 检查IP是否在白名单中
if lockManager.isWhitelisted(ip) {
r.Body = io.NopCloser(strings.NewReader(string(body)))
next.ServeHTTP(w, r)
return
}
// // 速率限制 // // 速率限制
// lockManager.mu.RLock() // lockManager.mu.RLock()
@ -114,23 +125,69 @@ type LockState struct {
LockTime time.Time LockTime time.Time
LastAttempt time.Time LastAttempt time.Time
} }
type IPWhitelists struct {
IP net.IP
CIDR *net.IPNet
}
type LockManager struct { type LockManager struct {
mu sync.RWMutex mu sync.Mutex
ipLocks map[string]*LockState // 全局IP锁定状态 ipLocks map[string]*LockState // 全局IP锁定状态
userLocks map[string]*LockState // 全局用户锁定状态 userLocks map[string]*LockState // 全局用户锁定状态
ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态 ipUserLocks map[string]map[string]*LockState // 单用户IP锁定状态
ipWhitelists []IPWhitelists // 全局IP白名单包含IP地址和CIDR范围
// rateLimiter map[string]*rate.Limiter // 速率限制器 // rateLimiter map[string]*rate.Limiter // 速率限制器
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
} }
var lockManager = &LockManager{ var lockManager = &LockManager{
ipLocks: make(map[string]*LockState), ipLocks: make(map[string]*LockState),
userLocks: make(map[string]*LockState), userLocks: make(map[string]*LockState),
ipUserLocks: make(map[string]map[string]*LockState), ipUserLocks: make(map[string]map[string]*LockState),
ipWhitelists: make([]IPWhitelists, 0),
// rateLimiter: make(map[string]*rate.Limiter), // rateLimiter: make(map[string]*rate.Limiter),
} }
// 初始化IP白名单
func (lm *LockManager) initIPWhitelist() {
ipWhitelist := strings.Split(base.Cfg.IPWhitelist, ",")
for _, ipWhitelist := range ipWhitelist {
ipWhitelist = strings.TrimSpace(ipWhitelist)
if ipWhitelist == "" {
continue
}
_, ipNet, err := net.ParseCIDR(ipWhitelist)
if err == nil {
lm.ipWhitelists = append(lm.ipWhitelists, IPWhitelists{CIDR: ipNet})
continue
}
ip := net.ParseIP(ipWhitelist)
if ip != nil {
lm.ipWhitelists = append(lm.ipWhitelists, IPWhitelists{IP: ip})
continue
}
}
}
// 检查 IP 是否在白名单中
func (lm *LockManager) isWhitelisted(ip string) bool {
clientIP := net.ParseIP(ip)
if clientIP == nil {
return false
}
for _, ipWhitelist := range lm.ipWhitelists {
if ipWhitelist.CIDR != nil && ipWhitelist.CIDR.Contains(clientIP) {
return true
}
if ipWhitelist.IP != nil && ipWhitelist.IP.Equal(clientIP) {
return true
}
}
return false
}
func (lm *LockManager) startCleanupTicker() { func (lm *LockManager) startCleanupTicker() {
lm.cleanupTicker = time.NewTicker(5 * time.Minute) lm.cleanupTicker = time.NewTicker(5 * time.Minute)
go func() { go func() {
@ -187,8 +244,8 @@ func (lm *LockManager) cleanupExpiredLocks() {
// 检查全局 IP 锁定 // 检查全局 IP 锁定
func (lm *LockManager) checkGlobalIPLock(ip string, now time.Time) bool { func (lm *LockManager) checkGlobalIPLock(ip string, now time.Time) bool {
lm.mu.RLock() lm.mu.Lock()
defer lm.mu.RUnlock() defer lm.mu.Unlock()
state, exists := lm.ipLocks[ip] state, exists := lm.ipLocks[ip]
if !exists { if !exists {
@ -211,14 +268,13 @@ func (lm *LockManager) checkGlobalUserLock(username string, now time.Time) bool
if username == "" { if username == "" {
return false return false
} }
lm.mu.RLock() lm.mu.Lock()
defer lm.mu.RUnlock() defer lm.mu.Unlock()
state, exists := lm.userLocks[username] state, exists := lm.userLocks[username]
if !exists { if !exists {
return false return false
} }
// 如果超过时间窗口,重置失败计数 // 如果超过时间窗口,重置失败计数
lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalUserBanResetTime) lm.resetLockStateIfExpired(state, now, base.Cfg.GlobalUserBanResetTime)
@ -235,8 +291,8 @@ func (lm *LockManager) checkUserIPLock(username, ip string, now time.Time) bool
if username == "" { if username == "" {
return false return false
} }
lm.mu.RLock() lm.mu.Lock()
defer lm.mu.RUnlock() defer lm.mu.Unlock()
userIPMap, userExists := lm.ipUserLocks[username] userIPMap, userExists := lm.ipUserLocks[username]
if !userExists { if !userExists {

View File

@ -17,6 +17,8 @@ func Start() {
sessdata.Start() sessdata.Start()
cron.Start() cron.Start()
initAntiBruteForce() //初始化防爆破定时器和IP白名单
// 开启服务器转发 // 开启服务器转发
err := execCmd([]string{"sysctl -w net.ipv4.ip_forward=1"}) err := execCmd([]string{"sysctl -w net.ipv4.ip_forward=1"})
if err != nil { if err != nil {