From 609a893feb4c2f49202a01d2adfa8d88b4577853 Mon Sep 17 00:00:00 2001 From: deny Date: Mon, 17 Apr 2023 11:07:39 +0000 Subject: [PATCH] =?UTF-8?q?feat:=E6=A0=B9=E6=8D=AESNI=E8=BF=94=E5=9B=9ESSL?= =?UTF-8?q?=E8=AF=81=E4=B9=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/admin/api_cert.go | 2 +- server/admin/server.go | 6 ++-- server/dbdata/cert.go | 71 ++++++++++++++++++++++++++++++++++++++-- server/handler/server.go | 4 +-- 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/server/admin/api_cert.go b/server/admin/api_cert.go index 5731c3a..6568d96 100755 --- a/server/admin/api_cert.go +++ b/server/admin/api_cert.go @@ -46,7 +46,7 @@ func CustomCert(w http.ResponseWriter, r *http.Request) { RespError(w, RespInternalErr, fmt.Sprintf("证书不合法,请重新上传:%v", err)) return } else { - dbdata.TLSCert = tlscert + dbdata.LoadCertificate(tlscert) } RespSucess(w, "上传成功") } diff --git a/server/admin/server.go b/server/admin/server.go index e7171af..b0cbbca 100644 --- a/server/admin/server.go +++ b/server/admin/server.go @@ -104,7 +104,7 @@ func StartAdmin() { base.Error(err) return } else { - dbdata.TLSCert = tlscert + dbdata.LoadCertificate(tlscert) } // 设置tls信息 @@ -112,8 +112,8 @@ func StartAdmin() { NextProtos: []string{"http/1.1"}, MinVersion: tls.VersionTLS12, CipherSuites: selectedCipherSuites, - GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return dbdata.TLSCert, nil + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return dbdata.GetCertificateBySNI(chi.ServerName) }, } srv := &http.Server{ diff --git a/server/dbdata/cert.go b/server/dbdata/cert.go index e1a9412..b9b2a6f 100755 --- a/server/dbdata/cert.go +++ b/server/dbdata/cert.go @@ -12,9 +12,11 @@ import ( "encoding/pem" "errors" "fmt" + "github.com/pion/dtls/v2/pkg/crypto/selfsign" "math/big" "net" "os" + "strings" "sync" "time" @@ -30,7 +32,14 @@ import ( "github.com/go-acme/lego/v4/registration" ) -var TLSCert *tls.Certificate +var nameToCertificate = make(map[string]*tls.Certificate) + +var tempCert *tls.Certificate + +func init() { + c, _ := selfsign.GenerateSelfSignedWithDNS("localhost") + tempCert = &c +} type SettingLetsEncrypt struct { Domain string `json:"domain"` @@ -200,6 +209,7 @@ func (c *LeGoClient) NewClient(l *SettingLetsEncrypt) error { c.Client = client return nil } + func (c *LeGoClient) GetCert(domain string) error { // 申请证书 certificates, err := c.Client.Certificate.Obtain( @@ -255,7 +265,7 @@ func (c *LeGoClient) SaveCert() error { if tlscert, _, err := ParseCert(); err != nil { return err } else { - TLSCert = tlscert + LoadCertificate(tlscert) } return nil } @@ -278,6 +288,7 @@ func ParseCert() (*tls.Certificate, *time.Time, error) { } return &cert, &parseCert.NotAfter, nil } + func PrivateCert() error { // 创建一个RSA密钥对 priv, _ := rsa.GenerateKey(rand.Reader, 2048) @@ -313,10 +324,64 @@ func PrivateCert() error { if err != nil { return err } - TLSCert = &cert + LoadCertificate(&cert) return nil } +func getTempCertificate() (*tls.Certificate, error) { + var err error + var cert tls.Certificate + if tempCert == nil { + cert, err = selfsign.GenerateSelfSignedWithDNS("localhost") + tempCert = &cert + } + return tempCert, err +} + +func GetCertificateBySNI(commonName string) (*tls.Certificate, error) { + // Copy from tls.Config getCertificate() + name := strings.ToLower(commonName) + if cert, ok := nameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := nameToCertificate[wildcardName]; ok { + return cert, nil + } + } + return getTempCertificate() +} + +func LoadCertificate(cert *tls.Certificate) { + buildNameToCertificate(cert) +} + +// Copy from tls.Config BuildNameToCertificate() +func buildNameToCertificate(cert *tls.Certificate) { + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } + startTime := x509Cert.NotBefore.String() + expiredTime := x509Cert.NotAfter.String() + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + commonName := x509Cert.Subject.CommonName + fmt.Printf("┏ Load Certificate: %s\n", commonName) + fmt.Printf("┠╌╌ Start Time: %s\n", startTime) + fmt.Printf("┖╌╌ Expired Time: %s\n", expiredTime) + nameToCertificate[commonName] = cert + } + for _, san := range x509Cert.DNSNames { + fmt.Printf("┏ Load Certificate: %s\n", san) + fmt.Printf("┠╌╌ Start Time: %s\n", startTime) + fmt.Printf("┖╌╌ Expired Time: %s\n", expiredTime) + nameToCertificate[san] = cert + } +} + // func Scrypt(passwd string) string { // salt := []byte{0xc8, 0x28, 0xf2, 0x58, 0xa7, 0x6a, 0xad, 0x7b} // hashPasswd, err := scrypt.Key([]byte(passwd), salt, 1<<15, 8, 1, 32) diff --git a/server/handler/server.go b/server/handler/server.go index 1197316..609f0e8 100644 --- a/server/handler/server.go +++ b/server/handler/server.go @@ -49,8 +49,8 @@ func startTls() { NextProtos: []string{"http/1.1"}, MinVersion: tls.VersionTLS12, CipherSuites: selectedCipherSuites, - GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return dbdata.TLSCert, nil + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return dbdata.GetCertificateBySNI(chi.ServerName) }, // InsecureSkipVerify: true, }