优化代码

This commit is contained in:
bjdgyc 2024-10-24 18:10:29 +08:00
parent 772b1118eb
commit bd6ee0b140
4 changed files with 39 additions and 34 deletions

View File

@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
_ = xml.EscapeText(buf, []byte(data.Banner)) _ = xml.EscapeText(buf, []byte(data.Banner))
data.Banner = buf.String() data.Banner = buf.String()
} }
t, _ := template.New("auth_complete").Parse(auth_complete) t, _ := template.New("auth_complete").Parse(auth_complete)
_ = t.Execute(w, data) _ = t.Execute(w, data)
case tpl_otp: case tpl_otp:

View File

@ -2,26 +2,28 @@ package handler
import ( import (
"crypto/md5" "crypto/md5"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"sync" "sync"
"sync/atomic"
"github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
"github.com/bjdgyc/anylink/sessdata" "github.com/bjdgyc/anylink/sessdata"
) )
var SessStore = NewSessionStore() var SessStore = NewSessionStore()
const maxOtpErrCount = 3
type AuthSession struct { type AuthSession struct {
ClientRequest *ClientRequest ClientRequest *ClientRequest
UserActLog *dbdata.UserActLog UserActLog *dbdata.UserActLog
OtpErrCount atomic.Uint32 // otp错误次数
} }
// 存储临时会话信息 // 存储临时会话信息
@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
delete(s.session, sessionID) delete(s.session, sessionID)
} }
func (a *AuthSession) AddOtpErrCount(i int) int {
newI := a.OtpErrCount.Add(uint32(i))
return int(newI)
}
func GenerateSessionID() (string, error) { func GenerateSessionID() (string, error) {
b := make([]byte, 32) sessionID := utils.RandomRunes(32)
_, err := rand.Read(b) if sessionID == "" {
if err != nil { return "", fmt.Errorf("failed to generate session ID")
return "", fmt.Errorf("failed to generate session ID: %w", err)
} }
hash := sha256.Sum256(b)
sessionID := base64.URLEncoding.EncodeToString(hash[:])
return sessionID, nil return sessionID, nil
} }
@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
otpSecret := sessionData.ClientRequest.Auth.OtpSecret otpSecret := sessionData.ClientRequest.Auth.OtpSecret
otp := cr.Auth.SecondaryPassword otp := cr.Auth.SecondaryPassword
// 动态码错误
if !dbdata.CheckOtp(username, otp, otpSecret) { if !dbdata.CheckOtp(username, otp, otpSecret) {
base.Warn("OTP 动态码错误", r.RemoteAddr) if sessionData.AddOtpErrCount(1) > maxOtpErrCount {
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
return
}
base.Warn("OTP 动态码错误", username, r.RemoteAddr)
ua.Info = "OTP 动态码错误" ua.Info = "OTP 动态码错误"
ua.Status = dbdata.UserAuthFail ua.Status = dbdata.UserAuthFail
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent) dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
data := RequestData{Error: "OTP 动态码错误"} data := RequestData{Error: "请求错误"}
if base.Cfg.DisplayError { if base.Cfg.DisplayError {
data.Error = "OTP 动态码错误" data.Error = "OTP 动态码错误"
} }
@ -216,7 +226,7 @@ var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error> <error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
{{end}} {{end}}
<form method="post" action="/otp-verification"> <form method="post" action="/otp-verification">
<input type="password" name="secondary_password" label="OTP"/> <input type="password" name="secondary_password" label="OTPCode:"/>
</form> </form>
</auth> </auth>
</config-auth>` </config-auth>`

View File

@ -1,7 +1,10 @@
package utils package utils
import ( import (
crand "crypto/rand"
"encoding/hex"
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -83,9 +86,7 @@ func HumanByte(bf interface{}) string {
func RandomRunes(length int) string { func RandomRunes(length int) string {
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890") letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
bytes := make([]rune, length) bytes := make([]rune, length)
for i := range bytes { for i := range bytes {
bytes[i] = letterRunes[rand.Intn(len(letterRunes))] bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
} }
@ -93,6 +94,17 @@ func RandomRunes(length int) string {
return string(bytes) return string(bytes)
} }
func RandomHex(length int) string {
b := make([]byte, length)
_, err := crand.Read(b)
if err != nil {
log.Println(err)
return ""
}
return hex.EncodeToString(b)
}
func ParseName(name string) string { func ParseName(name string) string {
name = strings.ReplaceAll(name, " ", "-") name = strings.ReplaceAll(name, " ", "-")
name = strings.ReplaceAll(name, "'", "-") name = strings.ReplaceAll(name, "'", "-")

View File

@ -2,7 +2,6 @@ package sessdata
import ( import (
"fmt" "fmt"
"math/rand"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -12,6 +11,7 @@ import (
"github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
mapset "github.com/deckarep/golang-set" mapset "github.com/deckarep/golang-set"
) )
@ -91,10 +91,6 @@ type Session struct {
CSess *ConnSession CSess *ConnSession
} }
func init() {
rand.Seed(time.Now().UnixNano())
}
func checkSession() { func checkSession() {
// 检测过期的session // 检测过期的session
go func() { go func() {
@ -144,28 +140,16 @@ func CloseUserLimittimeSession() {
} }
} }
func GenToken() string {
// 生成32位的 token
bToken := make([]byte, 32)
rand.Read(bToken)
return fmt.Sprintf("%x", bToken)
}
func NewSession(token string) *Session { func NewSession(token string) *Session {
if token == "" { if token == "" {
btoken := make([]byte, 32) token = utils.RandomHex(32)
rand.Read(btoken)
token = fmt.Sprintf("%x", btoken)
} }
// 生成 dtlsn session_id // 生成 dtlsn session_id
dtlsid := make([]byte, 32)
rand.Read(dtlsid)
sess := &Session{ sess := &Session{
Sid: fmt.Sprintf("%d", time.Now().Unix()), Sid: fmt.Sprintf("%d", time.Now().Unix()),
Token: token, Token: token,
DtlsSid: fmt.Sprintf("%x", dtlsid), DtlsSid: utils.RandomHex(32),
LastLogin: time.Now(), LastLogin: time.Now(),
} }