优化代码

This commit is contained in:
bjdgyc 2024-10-24 18:10:29 +08:00
parent 772b1118eb
commit bd6ee0b140
4 changed files with 39 additions and 34 deletions

View File

@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
_ = xml.EscapeText(buf, []byte(data.Banner))
data.Banner = buf.String()
}
t, _ := template.New("auth_complete").Parse(auth_complete)
_ = t.Execute(w, data)
case tpl_otp:

View File

@ -2,26 +2,28 @@ package handler
import (
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"sync"
"sync/atomic"
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
"github.com/bjdgyc/anylink/sessdata"
)
var SessStore = NewSessionStore()
const maxOtpErrCount = 3
type AuthSession struct {
ClientRequest *ClientRequest
UserActLog *dbdata.UserActLog
OtpErrCount atomic.Uint32 // otp错误次数
}
// 存储临时会话信息
@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
delete(s.session, sessionID)
}
func (a *AuthSession) AddOtpErrCount(i int) int {
newI := a.OtpErrCount.Add(uint32(i))
return int(newI)
}
func GenerateSessionID() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
sessionID := utils.RandomRunes(32)
if sessionID == "" {
return "", fmt.Errorf("failed to generate session ID")
}
hash := sha256.Sum256(b)
sessionID := base64.URLEncoding.EncodeToString(hash[:])
return sessionID, nil
}
@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
otp := cr.Auth.SecondaryPassword
// 动态码错误
if !dbdata.CheckOtp(username, otp, otpSecret) {
base.Warn("OTP 动态码错误", r.RemoteAddr)
if sessionData.AddOtpErrCount(1) > maxOtpErrCount {
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
return
}
base.Warn("OTP 动态码错误", username, r.RemoteAddr)
ua.Info = "OTP 动态码错误"
ua.Status = dbdata.UserAuthFail
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
w.WriteHeader(http.StatusOK)
data := RequestData{Error: "OTP 动态码错误"}
data := RequestData{Error: "请求错误"}
if base.Cfg.DisplayError {
data.Error = "OTP 动态码错误"
}
@ -216,7 +226,7 @@ var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
{{end}}
<form method="post" action="/otp-verification">
<input type="password" name="secondary_password" label="OTP"/>
<input type="password" name="secondary_password" label="OTPCode:"/>
</form>
</auth>
</config-auth>`

View File

@ -1,7 +1,10 @@
package utils
import (
crand "crypto/rand"
"encoding/hex"
"fmt"
"log"
"math/rand"
"strings"
"sync/atomic"
@ -83,9 +86,7 @@ func HumanByte(bf interface{}) string {
func RandomRunes(length int) string {
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
bytes := make([]rune, length)
for i := range bytes {
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
}
@ -93,6 +94,17 @@ func RandomRunes(length int) string {
return string(bytes)
}
func RandomHex(length int) string {
b := make([]byte, length)
_, err := crand.Read(b)
if err != nil {
log.Println(err)
return ""
}
return hex.EncodeToString(b)
}
func ParseName(name string) string {
name = strings.ReplaceAll(name, " ", "-")
name = strings.ReplaceAll(name, "'", "-")

View File

@ -2,7 +2,6 @@ package sessdata
import (
"fmt"
"math/rand"
"net"
"strconv"
"strings"
@ -12,6 +11,7 @@ import (
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
mapset "github.com/deckarep/golang-set"
)
@ -91,10 +91,6 @@ type Session struct {
CSess *ConnSession
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func checkSession() {
// 检测过期的session
go func() {
@ -144,28 +140,16 @@ func CloseUserLimittimeSession() {
}
}
func GenToken() string {
// 生成32位的 token
bToken := make([]byte, 32)
rand.Read(bToken)
return fmt.Sprintf("%x", bToken)
}
func NewSession(token string) *Session {
if token == "" {
btoken := make([]byte, 32)
rand.Read(btoken)
token = fmt.Sprintf("%x", btoken)
token = utils.RandomHex(32)
}
// 生成 dtlsn session_id
dtlsid := make([]byte, 32)
rand.Read(dtlsid)
sess := &Session{
Sid: fmt.Sprintf("%d", time.Now().Unix()),
Token: token,
DtlsSid: fmt.Sprintf("%x", dtlsid),
DtlsSid: utils.RandomHex(32),
LastLogin: time.Now(),
}