mirror of https://github.com/bjdgyc/anylink.git
增加用户验证防爆功能
This commit is contained in:
parent
7160c3cab7
commit
c5a76ba436
|
@ -84,6 +84,11 @@ type ServerConfig struct {
|
||||||
|
|
||||||
DisplayError bool `json:"display_error"`
|
DisplayError bool `json:"display_error"`
|
||||||
ExcludeExportIp bool `json:"exclude_export_ip"`
|
ExcludeExportIp bool `json:"exclude_export_ip"`
|
||||||
|
|
||||||
|
MaxBanCount int `json:"max_ban_score"`
|
||||||
|
BanResetTime int `json:"ban_reset_time"`
|
||||||
|
LockTime int `json:"lock_time"`
|
||||||
|
UserStateExpiration int `json:"user_state_expiration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func initServerCfg() {
|
func initServerCfg() {
|
||||||
|
|
|
@ -71,6 +71,11 @@ var configs = []config{
|
||||||
|
|
||||||
{Typ: cfgBool, Name: "display_error", Usage: "客户端显示详细错误信息(线上环境慎开启)", ValBool: false},
|
{Typ: cfgBool, Name: "display_error", Usage: "客户端显示详细错误信息(线上环境慎开启)", ValBool: false},
|
||||||
{Typ: cfgBool, Name: "exclude_export_ip", Usage: "排除出口ip路由(出口ip不加密传输)", ValBool: true},
|
{Typ: cfgBool, Name: "exclude_export_ip", Usage: "排除出口ip路由(出口ip不加密传输)", ValBool: true},
|
||||||
|
|
||||||
|
{Typ: cfgInt, Name: "max_ban_score", Usage: "单位时间内最大尝试次数", ValInt: 5},
|
||||||
|
{Typ: cfgInt, Name: "ban_reset_time", Usage: "设置单位时间(秒),超过则重置计数", ValInt: 1},
|
||||||
|
{Typ: cfgInt, Name: "lock_time", Usage: "超过最大尝试次数后的锁定时长(秒)", ValInt: 300},
|
||||||
|
{Typ: cfgInt, Name: "user_state_expiration", Usage: "用户状态的保存周期(秒),超过则清空计数", ValInt: 900},
|
||||||
}
|
}
|
||||||
|
|
||||||
var envs = map[string]string{}
|
var envs = map[string]string{}
|
||||||
|
|
|
@ -53,8 +53,14 @@ ipv4_end = "192.168.90.200"
|
||||||
#是否自动添加nat
|
#是否自动添加nat
|
||||||
iptables_nat = true
|
iptables_nat = true
|
||||||
|
|
||||||
|
#单位时间内最大尝试次数
|
||||||
|
max_ban_score = 5
|
||||||
|
#设置单位时间(秒),超过则重置计数
|
||||||
|
ban_reset_time = 10
|
||||||
|
#超过最大尝试次数后的锁定时长(秒)
|
||||||
|
lock_time = 300
|
||||||
|
#用户状态的保存周期(秒),超过则清空计数
|
||||||
|
user_state_expiration = 900
|
||||||
|
|
||||||
#客户端显示详细错误信息(线上环境慎开启)
|
#客户端显示详细错误信息(线上环境慎开启)
|
||||||
display_error = true
|
display_error = true
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bjdgyc/anylink/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserState 用于存储用户的登录状态
|
||||||
|
type UserState struct {
|
||||||
|
FailureCount int
|
||||||
|
LastAttempt time.Time
|
||||||
|
LockTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自定义 contextKey 类型,避免键冲突
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
// 定义常量作为上下文的键
|
||||||
|
const loginStatusKey contextKey = "login_status"
|
||||||
|
|
||||||
|
// 用户状态映射
|
||||||
|
var userStates = make(map[string]*UserState)
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go cleanupUserStates()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理过期的登录状态
|
||||||
|
func cleanupUserStates() {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for username, state := range userStates {
|
||||||
|
if now.Sub(state.LastAttempt) > time.Duration(base.Cfg.UserStateExpiration)*time.Second {
|
||||||
|
delete(userStates, username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 防爆破中间件
|
||||||
|
func antiBruteForce(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 如果最大验证失败次数为0,则不启用防爆破功能
|
||||||
|
if base.Cfg.MaxBanCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Body.Close()
|
||||||
|
|
||||||
|
cr := ClientRequest{}
|
||||||
|
err = xml.Unmarshal(body, &cr)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to parse XML", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username := cr.Auth.Username
|
||||||
|
|
||||||
|
// 更新用户登录状态
|
||||||
|
mu.Lock()
|
||||||
|
state, exists := userStates[username]
|
||||||
|
if !exists {
|
||||||
|
state = &UserState{}
|
||||||
|
userStates[username] = state
|
||||||
|
}
|
||||||
|
// 检查是否已超过锁定时间
|
||||||
|
if !state.LockTime.IsZero() {
|
||||||
|
if time.Now().After(state.LockTime) {
|
||||||
|
// 如果已经超过了锁定时间,重置失败计数和锁定时间
|
||||||
|
state.FailureCount = 0
|
||||||
|
state.LockTime = time.Time{}
|
||||||
|
} else {
|
||||||
|
// 如果还在锁定时间内,返回错误信息
|
||||||
|
http.Error(w, "Account locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests)
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果超过时间窗口,重置失败计数
|
||||||
|
if time.Since(state.LastAttempt) > time.Duration(base.Cfg.BanResetTime)*time.Second {
|
||||||
|
state.FailureCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
state.LastAttempt = time.Now()
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// 重新设置请求体以便后续处理器可以访问
|
||||||
|
r.Body = io.NopCloser(strings.NewReader(string(body)))
|
||||||
|
|
||||||
|
// 调用下一个处理器
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
// 从 context 中获取登录状态
|
||||||
|
loginStatus, ok := r.Context().Value(loginStatusKey).(bool)
|
||||||
|
if !ok {
|
||||||
|
// 如果没有找到登录状态,默认为登录失败
|
||||||
|
loginStatus = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户登录状态
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
if !loginStatus {
|
||||||
|
state.FailureCount++
|
||||||
|
if state.FailureCount >= base.Cfg.MaxBanCount {
|
||||||
|
state.LockTime = time.Now().Add(time.Duration(base.Cfg.LockTime) * time.Second)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.FailureCount = 0 // 成功登录后重置
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -88,6 +89,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
// TODO 用户密码校验
|
// TODO 用户密码校验
|
||||||
err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect)
|
err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, false)) // 传递登录失败状态
|
||||||
base.Warn(err, r.RemoteAddr)
|
base.Warn(err, r.RemoteAddr)
|
||||||
ua.Info = err.Error()
|
ua.Info = err.Error()
|
||||||
ua.Status = dbdata.UserAuthFail
|
ua.Status = dbdata.UserAuthFail
|
||||||
|
@ -101,6 +103,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
tplRequest(tpl_request, w, data)
|
tplRequest(tpl_request, w, data)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, true)) // 传递登录成功状态
|
||||||
dbdata.UserActLogIns.Add(ua, userAgent)
|
dbdata.UserActLogIns.Add(ua, userAgent)
|
||||||
// if !ok {
|
// if !ok {
|
||||||
// w.WriteHeader(http.StatusOK)
|
// w.WriteHeader(http.StatusOK)
|
||||||
|
|
|
@ -111,7 +111,7 @@ func initRoute() http.Handler {
|
||||||
})
|
})
|
||||||
|
|
||||||
r.HandleFunc("/", LinkHome).Methods(http.MethodGet)
|
r.HandleFunc("/", LinkHome).Methods(http.MethodGet)
|
||||||
r.HandleFunc("/", LinkAuth).Methods(http.MethodPost)
|
r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost)
|
||||||
r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect)
|
r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect)
|
||||||
r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet)
|
r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet)
|
||||||
r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
Loading…
Reference in New Issue