mirror of https://github.com/bjdgyc/anylink.git
优化代码
This commit is contained in:
parent
772b1118eb
commit
bd6ee0b140
|
@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
|
|||
_ = xml.EscapeText(buf, []byte(data.Banner))
|
||||
data.Banner = buf.String()
|
||||
}
|
||||
|
||||
t, _ := template.New("auth_complete").Parse(auth_complete)
|
||||
_ = t.Execute(w, data)
|
||||
case tpl_otp:
|
||||
|
|
|
@ -2,26 +2,28 @@ package handler
|
|||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/dbdata"
|
||||
"github.com/bjdgyc/anylink/pkg/utils"
|
||||
"github.com/bjdgyc/anylink/sessdata"
|
||||
)
|
||||
|
||||
var SessStore = NewSessionStore()
|
||||
|
||||
const maxOtpErrCount = 3
|
||||
|
||||
type AuthSession struct {
|
||||
ClientRequest *ClientRequest
|
||||
UserActLog *dbdata.UserActLog
|
||||
OtpErrCount atomic.Uint32 // otp错误次数
|
||||
}
|
||||
|
||||
// 存储临时会话信息
|
||||
|
@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
|
|||
delete(s.session, sessionID)
|
||||
}
|
||||
|
||||
func (a *AuthSession) AddOtpErrCount(i int) int {
|
||||
newI := a.OtpErrCount.Add(uint32(i))
|
||||
return int(newI)
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||
sessionID := utils.RandomRunes(32)
|
||||
if sessionID == "" {
|
||||
return "", fmt.Errorf("failed to generate session ID")
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(b)
|
||||
sessionID := base64.URLEncoding.EncodeToString(hash[:])
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
|
@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
|
|||
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
|
||||
otp := cr.Auth.SecondaryPassword
|
||||
|
||||
// 动态码错误
|
||||
if !dbdata.CheckOtp(username, otp, otpSecret) {
|
||||
base.Warn("OTP 动态码错误", r.RemoteAddr)
|
||||
if sessionData.AddOtpErrCount(1) > maxOtpErrCount {
|
||||
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
base.Warn("OTP 动态码错误", username, r.RemoteAddr)
|
||||
ua.Info = "OTP 动态码错误"
|
||||
ua.Status = dbdata.UserAuthFail
|
||||
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data := RequestData{Error: "OTP 动态码错误"}
|
||||
data := RequestData{Error: "请求错误"}
|
||||
if base.Cfg.DisplayError {
|
||||
data.Error = "OTP 动态码错误"
|
||||
}
|
||||
|
@ -216,7 +226,7 @@ var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
|
|||
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
|
||||
{{end}}
|
||||
<form method="post" action="/otp-verification">
|
||||
<input type="password" name="secondary_password" label="OTP"/>
|
||||
<input type="password" name="secondary_password" label="OTPCode:"/>
|
||||
</form>
|
||||
</auth>
|
||||
</config-auth>`
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
@ -83,9 +86,7 @@ func HumanByte(bf interface{}) string {
|
|||
|
||||
func RandomRunes(length int) string {
|
||||
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
|
||||
|
||||
bytes := make([]rune, length)
|
||||
|
||||
for i := range bytes {
|
||||
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
|
||||
}
|
||||
|
@ -93,6 +94,17 @@ func RandomRunes(length int) string {
|
|||
return string(bytes)
|
||||
}
|
||||
|
||||
func RandomHex(length int) string {
|
||||
b := make([]byte, length)
|
||||
_, err := crand.Read(b)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func ParseName(name string) string {
|
||||
name = strings.ReplaceAll(name, " ", "-")
|
||||
name = strings.ReplaceAll(name, "'", "-")
|
||||
|
|
|
@ -2,7 +2,6 @@ package sessdata
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -12,6 +11,7 @@ import (
|
|||
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/dbdata"
|
||||
"github.com/bjdgyc/anylink/pkg/utils"
|
||||
mapset "github.com/deckarep/golang-set"
|
||||
)
|
||||
|
||||
|
@ -91,10 +91,6 @@ type Session struct {
|
|||
CSess *ConnSession
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func checkSession() {
|
||||
// 检测过期的session
|
||||
go func() {
|
||||
|
@ -144,28 +140,16 @@ func CloseUserLimittimeSession() {
|
|||
}
|
||||
}
|
||||
|
||||
func GenToken() string {
|
||||
// 生成32位的 token
|
||||
bToken := make([]byte, 32)
|
||||
rand.Read(bToken)
|
||||
return fmt.Sprintf("%x", bToken)
|
||||
}
|
||||
|
||||
func NewSession(token string) *Session {
|
||||
if token == "" {
|
||||
btoken := make([]byte, 32)
|
||||
rand.Read(btoken)
|
||||
token = fmt.Sprintf("%x", btoken)
|
||||
token = utils.RandomHex(32)
|
||||
}
|
||||
|
||||
// 生成 dtlsn session_id
|
||||
dtlsid := make([]byte, 32)
|
||||
rand.Read(dtlsid)
|
||||
|
||||
sess := &Session{
|
||||
Sid: fmt.Sprintf("%d", time.Now().Unix()),
|
||||
Token: token,
|
||||
DtlsSid: fmt.Sprintf("%x", dtlsid),
|
||||
DtlsSid: utils.RandomHex(32),
|
||||
LastLogin: time.Now(),
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue