mirror of https://github.com/bjdgyc/anylink.git
优化代码
This commit is contained in:
parent
772b1118eb
commit
bd6ee0b140
|
@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
|
||||||
_ = xml.EscapeText(buf, []byte(data.Banner))
|
_ = xml.EscapeText(buf, []byte(data.Banner))
|
||||||
data.Banner = buf.String()
|
data.Banner = buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
t, _ := template.New("auth_complete").Parse(auth_complete)
|
t, _ := template.New("auth_complete").Parse(auth_complete)
|
||||||
_ = t.Execute(w, data)
|
_ = t.Execute(w, data)
|
||||||
case tpl_otp:
|
case tpl_otp:
|
||||||
|
|
|
@ -2,26 +2,28 @@ package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/bjdgyc/anylink/base"
|
"github.com/bjdgyc/anylink/base"
|
||||||
"github.com/bjdgyc/anylink/dbdata"
|
"github.com/bjdgyc/anylink/dbdata"
|
||||||
|
"github.com/bjdgyc/anylink/pkg/utils"
|
||||||
"github.com/bjdgyc/anylink/sessdata"
|
"github.com/bjdgyc/anylink/sessdata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var SessStore = NewSessionStore()
|
var SessStore = NewSessionStore()
|
||||||
|
|
||||||
|
const maxOtpErrCount = 3
|
||||||
|
|
||||||
type AuthSession struct {
|
type AuthSession struct {
|
||||||
ClientRequest *ClientRequest
|
ClientRequest *ClientRequest
|
||||||
UserActLog *dbdata.UserActLog
|
UserActLog *dbdata.UserActLog
|
||||||
|
OtpErrCount atomic.Uint32 // otp错误次数
|
||||||
}
|
}
|
||||||
|
|
||||||
// 存储临时会话信息
|
// 存储临时会话信息
|
||||||
|
@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
|
||||||
delete(s.session, sessionID)
|
delete(s.session, sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AuthSession) AddOtpErrCount(i int) int {
|
||||||
|
newI := a.OtpErrCount.Add(uint32(i))
|
||||||
|
return int(newI)
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateSessionID() (string, error) {
|
func GenerateSessionID() (string, error) {
|
||||||
b := make([]byte, 32)
|
sessionID := utils.RandomRunes(32)
|
||||||
_, err := rand.Read(b)
|
if sessionID == "" {
|
||||||
if err != nil {
|
return "", fmt.Errorf("failed to generate session ID")
|
||||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := sha256.Sum256(b)
|
|
||||||
sessionID := base64.URLEncoding.EncodeToString(hash[:])
|
|
||||||
return sessionID, nil
|
return sessionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
|
||||||
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
|
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
|
||||||
otp := cr.Auth.SecondaryPassword
|
otp := cr.Auth.SecondaryPassword
|
||||||
|
|
||||||
|
// 动态码错误
|
||||||
if !dbdata.CheckOtp(username, otp, otpSecret) {
|
if !dbdata.CheckOtp(username, otp, otpSecret) {
|
||||||
base.Warn("OTP 动态码错误", r.RemoteAddr)
|
if sessionData.AddOtpErrCount(1) > maxOtpErrCount {
|
||||||
|
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
base.Warn("OTP 动态码错误", username, r.RemoteAddr)
|
||||||
ua.Info = "OTP 动态码错误"
|
ua.Info = "OTP 动态码错误"
|
||||||
ua.Status = dbdata.UserAuthFail
|
ua.Status = dbdata.UserAuthFail
|
||||||
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
|
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
data := RequestData{Error: "OTP 动态码错误"}
|
data := RequestData{Error: "请求错误"}
|
||||||
if base.Cfg.DisplayError {
|
if base.Cfg.DisplayError {
|
||||||
data.Error = "OTP 动态码错误"
|
data.Error = "OTP 动态码错误"
|
||||||
}
|
}
|
||||||
|
@ -216,7 +226,7 @@ var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
|
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
|
||||||
{{end}}
|
{{end}}
|
||||||
<form method="post" action="/otp-verification">
|
<form method="post" action="/otp-verification">
|
||||||
<input type="password" name="secondary_password" label="OTP"/>
|
<input type="password" name="secondary_password" label="OTPCode:"/>
|
||||||
</form>
|
</form>
|
||||||
</auth>
|
</auth>
|
||||||
</config-auth>`
|
</config-auth>`
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -83,9 +86,7 @@ func HumanByte(bf interface{}) string {
|
||||||
|
|
||||||
func RandomRunes(length int) string {
|
func RandomRunes(length int) string {
|
||||||
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
|
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
|
||||||
|
|
||||||
bytes := make([]rune, length)
|
bytes := make([]rune, length)
|
||||||
|
|
||||||
for i := range bytes {
|
for i := range bytes {
|
||||||
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
|
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
|
||||||
}
|
}
|
||||||
|
@ -93,6 +94,17 @@ func RandomRunes(length int) string {
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RandomHex(length int) string {
|
||||||
|
b := make([]byte, length)
|
||||||
|
_, err := crand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func ParseName(name string) string {
|
func ParseName(name string) string {
|
||||||
name = strings.ReplaceAll(name, " ", "-")
|
name = strings.ReplaceAll(name, " ", "-")
|
||||||
name = strings.ReplaceAll(name, "'", "-")
|
name = strings.ReplaceAll(name, "'", "-")
|
||||||
|
|
|
@ -2,7 +2,6 @@ package sessdata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -12,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/bjdgyc/anylink/base"
|
"github.com/bjdgyc/anylink/base"
|
||||||
"github.com/bjdgyc/anylink/dbdata"
|
"github.com/bjdgyc/anylink/dbdata"
|
||||||
|
"github.com/bjdgyc/anylink/pkg/utils"
|
||||||
mapset "github.com/deckarep/golang-set"
|
mapset "github.com/deckarep/golang-set"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,10 +91,6 @@ type Session struct {
|
||||||
CSess *ConnSession
|
CSess *ConnSession
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSession() {
|
func checkSession() {
|
||||||
// 检测过期的session
|
// 检测过期的session
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -144,28 +140,16 @@ func CloseUserLimittimeSession() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenToken() string {
|
|
||||||
// 生成32位的 token
|
|
||||||
bToken := make([]byte, 32)
|
|
||||||
rand.Read(bToken)
|
|
||||||
return fmt.Sprintf("%x", bToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSession(token string) *Session {
|
func NewSession(token string) *Session {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
btoken := make([]byte, 32)
|
token = utils.RandomHex(32)
|
||||||
rand.Read(btoken)
|
|
||||||
token = fmt.Sprintf("%x", btoken)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成 dtlsn session_id
|
// 生成 dtlsn session_id
|
||||||
dtlsid := make([]byte, 32)
|
|
||||||
rand.Read(dtlsid)
|
|
||||||
|
|
||||||
sess := &Session{
|
sess := &Session{
|
||||||
Sid: fmt.Sprintf("%d", time.Now().Unix()),
|
Sid: fmt.Sprintf("%d", time.Now().Unix()),
|
||||||
Token: token,
|
Token: token,
|
||||||
DtlsSid: fmt.Sprintf("%x", dtlsid),
|
DtlsSid: utils.RandomHex(32),
|
||||||
LastLogin: time.Now(),
|
LastLogin: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue