From 6ee80d32eab4ce53c2b4f214c427668f57004661 Mon Sep 17 00:00:00 2001 From: bjdgyc Date: Fri, 21 Apr 2023 11:39:51 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AF=81=E4=B9=A6=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/admin/server.go | 4 ++-- server/base/log.go | 12 +++++++++++- server/dbdata/cert.go | 14 ++++++++++++-- server/handler/server.go | 13 ++++++++----- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/server/admin/server.go b/server/admin/server.go index b0cbbca..3cabb99 100644 --- a/server/admin/server.go +++ b/server/admin/server.go @@ -100,9 +100,9 @@ func StartAdmin() { for _, s := range cipherSuites { selectedCipherSuites = append(selectedCipherSuites, s.ID) } + if tlscert, _, err := dbdata.ParseCert(); err != nil { - base.Error(err) - return + base.Fatal("证书加载失败", err) } else { dbdata.LoadCertificate(tlscert) } diff --git a/server/base/log.go b/server/base/log.go index 36bb927..f95a9c8 100644 --- a/server/base/log.go +++ b/server/base/log.go @@ -10,7 +10,8 @@ import ( ) const ( - _Debug = iota + _Trace = iota + _Debug _Info _Warn _Error @@ -89,6 +90,7 @@ func GetBaseLog() *log.Logger { func logLevel2Int(l string) int { levels = map[int]string{ + _Trace: "Trace", _Debug: "Debug", _Info: "Info", _Warn: "Warn", @@ -109,6 +111,14 @@ func output(l int, s ...interface{}) { _ = baseLog.Output(3, lvl+fmt.Sprintln(s...)) } +func Trace(v ...interface{}) { + l := _Trace + if baseLevel > l { + return + } + output(l, v...) +} + func Debug(v ...interface{}) { l := _Debug if baseLevel > l { diff --git a/server/dbdata/cert.go b/server/dbdata/cert.go index 089a671..3b150ba 100644 --- a/server/dbdata/cert.go +++ b/server/dbdata/cert.go @@ -275,8 +275,10 @@ func ParseCert() (*tls.Certificate, *time.Time, error) { _, errCert := os.Stat(base.Cfg.CertFile) _, errKey := os.Stat(base.Cfg.CertKey) if os.IsNotExist(errCert) || os.IsNotExist(errKey) { - PrivateCert() - + err := PrivateCert() + if err != nil { + return nil, nil, err + } } cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey) if err != nil || errors.Is(err, os.ErrNotExist) { @@ -353,6 +355,11 @@ func GetCertificateBySNI(commonName string) (*tls.Certificate, error) { return cert, nil } } + // 默认证书 兼容不支持 SNI 的客户端 + if cert, ok := nameToCertificate["default"]; ok { + return cert, nil + } + return getTempCertificate() } @@ -362,6 +369,9 @@ func LoadCertificate(cert *tls.Certificate) { // Copy from tls.Config BuildNameToCertificate() func buildNameToCertificate(cert *tls.Certificate) { + // 设置默认证书 + nameToCertificate["default"] = cert + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return diff --git a/server/handler/server.go b/server/handler/server.go index 609f0e8..34fd016 100644 --- a/server/handler/server.go +++ b/server/handler/server.go @@ -50,15 +50,18 @@ func startTls() { MinVersion: tls.VersionTLS12, CipherSuites: selectedCipherSuites, GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + base.Trace("GetCertificate", chi.ServerName) return dbdata.GetCertificateBySNI(chi.ServerName) }, // InsecureSkipVerify: true, } srv := &http.Server{ - Addr: addr, - Handler: initRoute(), - TLSConfig: tlsConfig, - ErrorLog: base.GetBaseLog(), + Addr: addr, + Handler: initRoute(), + TLSConfig: tlsConfig, + ErrorLog: base.GetBaseLog(), + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, } ln, err = net.Listen("tcp", addr) @@ -70,7 +73,7 @@ func startTls() { if base.Cfg.ProxyProtocol { ln = &proxyproto.Listener{ Listener: ln, - ReadHeaderTimeout: 40 * time.Second, + ReadHeaderTimeout: 30 * time.Second, } }