兼容不支持SNI的情况

This commit is contained in:
bjdgyc 2023-04-21 15:57:12 +08:00
parent c05ec9ab36
commit 91ce4752f3
3 changed files with 40 additions and 25 deletions

View File

@ -10,12 +10,12 @@ import (
) )
const ( const (
_Trace = iota LogLevelTrace = iota
_Debug LogLevelDebug
_Info LogLevelInfo
_Warn LogLevelWarn
_Error LogLevelError
_Fatal LogLevelFatal
) )
var ( var (
@ -88,16 +88,20 @@ func GetBaseLog() *log.Logger {
return baseLog return baseLog
} }
func GetLogLevel() int {
return baseLevel
}
func logLevel2Int(l string) int { func logLevel2Int(l string) int {
levels = map[int]string{ levels = map[int]string{
_Trace: "Trace", LogLevelTrace: "Trace",
_Debug: "Debug", LogLevelDebug: "Debug",
_Info: "Info", LogLevelInfo: "Info",
_Warn: "Warn", LogLevelWarn: "Warn",
_Error: "Error", LogLevelError: "Error",
_Fatal: "Fatal", LogLevelFatal: "Fatal",
} }
lvl := _Info lvl := LogLevelInfo
for k, v := range levels { for k, v := range levels {
if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) { if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) {
lvl = k lvl = k
@ -112,7 +116,7 @@ func output(l int, s ...interface{}) {
} }
func Trace(v ...interface{}) { func Trace(v ...interface{}) {
l := _Trace l := LogLevelTrace
if baseLevel > l { if baseLevel > l {
return return
} }
@ -120,7 +124,7 @@ func Trace(v ...interface{}) {
} }
func Debug(v ...interface{}) { func Debug(v ...interface{}) {
l := _Debug l := LogLevelDebug
if baseLevel > l { if baseLevel > l {
return return
} }
@ -128,7 +132,7 @@ func Debug(v ...interface{}) {
} }
func Info(v ...interface{}) { func Info(v ...interface{}) {
l := _Info l := LogLevelInfo
if baseLevel > l { if baseLevel > l {
return return
} }
@ -136,7 +140,7 @@ func Info(v ...interface{}) {
} }
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
l := _Warn l := LogLevelWarn
if baseLevel > l { if baseLevel > l {
return return
} }
@ -144,7 +148,7 @@ func Warn(v ...interface{}) {
} }
func Error(v ...interface{}) { func Error(v ...interface{}) {
l := _Error l := LogLevelError
if baseLevel > l { if baseLevel > l {
return return
} }
@ -152,7 +156,7 @@ func Error(v ...interface{}) {
} }
func Fatal(v ...interface{}) { func Fatal(v ...interface{}) {
l := _Fatal l := LogLevelFatal
if baseLevel > l { if baseLevel > l {
return return
} }

View File

@ -33,9 +33,12 @@ import (
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
) )
var nameToCertificate = make(map[string]*tls.Certificate) var (
// nameToCertificate mutex
var tempCert *tls.Certificate ntcMux sync.RWMutex
nameToCertificate = make(map[string]*tls.Certificate)
tempCert *tls.Certificate
)
func init() { func init() {
c, _ := selfsign.GenerateSelfSignedWithDNS("localhost") c, _ := selfsign.GenerateSelfSignedWithDNS("localhost")
@ -342,6 +345,9 @@ func getTempCertificate() (*tls.Certificate, error) {
} }
func GetCertificateBySNI(commonName string) (*tls.Certificate, error) { func GetCertificateBySNI(commonName string) (*tls.Certificate, error) {
ntcMux.RLock()
defer ntcMux.RUnlock()
// Copy from tls.Config getCertificate() // Copy from tls.Config getCertificate()
name := strings.ToLower(commonName) name := strings.ToLower(commonName)
if cert, ok := nameToCertificate[name]; ok { if cert, ok := nameToCertificate[name]; ok {
@ -369,6 +375,9 @@ func LoadCertificate(cert *tls.Certificate) {
// Copy from tls.Config BuildNameToCertificate() // Copy from tls.Config BuildNameToCertificate()
func buildNameToCertificate(cert *tls.Certificate) { func buildNameToCertificate(cert *tls.Certificate) {
ntcMux.Lock()
defer ntcMux.Unlock()
// TODO 设置默认证书 // TODO 设置默认证书
nameToCertificate["default"] = cert nameToCertificate["default"] = cert

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"strings" "strings"
"text/template" "text/template"
@ -19,9 +20,10 @@ var profileHash = ""
func LinkAuth(w http.ResponseWriter, r *http.Request) { func LinkAuth(w http.ResponseWriter, r *http.Request) {
// TODO 调试信息输出 // TODO 调试信息输出
// hd, _ := httputil.DumpRequest(r, true) if base.GetLogLevel() == base.LogLevelTrace {
// base.Debug("DumpRequest: ", string(hd)) hd, _ := httputil.DumpRequest(r, true)
base.Trace("LinkAuth: ", string(hd))
}
// 判断anyconnect客户端 // 判断anyconnect客户端
userAgent := strings.ToLower(r.UserAgent()) userAgent := strings.ToLower(r.UserAgent())
xAggregateAuth := r.Header.Get("X-Aggregate-Auth") xAggregateAuth := r.Header.Get("X-Aggregate-Auth")