mirror of https://github.com/bjdgyc/anylink.git
feat:根据SNI返回SSL证书
This commit is contained in:
parent
8798de0d6d
commit
609a893feb
|
@ -46,7 +46,7 @@ func CustomCert(w http.ResponseWriter, r *http.Request) {
|
|||
RespError(w, RespInternalErr, fmt.Sprintf("证书不合法,请重新上传:%v", err))
|
||||
return
|
||||
} else {
|
||||
dbdata.TLSCert = tlscert
|
||||
dbdata.LoadCertificate(tlscert)
|
||||
}
|
||||
RespSucess(w, "上传成功")
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ func StartAdmin() {
|
|||
base.Error(err)
|
||||
return
|
||||
} else {
|
||||
dbdata.TLSCert = tlscert
|
||||
dbdata.LoadCertificate(tlscert)
|
||||
}
|
||||
|
||||
// 设置tls信息
|
||||
|
@ -112,8 +112,8 @@ func StartAdmin() {
|
|||
NextProtos: []string{"http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: selectedCipherSuites,
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return dbdata.TLSCert, nil
|
||||
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return dbdata.GetCertificateBySNI(chi.ServerName)
|
||||
},
|
||||
}
|
||||
srv := &http.Server{
|
||||
|
|
|
@ -12,9 +12,11 @@ import (
|
|||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -30,7 +32,14 @@ import (
|
|||
"github.com/go-acme/lego/v4/registration"
|
||||
)
|
||||
|
||||
var TLSCert *tls.Certificate
|
||||
var nameToCertificate = make(map[string]*tls.Certificate)
|
||||
|
||||
var tempCert *tls.Certificate
|
||||
|
||||
func init() {
|
||||
c, _ := selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||
tempCert = &c
|
||||
}
|
||||
|
||||
type SettingLetsEncrypt struct {
|
||||
Domain string `json:"domain"`
|
||||
|
@ -200,6 +209,7 @@ func (c *LeGoClient) NewClient(l *SettingLetsEncrypt) error {
|
|||
c.Client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LeGoClient) GetCert(domain string) error {
|
||||
// 申请证书
|
||||
certificates, err := c.Client.Certificate.Obtain(
|
||||
|
@ -255,7 +265,7 @@ func (c *LeGoClient) SaveCert() error {
|
|||
if tlscert, _, err := ParseCert(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
TLSCert = tlscert
|
||||
LoadCertificate(tlscert)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -278,6 +288,7 @@ func ParseCert() (*tls.Certificate, *time.Time, error) {
|
|||
}
|
||||
return &cert, &parseCert.NotAfter, nil
|
||||
}
|
||||
|
||||
func PrivateCert() error {
|
||||
// 创建一个RSA密钥对
|
||||
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
@ -313,10 +324,64 @@ func PrivateCert() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
TLSCert = &cert
|
||||
LoadCertificate(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTempCertificate() (*tls.Certificate, error) {
|
||||
var err error
|
||||
var cert tls.Certificate
|
||||
if tempCert == nil {
|
||||
cert, err = selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||
tempCert = &cert
|
||||
}
|
||||
return tempCert, err
|
||||
}
|
||||
|
||||
func GetCertificateBySNI(commonName string) (*tls.Certificate, error) {
|
||||
// Copy from tls.Config getCertificate()
|
||||
name := strings.ToLower(commonName)
|
||||
if cert, ok := nameToCertificate[name]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
if len(name) > 0 {
|
||||
labels := strings.Split(name, ".")
|
||||
labels[0] = "*"
|
||||
wildcardName := strings.Join(labels, ".")
|
||||
if cert, ok := nameToCertificate[wildcardName]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
return getTempCertificate()
|
||||
}
|
||||
|
||||
func LoadCertificate(cert *tls.Certificate) {
|
||||
buildNameToCertificate(cert)
|
||||
}
|
||||
|
||||
// Copy from tls.Config BuildNameToCertificate()
|
||||
func buildNameToCertificate(cert *tls.Certificate) {
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
startTime := x509Cert.NotBefore.String()
|
||||
expiredTime := x509Cert.NotAfter.String()
|
||||
if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 {
|
||||
commonName := x509Cert.Subject.CommonName
|
||||
fmt.Printf("┏ Load Certificate: %s\n", commonName)
|
||||
fmt.Printf("┠╌╌ Start Time: %s\n", startTime)
|
||||
fmt.Printf("┖╌╌ Expired Time: %s\n", expiredTime)
|
||||
nameToCertificate[commonName] = cert
|
||||
}
|
||||
for _, san := range x509Cert.DNSNames {
|
||||
fmt.Printf("┏ Load Certificate: %s\n", san)
|
||||
fmt.Printf("┠╌╌ Start Time: %s\n", startTime)
|
||||
fmt.Printf("┖╌╌ Expired Time: %s\n", expiredTime)
|
||||
nameToCertificate[san] = cert
|
||||
}
|
||||
}
|
||||
|
||||
// func Scrypt(passwd string) string {
|
||||
// salt := []byte{0xc8, 0x28, 0xf2, 0x58, 0xa7, 0x6a, 0xad, 0x7b}
|
||||
// hashPasswd, err := scrypt.Key([]byte(passwd), salt, 1<<15, 8, 1, 32)
|
||||
|
|
|
@ -49,8 +49,8 @@ func startTls() {
|
|||
NextProtos: []string{"http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: selectedCipherSuites,
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return dbdata.TLSCert, nil
|
||||
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return dbdata.GetCertificateBySNI(chi.ServerName)
|
||||
},
|
||||
// InsecureSkipVerify: true,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue