增加弹窗输入OTP动态码的功能

This commit is contained in:
wsczx 2024-10-07 17:11:56 +08:00
parent 1c6fc446c9
commit 11bd9861e5
6 changed files with 288 additions and 67 deletions

View File

@ -117,13 +117,13 @@ func checkLocalUser(name, pwd, group string) error {
} }
// 判断otp信息 // 判断otp信息
pinCode := pwd pinCode := pwd
if !v.DisableOtp { // if !v.DisableOtp {
pinCode = pwd[:pl-6] // pinCode = pwd[:pl-6]
otp := pwd[pl-6:] // otp := pwd[pl-6:]
if !checkOtp(name, otp, v.OtpSecret) { // if !CheckOtp(name, otp, v.OtpSecret) {
return fmt.Errorf("%s %s", name, "动态码错误") // return fmt.Errorf("%s %s", name, "动态码错误")
} // }
} // }
// 判断用户密码 // 判断用户密码
if pinCode != v.PinCode { if pinCode != v.PinCode {
@ -171,7 +171,7 @@ func init() {
} }
// 判断令牌信息 // 判断令牌信息
func checkOtp(name, otp, secret string) bool { func CheckOtp(name, otp, secret string) bool {
key := fmt.Sprintf("%s:%s", name, otp) key := fmt.Sprintf("%s:%s", name, otp)
userOtpMux.Lock() userOtpMux.Lock()

View File

@ -3,7 +3,6 @@ package handler
import ( import (
"encoding/xml" "encoding/xml"
"io" "io"
"log"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -85,21 +84,21 @@ func antiBruteForce(next http.Handler) http.Handler {
// 检查全局 IP 锁定 // 检查全局 IP 锁定
if base.Cfg.MaxGlobalIPBanCount > 0 && lockManager.checkGlobalIPLock(ip, now) { if base.Cfg.MaxGlobalIPBanCount > 0 && lockManager.checkGlobalIPLock(ip, now) {
log.Printf("IP %s is globally locked. Try again later.", ip) base.Warn("IP", ip, "is globally locked. Try again later.")
http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests)
return return
} }
// 检查全局用户锁定 // 检查全局用户锁定
if base.Cfg.MaxGlobalUserBanCount > 0 && lockManager.checkGlobalUserLock(username, now) { if base.Cfg.MaxGlobalUserBanCount > 0 && lockManager.checkGlobalUserLock(username, now) {
log.Printf("User %s is globally locked. Try again later.", username) base.Warn("User", username, "is globally locked. Try again later.")
http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) http.Error(w, "Account globally locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests)
return return
} }
// 检查单个用户的 IP 锁定 // 检查单个用户的 IP 锁定
if base.Cfg.MaxBanCount > 0 && lockManager.checkUserIPLock(username, ip, now) { if base.Cfg.MaxBanCount > 0 && lockManager.checkUserIPLock(username, ip, now) {
log.Printf("IP %s is locked for user %s. Try again later.", ip, username) base.Warn("IP", ip, "is locked for user", username, "Try again later.")
http.Error(w, "Account locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests) http.Error(w, "Account locked due to too many failed attempts. Try again later.", http.StatusTooManyRequests)
return return
} }

View File

@ -3,11 +3,9 @@ package handler
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/md5"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strings" "strings"
@ -47,7 +45,10 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
} }
defer r.Body.Close() defer r.Body.Close()
cr := ClientRequest{} cr := &ClientRequest{
RemoteAddr: r.RemoteAddr,
UserAgent: userAgent,
}
err = xml.Unmarshal(body, &cr) err = xml.Unmarshal(body, &cr)
if err != nil { if err != nil {
base.Error(err) base.Error(err)
@ -78,7 +79,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
return return
} }
// 用户活动日志 // 用户活动日志
ua := dbdata.UserActLog{ ua := &dbdata.UserActLog{
Username: cr.Auth.Username, Username: cr.Auth.Username,
GroupName: cr.GroupSelect, GroupName: cr.GroupSelect,
RemoteAddr: r.RemoteAddr, RemoteAddr: r.RemoteAddr,
@ -86,6 +87,11 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
DeviceType: cr.DeviceId.DeviceType, DeviceType: cr.DeviceId.DeviceType,
PlatformVersion: cr.DeviceId.PlatformVersion, PlatformVersion: cr.DeviceId.PlatformVersion,
} }
sessionData := &AuthSession{
ClientRequest: cr,
UserActLog: ua,
}
// TODO 用户密码校验 // TODO 用户密码校验
err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect)
if err != nil { if err != nil {
@ -93,7 +99,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
base.Warn(err, r.RemoteAddr) base.Warn(err, r.RemoteAddr)
ua.Info = err.Error() ua.Info = err.Error()
ua.Status = dbdata.UserAuthFail ua.Status = dbdata.UserAuthFail
dbdata.UserActLogIns.Add(ua, userAgent) dbdata.UserActLogIns.Add(*ua, userAgent)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
data := RequestData{Group: cr.GroupSelect, Groups: dbdata.GetGroupNamesNormal(), Error: "用户名或密码错误"} data := RequestData{Group: cr.GroupSelect, Groups: dbdata.GetGroupNamesNormal(), Error: "用户名或密码错误"}
@ -104,72 +110,62 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
return return
} }
r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, true)) // 传递登录成功状态 r = r.WithContext(context.WithValue(r.Context(), loginStatusKey, true)) // 传递登录成功状态
dbdata.UserActLogIns.Add(ua, userAgent) dbdata.UserActLogIns.Add(*ua, userAgent)
// if !ok {
// w.WriteHeader(http.StatusOK)
// data := RequestData{Group: cr.GroupSelect, Groups: base.Cfg.UserGroups, Error: "请先激活用户"}
// tplRequest(tpl_request, w, data)
// return
// }
// 创建新的session信息 v := &dbdata.User{}
sess := sessdata.NewSession("") err = dbdata.One("Username", cr.Auth.Username, v)
sess.Username = cr.Auth.Username
sess.Group = cr.GroupSelect
oriMac := cr.MacAddressList.MacAddress
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
sess.UserAgent = userAgent
sess.DeviceType = ua.DeviceType
sess.PlatformVersion = ua.PlatformVersion
sess.RemoteAddr = r.RemoteAddr
// 获取客户端mac地址
sess.UniqueMac = true
macHw, err := net.ParseMAC(oriMac)
if err != nil { if err != nil {
var sum [16]byte base.Error("Failed to get TOTP secret for user:", cr.Auth.Username, err)
if sess.UniqueIdGlobal != "" { http.Error(w, "Failed to get TOTP secret", http.StatusInternalServerError)
sum = md5.Sum([]byte(sess.UniqueIdGlobal)) return
} else {
sum = md5.Sum([]byte(sess.Token))
sess.UniqueMac = false
}
macHw = sum[0:5] // 5个byte
macHw = append([]byte{0x02}, macHw...)
sess.MacAddr = macHw.String()
} }
sess.MacHw = macHw // 用户otp验证
// 统一macAddr的格式 if !v.DisableOtp {
sess.MacAddr = macHw.String() sessionID, err := GenerateSessionID()
if err != nil {
base.Error("Failed to generate session ID: ", err)
http.Error(w, "Failed to generate session ID", http.StatusInternalServerError)
return
}
other := &dbdata.SettingOther{} sessionData.ClientRequest.Auth.OtpSecret = v.OtpSecret
_ = dbdata.SettingGet(other) SessStore.SaveAuthSession(sessionID, sessionData)
rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token,
Banner: other.Banner, ProfileName: base.Cfg.ProfileName, ProfileHash: profileHash, CertHash: certHash} SetCookie(w, "auth-session-id", sessionID, 0)
w.WriteHeader(http.StatusOK)
tplRequest(tpl_complete, w, rd) data := RequestData{}
base.Info("login", cr.Auth.Username, userAgent) w.WriteHeader(http.StatusOK)
tplRequest(tpl_otp, w, data)
return
}
CreateSession(w, r, sessionData)
} }
const ( const (
tpl_request = iota tpl_request = iota
tpl_complete tpl_complete
tpl_otp
) )
func tplRequest(typ int, w io.Writer, data RequestData) { func tplRequest(typ int, w io.Writer, data RequestData) {
if typ == tpl_request { switch typ {
case tpl_request:
t, _ := template.New("auth_request").Parse(auth_request) t, _ := template.New("auth_request").Parse(auth_request)
_ = t.Execute(w, data) _ = t.Execute(w, data)
return case tpl_complete:
} if data.Banner != "" {
buf := new(bytes.Buffer)
_ = xml.EscapeText(buf, []byte(data.Banner))
data.Banner = buf.String()
}
if data.Banner != "" { t, _ := template.New("auth_complete").Parse(auth_complete)
buf := new(bytes.Buffer) _ = t.Execute(w, data)
_ = xml.EscapeText(buf, []byte(data.Banner)) case tpl_otp:
data.Banner = buf.String() t, _ := template.New("auth_otp").Parse(auth_otp)
_ = t.Execute(w, data)
} }
t, _ := template.New("auth_complete").Parse(auth_complete)
_ = t.Execute(w, data)
} }
// 设置输出信息 // 设置输出信息

View File

@ -0,0 +1,222 @@
package handler
import (
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"sync"
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/sessdata"
)
var SessStore = NewSessionStore()
type AuthSession struct {
ClientRequest *ClientRequest
UserActLog *dbdata.UserActLog
}
// 存储临时会话信息
type SessionStore struct {
session map[string]*AuthSession
mu sync.Mutex
}
func NewSessionStore() *SessionStore {
return &SessionStore{
session: make(map[string]*AuthSession),
}
}
func (s *SessionStore) SaveAuthSession(sessionID string, session *AuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.session[sessionID] = session
}
func (s *SessionStore) GetAuthSession(sessionID string) (*AuthSession, error) {
s.mu.Lock()
defer s.mu.Unlock()
session, exists := s.session[sessionID]
if !exists {
return nil, fmt.Errorf("auth session not found")
}
return session, nil
}
func (s *SessionStore) DeleteAuthSession(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.session, sessionID)
}
func GenerateSessionID() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
hash := sha256.Sum256(b)
sessionID := base64.URLEncoding.EncodeToString(hash[:])
return sessionID, nil
}
func SetCookie(w http.ResponseWriter, name, value string, maxAge int) {
cookie := &http.Cookie{
Name: name,
Value: value,
MaxAge: maxAge,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
}
func GetCookie(r *http.Request, name string) (string, error) {
cookie, err := r.Cookie(name)
if err != nil {
return "", fmt.Errorf("failed to get cookie: %v", err)
}
return cookie.Value, nil
}
func DeleteCookie(w http.ResponseWriter, name string) {
cookie := &http.Cookie{
Name: name,
Value: "",
MaxAge: -1,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
}
http.SetCookie(w, cookie)
}
func CreateSession(w http.ResponseWriter, r *http.Request, authSession *AuthSession) {
cr := authSession.ClientRequest
ua := authSession.UserActLog
sess := sessdata.NewSession("")
sess.Username = cr.Auth.Username
sess.Group = cr.GroupSelect
oriMac := cr.MacAddressList.MacAddress
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
sess.UserAgent = cr.UserAgent
sess.DeviceType = ua.DeviceType
sess.PlatformVersion = ua.PlatformVersion
sess.RemoteAddr = cr.RemoteAddr
// 获取客户端mac地址
sess.UniqueMac = true
macHw, err := net.ParseMAC(oriMac)
if err != nil {
var sum [16]byte
if sess.UniqueIdGlobal != "" {
sum = md5.Sum([]byte(sess.UniqueIdGlobal))
} else {
sum = md5.Sum([]byte(sess.Token))
sess.UniqueMac = false
}
macHw = sum[0:5] // 5个byte
macHw = append([]byte{0x02}, macHw...)
sess.MacAddr = macHw.String()
}
sess.MacHw = macHw
// 统一macAddr的格式
sess.MacAddr = macHw.String()
other := &dbdata.SettingOther{}
dbdata.SettingGet(other)
rd := RequestData{
SessionId: sess.Sid,
SessionToken: sess.Sid + "@" + sess.Token,
Banner: other.Banner,
ProfileName: base.Cfg.ProfileName,
ProfileHash: profileHash,
CertHash: certHash,
}
w.WriteHeader(http.StatusOK)
tplRequest(tpl_complete, w, rd)
base.Info("login", cr.Auth.Username, cr.UserAgent)
}
func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
sessionID, err := GetCookie(r, "auth-session-id")
if err != nil {
http.Error(w, "Invalid session, please login again", http.StatusUnauthorized)
return
}
sessionData, err := SessStore.GetAuthSession(sessionID)
if err != nil {
http.Error(w, "Invalid session, please login again", http.StatusUnauthorized)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
defer r.Body.Close()
cr := ClientRequest{}
err = xml.Unmarshal(body, &cr)
if err != nil {
base.Error(err)
w.WriteHeader(http.StatusBadRequest)
return
}
ua := sessionData.UserActLog
username := sessionData.ClientRequest.Auth.Username
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
otp := cr.Auth.SecondaryPassword
if !dbdata.CheckOtp(username, otp, otpSecret) {
base.Warn("OTP 动态码错误", r.RemoteAddr)
ua.Info = "OTP 动态码错误"
ua.Status = dbdata.UserAuthFail
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
w.WriteHeader(http.StatusOK)
data := RequestData{Error: "OTP 动态码错误"}
if base.Cfg.DisplayError {
data.Error = "OTP 动态码错误"
}
tplRequest(tpl_otp, w, data)
return
}
CreateSession(w, r, sessionData)
// 删除临时会话信息
SessStore.DeleteAuthSession(sessionID)
// DeleteCookie(w, "auth-session-id")
}
var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
<config-auth client="vpn" type="auth-request" aggregate-auth-version="2">
<auth id="otp-verification">
<title>OTP 动态码验证</title>
<message>请输入您的 OTP 动态码</message>
{{if .Error}}
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
{{end}}
<form method="post" action="/otp-verification">
<input type="password" name="secondary_password" label="OTP"/>
</form>
</auth>
</config-auth>`

View File

@ -17,6 +17,8 @@ type ClientRequest struct {
Version string `xml:"version"` // 客户端版本号 Version string `xml:"version"` // 客户端版本号
GroupAccess string `xml:"group-access"` // 请求的地址 GroupAccess string `xml:"group-access"` // 请求的地址
GroupSelect string `xml:"group-select"` // 选择的组名 GroupSelect string `xml:"group-select"` // 选择的组名
RemoteAddr string `xml:"remote_addr"`
UserAgent string `xml:"user_agent"`
SessionId string `xml:"session-id"` SessionId string `xml:"session-id"`
SessionToken string `xml:"session-token"` SessionToken string `xml:"session-token"`
Auth auth `xml:"auth"` Auth auth `xml:"auth"`
@ -27,6 +29,7 @@ type ClientRequest struct {
type auth struct { type auth struct {
Username string `xml:"username"` Username string `xml:"username"`
Password string `xml:"password"` Password string `xml:"password"`
OtpSecret string `xml:"otp_secret"`
SecondaryPassword string `xml:"secondary_password"` SecondaryPassword string `xml:"secondary_password"`
} }

View File

@ -114,6 +114,7 @@ func initRoute() http.Handler {
r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost) r.Handle("/", antiBruteForce(http.HandlerFunc(LinkAuth))).Methods(http.MethodPost)
r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect) r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect)
r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet) r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet)
r.HandleFunc("/otp-verification", LinkAuth_otp)
r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(fmt.Sprintf("/profile_%s.xml", base.Cfg.ProfileName), func(w http.ResponseWriter, r *http.Request) {
b, _ := os.ReadFile(base.Cfg.Profile) b, _ := os.ReadFile(base.Cfg.Profile)
w.Write(b) w.Write(b)