From 11bd9861e5e9c44007d38eb56fb4c4894313c245 Mon Sep 17 00:00:00 2001 From: wsczx Date: Mon, 7 Oct 2024 17:11:56 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=BC=B9=E7=AA=97?= =?UTF-8?q?=E8=BE=93=E5=85=A5OTP=E5=8A=A8=E6=80=81=E7=A0=81=E7=9A=84?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/dbdata/user.go | 16 +-- server/handler/antiBruteForce.go | 7 +- server/handler/link_auth.go | 106 +++++++-------- server/handler/link_auth_otp.go | 222 +++++++++++++++++++++++++++++++ server/handler/link_base.go | 3 + server/handler/server.go | 1 + 6 files changed, 288 insertions(+), 67 deletions(-) create mode 100644 server/handler/link_auth_otp.go diff --git a/server/dbdata/user.go b/server/dbdata/user.go index a84ddaf..18db099 100644 --- a/server/dbdata/user.go +++ b/server/dbdata/user.go @@ -117,13 +117,13 @@ func checkLocalUser(name, pwd, group string) error { } // 判断otp信息 pinCode := pwd - if !v.DisableOtp { - pinCode = pwd[:pl-6] - otp := pwd[pl-6:] - if !checkOtp(name, otp, v.OtpSecret) { - return fmt.Errorf("%s %s", name, "动态码错误") - } - } + // if !v.DisableOtp { + // pinCode = pwd[:pl-6] + // otp := pwd[pl-6:] + // if !CheckOtp(name, otp, v.OtpSecret) { + // return fmt.Errorf("%s %s", name, "动态码错误") + // } + // } // 判断用户密码 if pinCode != v.PinCode { @@ -171,7 +171,7 @@ func init() { } // 判断令牌信息 -func checkOtp(name, otp, secret string) bool { +func CheckOtp(name, otp, secret string) bool { key := fmt.Sprintf("%s:%s", name, otp) userOtpMux.Lock() diff --git a/server/handler/antiBruteForce.go b/server/handler/antiBruteForce.go index 947a8a9..248f8d2 100644 --- a/server/handler/antiBruteForce.go +++ b/server/handler/antiBruteForce.go @@ -3,7 +3,6 @@ package handler import ( "encoding/xml" "io" - "log" "net" "net/http" "strings" @@ -85,21 +84,21 @@ func antiBruteForce(next http.Handler) http.Handler { // 检查全局 IP 锁定 if base.Cfg.MaxGlobalIPBanCount > 0 && lockManager.checkGlobalIPLock(ip, now) { - log.Printf("IP %s is globally locked. Try again later.", ip) + base.Warn("IP", ip, "is globally locked. Try again later.") http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) return } // 检查全局用户锁定 if base.Cfg.MaxGlobalUserBanCount > 0 && lockManager.checkGlobalUserLock(username, now) { - log.Printf("User %s is globally locked. Try again later.", username) + base.Warn("User", username, "is globally locked. Try again later.") http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) return } // 检查单个用户的 IP 锁定 if base.Cfg.MaxBanCount > 0 && lockManager.checkUserIPLock(username, ip, now) { - log.Printf("IP %s is locked for user %s. Try again later.", ip, username) + base.Warn("IP", ip, "is locked for user", username, "Try again later.") http.Error(w, "Account locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) return } diff --git a/server/handler/link_auth.go b/server/handler/link_auth.go index 8924290..fe0be6d 100644 --- a/server/handler/link_auth.go +++ b/server/handler/link_auth.go @@ -3,11 +3,9 @@ package handler import ( "bytes" "context" - "crypto/md5" "encoding/xml" "fmt" "io" - "net" "net/http" "net/http/httputil" "strings" @@ -47,7 +45,10 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - cr := ClientRequest{} + cr := &ClientRequest{ + RemoteAddr: r.RemoteAddr, + UserAgent: userAgent, + } err = xml.Unmarshal(body, &cr) if err != nil { base.Error(err) @@ -78,7 +79,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { return } // 用户活动日志 - ua := dbdata.UserActLog{ + ua := &dbdata.UserActLog{ Username: cr.Auth.Username, GroupName: cr.GroupSelect, RemoteAddr: r.RemoteAddr, @@ -86,6 +87,11 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { DeviceType: cr.DeviceId.DeviceType, PlatformVersion: cr.DeviceId.PlatformVersion, } + + sessionData := &AuthSession{ + ClientRequest: cr, + UserActLog: ua, + } // TODO 用户密码校验 err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) if err != nil { @@ -93,7 +99,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { base.Warn(err, r.RemoteAddr) ua.Info = err.Error() ua.Status = dbdata.UserAuthFail - dbdata.UserActLogIns.Add(ua, userAgent) + dbdata.UserActLogIns.Add(*ua, userAgent) w.WriteHeader(http.StatusOK) data := RequestData{Group: cr.GroupSelect, Groups: dbdata.GetGroupNamesNormal(), Error: "用户名或密码错误"} @@ -104,72 +110,62 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { return } r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, true)) // 传递登录成功状态 - dbdata.UserActLogIns.Add(ua, userAgent) - // if !ok { - // w.WriteHeader(http.StatusOK) - // data := RequestData{Group: cr.GroupSelect, Groups: base.Cfg.UserGroups, Error: "请先激活用户"} - // tplRequest(tpl_request, w, data) - // return - // } + dbdata.UserActLogIns.Add(*ua, userAgent) - // 创建新的session信息 - sess := sessdata.NewSession("") - sess.Username = cr.Auth.Username - sess.Group = cr.GroupSelect - oriMac := cr.MacAddressList.MacAddress - sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal - sess.UserAgent = userAgent - sess.DeviceType = ua.DeviceType - sess.PlatformVersion = ua.PlatformVersion - sess.RemoteAddr = r.RemoteAddr - // 获取客户端mac地址 - sess.UniqueMac = true - macHw, err := net.ParseMAC(oriMac) + v := &dbdata.User{} + err = dbdata.One("Username", cr.Auth.Username, v) if err != nil { - var sum [16]byte - if sess.UniqueIdGlobal != "" { - sum = md5.Sum([]byte(sess.UniqueIdGlobal)) - } else { - sum = md5.Sum([]byte(sess.Token)) - sess.UniqueMac = false - } - macHw = sum[0:5] // 5个byte - macHw = append([]byte{0x02}, macHw...) - sess.MacAddr = macHw.String() + base.Error("Failed to get TOTP secret for user:", cr.Auth.Username, err) + http.Error(w, "Failed to get TOTP secret", http.StatusInternalServerError) + return } - sess.MacHw = macHw - // 统一macAddr的格式 - sess.MacAddr = macHw.String() + // 用户otp验证 + if !v.DisableOtp { + sessionID, err := GenerateSessionID() + if err != nil { + base.Error("Failed to generate session ID: ", err) + http.Error(w, "Failed to generate session ID", http.StatusInternalServerError) + return + } - other := &dbdata.SettingOther{} - _ = dbdata.SettingGet(other) - rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token, - Banner: other.Banner, ProfileName: base.Cfg.ProfileName, ProfileHash: profileHash, CertHash: certHash} - w.WriteHeader(http.StatusOK) - tplRequest(tpl_complete, w, rd) - base.Info("login", cr.Auth.Username, userAgent) + sessionData.ClientRequest.Auth.OtpSecret = v.OtpSecret + SessStore.SaveAuthSession(sessionID, sessionData) + + SetCookie(w, "auth-session-id", sessionID, 0) + + data := RequestData{} + w.WriteHeader(http.StatusOK) + tplRequest(tpl_otp, w, data) + return + } + + CreateSession(w, r, sessionData) } const ( tpl_request = iota tpl_complete + tpl_otp ) func tplRequest(typ int, w io.Writer, data RequestData) { - if typ == tpl_request { + switch typ { + case tpl_request: t, _ := template.New("auth_request").Parse(auth_request) _ = t.Execute(w, data) - return - } + case tpl_complete: + if data.Banner != "" { + buf := new(bytes.Buffer) + _ = xml.EscapeText(buf, []byte(data.Banner)) + data.Banner = buf.String() + } - if data.Banner != "" { - buf := new(bytes.Buffer) - _ = xml.EscapeText(buf, []byte(data.Banner)) - data.Banner = buf.String() + t, _ := template.New("auth_complete").Parse(auth_complete) + _ = t.Execute(w, data) + case tpl_otp: + t, _ := template.New("auth_otp").Parse(auth_otp) + _ = t.Execute(w, data) } - - t, _ := template.New("auth_complete").Parse(auth_complete) - _ = t.Execute(w, data) } // 设置输出信息 diff --git a/server/handler/link_auth_otp.go b/server/handler/link_auth_otp.go new file mode 100644 index 0000000..313aa63 --- /dev/null +++ b/server/handler/link_auth_otp.go @@ -0,0 +1,222 @@ +package handler + +import ( + "crypto/md5" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "fmt" + "io" + "net" + "net/http" + "sync" + + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/sessdata" +) + +var SessStore = NewSessionStore() + +type AuthSession struct { + ClientRequest *ClientRequest + UserActLog *dbdata.UserActLog +} + +// 存储临时会话信息 +type SessionStore struct { + session map[string]*AuthSession + mu sync.Mutex +} + +func NewSessionStore() *SessionStore { + return &SessionStore{ + session: make(map[string]*AuthSession), + } +} + +func (s *SessionStore) SaveAuthSession(sessionID string, session *AuthSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.session[sessionID] = session +} + +func (s *SessionStore) GetAuthSession(sessionID string) (*AuthSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.session[sessionID] + if !exists { + return nil, fmt.Errorf("auth session not found") + } + + return session, nil +} + +func (s *SessionStore) DeleteAuthSession(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.session, sessionID) +} + +func GenerateSessionID() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + return "", fmt.Errorf("failed to generate session ID: %w", err) + } + + hash := sha256.Sum256(b) + sessionID := base64.URLEncoding.EncodeToString(hash[:]) + return sessionID, nil +} + +func SetCookie(w http.ResponseWriter, name, value string, maxAge int) { + cookie := &http.Cookie{ + Name: name, + Value: value, + MaxAge: maxAge, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + } + http.SetCookie(w, cookie) +} + +func GetCookie(r *http.Request, name string) (string, error) { + cookie, err := r.Cookie(name) + if err != nil { + return "", fmt.Errorf("failed to get cookie: %v", err) + } + return cookie.Value, nil +} + +func DeleteCookie(w http.ResponseWriter, name string) { + cookie := &http.Cookie{ + Name: name, + Value: "", + MaxAge: -1, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + } + http.SetCookie(w, cookie) +} +func CreateSession(w http.ResponseWriter, r *http.Request, authSession *AuthSession) { + cr := authSession.ClientRequest + ua := authSession.UserActLog + + sess := sessdata.NewSession("") + sess.Username = cr.Auth.Username + sess.Group = cr.GroupSelect + oriMac := cr.MacAddressList.MacAddress + sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal + sess.UserAgent = cr.UserAgent + sess.DeviceType = ua.DeviceType + sess.PlatformVersion = ua.PlatformVersion + sess.RemoteAddr = cr.RemoteAddr + // 获取客户端mac地址 + sess.UniqueMac = true + macHw, err := net.ParseMAC(oriMac) + if err != nil { + var sum [16]byte + if sess.UniqueIdGlobal != "" { + sum = md5.Sum([]byte(sess.UniqueIdGlobal)) + } else { + sum = md5.Sum([]byte(sess.Token)) + sess.UniqueMac = false + } + macHw = sum[0:5] // 5个byte + macHw = append([]byte{0x02}, macHw...) + sess.MacAddr = macHw.String() + } + sess.MacHw = macHw + // 统一macAddr的格式 + sess.MacAddr = macHw.String() + + other := &dbdata.SettingOther{} + dbdata.SettingGet(other) + rd := RequestData{ + SessionId: sess.Sid, + SessionToken: sess.Sid + "@" + sess.Token, + Banner: other.Banner, + ProfileName: base.Cfg.ProfileName, + ProfileHash: profileHash, + CertHash: certHash, + } + + w.WriteHeader(http.StatusOK) + tplRequest(tpl_complete, w, rd) + base.Info("login", cr.Auth.Username, cr.UserAgent) +} + +func LinkAuth_otp(w http.ResponseWriter, r *http.Request) { + sessionID, err := GetCookie(r, "auth-session-id") + if err != nil { + http.Error(w, "Invalid session, please login again", http.StatusUnauthorized) + return + } + + sessionData, err := SessStore.GetAuthSession(sessionID) + if err != nil { + http.Error(w, "Invalid session, please login again", http.StatusUnauthorized) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer r.Body.Close() + + cr := ClientRequest{} + err = xml.Unmarshal(body, &cr) + if err != nil { + base.Error(err) + w.WriteHeader(http.StatusBadRequest) + return + } + + ua := sessionData.UserActLog + username := sessionData.ClientRequest.Auth.Username + otpSecret := sessionData.ClientRequest.Auth.OtpSecret + otp := cr.Auth.SecondaryPassword + + if !dbdata.CheckOtp(username, otp, otpSecret) { + base.Warn("OTP 动态码错误", r.RemoteAddr) + ua.Info = "OTP 动态码错误" + ua.Status = dbdata.UserAuthFail + dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent) + + w.WriteHeader(http.StatusOK) + data := RequestData{Error: "OTP 动态码错误"} + if base.Cfg.DisplayError { + data.Error = "OTP 动态码错误" + } + tplRequest(tpl_otp, w, data) + return + } + CreateSession(w, r, sessionData) + + // 删除临时会话信息 + SessStore.DeleteAuthSession(sessionID) + // DeleteCookie(w, "auth-session-id") +} + +var auth_otp = ` + + + OTP 动态码验证 + 请输入您的 OTP 动态码 + {{if .Error}} + 验证失败: %s + {{end}} +
+ +
+
+
` diff --git a/server/handler/link_base.go b/server/handler/link_base.go index c46da15..95c3e5d 100644 --- a/server/handler/link_base.go +++ b/server/handler/link_base.go @@ -17,6 +17,8 @@ type ClientRequest struct { Version string `xml:"version"` // 客户端版本号 GroupAccess string `xml:"group-access"` // 请求的地址 GroupSelect string `xml:"group-select"` // 选择的组名 + RemoteAddr string `xml:"remote_addr"` + UserAgent string `xml:"user_agent"` SessionId string `xml:"session-id"` SessionToken string `xml:"session-token"` Auth auth `xml:"auth"` @@ -27,6 +29,7 @@ type ClientRequest struct { type auth struct { Username string `xml:"username"` Password string `xml:"password"` + OtpSecret string `xml:"otp_secret"` SecondaryPassword string `xml:"secondary_password"` } diff --git a/server/handler/server.go b/server/handler/server.go index 2fc5baf..bf0a140 100644 --- a/server/handler/server.go +++ b/server/handler/server.go @@ -114,6 +114,7 @@ func initRoute() http.Handler { r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost) r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect) r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet) + r.HandleFunc("/otp-verification", LinkAuth_otp) r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) { b, _ := os.ReadFile(base.Cfg.Profile) w.Write(b) From fd383b92f5701a1a43e141b7e60bb67c92b56368 Mon Sep 17 00:00:00 2001 From: wsczx Date: Mon, 7 Oct 2024 17:31:21 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E7=AC=AC=E4=B8=89=E6=96=B9=E9=AA=8C=E8=AF=81=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E6=97=A0=E6=B3=95=E5=BB=BA=E7=AB=8B=E8=BF=9E=E6=8E=A5=E7=9A=84?= =?UTF-8?q?Bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/handler/link_auth.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handler/link_auth.go b/server/handler/link_auth.go index fe0be6d..cabe728 100644 --- a/server/handler/link_auth.go +++ b/server/handler/link_auth.go @@ -115,8 +115,8 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { v := &dbdata.User{} err = dbdata.One("Username", cr.Auth.Username, v) if err != nil { - base.Error("Failed to get TOTP secret for user:", cr.Auth.Username, err) - http.Error(w, "Failed to get TOTP secret", http.StatusInternalServerError) + base.Info("正在使用第三方认证方式登录") + CreateSession(w, r, sessionData) return } // 用户otp验证 From 4c219a3127324a2ade6d42249bd1558288a49534 Mon Sep 17 00:00:00 2001 From: wsczx Date: Tue, 8 Oct 2024 00:15:59 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=88=A0=E9=99=A4CheckUser=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E5=8D=95=E5=85=83=E7=9A=84otp=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/dbdata/user_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/server/dbdata/user_test.go b/server/dbdata/user_test.go index 2238837..ea54918 100644 --- a/server/dbdata/user_test.go +++ b/server/dbdata/user_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/xlzd/gotp" ) func TestCheckUser(t *testing.T) { @@ -30,10 +29,10 @@ func TestCheckUser(t *testing.T) { ast.Nil(err) // 验证 PinCode + OtpSecret - totp := gotp.NewDefaultTOTP(u.OtpSecret) - secret := totp.Now() - err = CheckUser("aaa", u.PinCode+secret, group) - ast.Nil(err) + // totp := gotp.NewDefaultTOTP(u.OtpSecret) + // secret := totp.Now() + // err = CheckUser("aaa", u.PinCode+secret, group) + // ast.Nil(err) // 单独验证密码 u.DisableOtp = true