优化代码

This commit is contained in:
wsczx 2023-04-06 12:29:21 +08:00
parent b3e7212b03
commit bc7c61c337
5 changed files with 211 additions and 238 deletions

View File

@ -1,37 +1,16 @@
package admin package admin
import ( import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"io" "io"
"math/big"
"net"
"net/http" "net/http"
"os" "os"
"sync"
"time"
"github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/dbdata"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
) )
type LeGoClient struct {
mutex sync.Mutex
Client *lego.Client
dbdata.LegoUserData
}
func CustomCert(w http.ResponseWriter, r *http.Request) { func CustomCert(w http.ResponseWriter, r *http.Request) {
cert, _, err := r.FormFile("cert") cert, _, err := r.FormFile("cert")
if err != nil { if err != nil {
@ -63,10 +42,7 @@ func CustomCert(w http.ResponseWriter, r *http.Request) {
RespError(w, RespInternalErr, err) RespError(w, RespInternalErr, err)
return return
} }
if tlscert, _, err := ParseCert(); err != nil { if tlscert, _, err := dbdata.ParseCert(); err != nil {
if err := PrivateCert(); err != nil {
base.Error(err)
}
RespError(w, RespInternalErr, fmt.Sprintf("证书不合法,请重新上传:%v", err)) RespError(w, RespInternalErr, fmt.Sprintf("证书不合法,请重新上传:%v", err))
return return
} else { } else {
@ -102,214 +78,16 @@ func CreatCert(w http.ResponseWriter, r *http.Request) {
RespError(w, RespInternalErr, err) RespError(w, RespInternalErr, err)
return return
} }
client := LeGoClient{} client := dbdata.LeGoClient{}
if err := client.NewClient(config); err != nil { if err := client.NewClient(config); err != nil {
base.Error(err) base.Error(err)
RespError(w, RespInternalErr, fmt.Sprintf("获取证书失败:%v", err)) RespError(w, RespInternalErr, fmt.Sprintf("获取证书失败:%v", err))
return return
} }
if err := client.GetCertificate(config.Domain); err != nil { if err := client.GetCert(config.Domain); err != nil {
base.Error(err) base.Error(err)
RespError(w, RespInternalErr, fmt.Sprintf("获取证书失败:%v", err)) RespError(w, RespInternalErr, fmt.Sprintf("获取证书失败:%v", err))
return return
} }
RespSucess(w, "生成证书成功") RespSucess(w, "生成证书成功")
} }
func ReNewCert() {
_, certtime, err := ParseCert()
if err != nil {
base.Error(err)
return
}
if certtime.AddDate(0, 0, -7).Before(time.Now()) {
config := &dbdata.SettingLetsEncrypt{}
if err := dbdata.SettingGet(config); err != nil {
base.Error(err)
return
}
if config.Domain == "" {
return
}
if config.Renew {
client := &LeGoClient{}
if err := client.NewClient(config); err != nil {
base.Error(err)
return
}
if err := client.RenewCert(base.Cfg.CertFile, base.Cfg.CertKey); err != nil {
base.Error(err)
return
}
base.Info("证书续期成功")
}
}
base.Info(fmt.Sprintf("证书过期时间:%s", certtime.Local().Format("2006-1-2 15:04:05")))
}
func (c *LeGoClient) NewClient(l *dbdata.SettingLetsEncrypt) error {
c.mutex.Lock()
defer c.mutex.Unlock()
legouser, err := c.GetUserData(l)
if err != nil {
return err
}
config := lego.NewConfig(legouser)
config.CADirURL = lego.LEDirectoryProduction
config.Certificate.KeyType = certcrypto.RSA2048
client, err := lego.NewClient(config)
if err != nil {
return err
}
Provider, err := dbdata.GetDNSProvider(l)
if err != nil {
return err
}
if err := client.Challenge.SetDNS01Provider(Provider, dns01.AddRecursiveNameservers([]string{"114.114.114.114", "114.114.115.115"})); err != nil {
return err
}
if legouser.Registration == nil {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
legouser.Registration = reg
c.SaveUserData(legouser)
}
c.Client = client
return nil
}
func (c *LeGoClient) GetCertificate(domain string) error {
// 申请证书
certificates, err := c.Client.Certificate.Obtain(
certificate.ObtainRequest{
Domains: []string{domain},
Bundle: true,
})
if err != nil {
return err
}
// 保存证书
if err := SaveCertificate(certificates); err != nil {
return err
}
return nil
}
func (c *LeGoClient) RenewCert(certFile, keyFile string) error {
cert, err := LoadCertResource(certFile, keyFile)
if err != nil {
return err
}
// 续期证书
renewcert, err := c.Client.Certificate.Renew(certificate.Resource{
Certificate: cert.Certificate,
PrivateKey: cert.PrivateKey,
}, true, false, "")
if err != nil {
return err
}
// 保存更新证书
if err := SaveCertificate(renewcert); err != nil {
return err
}
return nil
}
func SaveCertificate(cert *certificate.Resource) error {
err := os.WriteFile(base.Cfg.CertFile, cert.Certificate, 0600)
if err != nil {
return err
}
err = os.WriteFile(base.Cfg.CertKey, cert.PrivateKey, 0600)
if err != nil {
return err
}
if tlscert, _, err := ParseCert(); err != nil {
return err
} else {
dbdata.TLSCert = tlscert
}
return nil
}
func LoadCertResource(certFile, keyFile string) (*certificate.Resource, error) {
cert, err := os.ReadFile(certFile)
if err != nil {
return nil, err
}
key, err := os.ReadFile(keyFile)
if err != nil {
return nil, err
}
return &certificate.Resource{
Certificate: cert,
PrivateKey: key,
}, nil
}
func ParseCert() (*tls.Certificate, *time.Time, error) {
_, certErr := os.Stat(base.Cfg.CertFile)
_, keyErr := os.Stat(base.Cfg.CertKey)
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
PrivateCert()
}
cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
if err != nil {
return nil, nil, err
}
parseCert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, nil, err
}
certtime := parseCert.NotAfter
return &cert, &certtime, nil
}
func PrivateCert() error {
// 创建一个RSA密钥对
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
pub := &priv.PublicKey
// 生成一个自签名证书
template := x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 365),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
if err != nil {
return err
}
// 将证书编码为PEM格式并将其写入文件
certOut, _ := os.OpenFile(base.Cfg.CertFile, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
// 将私钥编码为PEM格式并将其写入文件
keyOut, _ := os.OpenFile(base.Cfg.CertKey, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
if err != nil {
return err
}
dbdata.TLSCert = &cert
return nil
}
// func Scrypt(passwd string) string {
// salt := []byte{0xc8, 0x28, 0xf2, 0x58, 0xa7, 0x6a, 0xad, 0x7b}
// hashPasswd, err := scrypt.Key([]byte(passwd), salt, 1<<15, 8, 1, 32)
// if err != nil {
// return err.Error()
// }
// return base64.StdEncoding.EncodeToString(hashPasswd)
// }

View File

@ -100,8 +100,7 @@ func StartAdmin() {
for _, s := range cipherSuites { for _, s := range cipherSuites {
selectedCipherSuites = append(selectedCipherSuites, s.ID) selectedCipherSuites = append(selectedCipherSuites, s.ID)
} }
if tlscert, _, err := dbdata.ParseCert(); err != nil {
if tlscert, _, err := ParseCert(); err != nil {
base.Error(err) base.Error(err)
return return
} else { } else {
@ -114,10 +113,6 @@ func StartAdmin() {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
CipherSuites: selectedCipherSuites, CipherSuites: selectedCipherSuites,
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
// cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
// if err != nil {
// return nil, err
// }
return dbdata.TLSCert, nil return dbdata.TLSCert, nil
}, },
} }

View File

@ -3,7 +3,7 @@ package cron
import ( import (
"time" "time"
"github.com/bjdgyc/anylink/admin" "github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/sessdata" "github.com/bjdgyc/anylink/sessdata"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
) )
@ -14,6 +14,6 @@ func Start() {
s.Cron("0 * * * *").Do(ClearStatsInfo) s.Cron("0 * * * *").Do(ClearStatsInfo)
s.Cron("0 * * * *").Do(ClearUserActLog) s.Cron("0 * * * *").Do(ClearUserActLog)
s.Every(1).Day().At("00:00").Do(sessdata.CloseUserLimittimeSession) s.Every(1).Day().At("00:00").Do(sessdata.CloseUserLimittimeSession)
s.Every(1).Day().At("00:00").Do(admin.ReNewCert) s.Every(1).Day().At("00:00").Do(dbdata.ReNewCert)
s.StartAsync() s.StartAsync()
} }

View File

@ -5,9 +5,24 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"os"
"sync"
"time"
"github.com/bjdgyc/anylink/base"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/providers/dns/alidns" "github.com/go-acme/lego/v4/providers/dns/alidns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/tencentcloud" "github.com/go-acme/lego/v4/providers/dns/tencentcloud"
@ -18,7 +33,6 @@ import (
var TLSCert *tls.Certificate var TLSCert *tls.Certificate
type SettingLetsEncrypt struct { type SettingLetsEncrypt struct {
// LegoUser LegoUser
Domain string `json:"domain"` Domain string `json:"domain"`
Legomail string `json:"legomail"` Legomail string `json:"legomail"`
Name string `json:"name"` Name string `json:"name"`
@ -52,6 +66,13 @@ type LegoUser struct {
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
} }
type LeGoClient struct {
mutex sync.Mutex
Client *lego.Client
Cert *certificate.Resource
LegoUserData
}
func GetDNSProvider(l *SettingLetsEncrypt) (Provider challenge.Provider, err error) { func GetDNSProvider(l *SettingLetsEncrypt) (Provider challenge.Provider, err error) {
switch l.Name { switch l.Name {
case "aliyun": case "aliyun":
@ -117,3 +138,186 @@ func (l *LegoUserData) GetUserData(d *SettingLetsEncrypt) (*LegoUser, error) {
Key: privateKey, Key: privateKey,
}, nil }, nil
} }
func ReNewCert() {
_, certtime, err := ParseCert()
if err != nil {
base.Error(err)
return
}
if certtime.AddDate(0, 0, -7).Before(time.Now()) {
config := &SettingLetsEncrypt{}
if err := SettingGet(config); err != nil {
base.Error(err)
return
}
if config.Renew {
client := &LeGoClient{}
if err := client.NewClient(config); err != nil {
base.Error(err)
return
}
if err := client.RenewCert(base.Cfg.CertFile, base.Cfg.CertKey); err != nil {
base.Error(err)
return
}
base.Info("证书续期成功")
}
} else {
base.Info(fmt.Sprintf("证书过期时间:%s", certtime.Local().Format("2006-1-2 15:04:05")))
}
}
func (c *LeGoClient) NewClient(l *SettingLetsEncrypt) error {
c.mutex.Lock()
defer c.mutex.Unlock()
legouser, err := c.GetUserData(l)
if err != nil {
return err
}
config := lego.NewConfig(legouser)
config.CADirURL = lego.LEDirectoryStaging
config.Certificate.KeyType = certcrypto.RSA2048
client, err := lego.NewClient(config)
if err != nil {
return err
}
Provider, err := GetDNSProvider(l)
if err != nil {
return err
}
if err := client.Challenge.SetDNS01Provider(Provider, dns01.AddRecursiveNameservers([]string{"114.114.114.114", "114.114.115.115"})); err != nil {
return err
}
if legouser.Registration == nil {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil {
return err
}
legouser.Registration = reg
c.SaveUserData(legouser)
}
c.Client = client
return nil
}
func (c *LeGoClient) GetCert(domain string) error {
// 申请证书
certificates, err := c.Client.Certificate.Obtain(
certificate.ObtainRequest{
Domains: []string{domain},
Bundle: true,
})
if err != nil {
return err
}
c.Cert = certificates
// 保存证书
if err := c.SaveCert(); err != nil {
return err
}
return nil
}
func (c *LeGoClient) RenewCert(certFile, keyFile string) error {
cert, err := os.ReadFile(certFile)
if err != nil {
return err
}
key, err := os.ReadFile(keyFile)
if err != nil {
return err
}
// 续期证书
renewcert, err := c.Client.Certificate.Renew(certificate.Resource{
Certificate: cert,
PrivateKey: key,
}, true, false, "")
if err != nil {
return err
}
c.Cert = renewcert
// 保存更新证书
if err := c.SaveCert(); err != nil {
return err
}
return nil
}
func (c *LeGoClient) SaveCert() error {
err := os.WriteFile(base.Cfg.CertFile, c.Cert.Certificate, 0600)
if err != nil {
return err
}
err = os.WriteFile(base.Cfg.CertKey, c.Cert.PrivateKey, 0600)
if err != nil {
return err
}
if tlscert, _, err := ParseCert(); err != nil {
return err
} else {
TLSCert = tlscert
}
return nil
}
func ParseCert() (*tls.Certificate, *time.Time, error) {
os.Stat(base.Cfg.CertFile)
os.Stat(base.Cfg.CertKey)
cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
if err != nil || errors.Is(err, os.ErrNotExist) {
PrivateCert()
return nil, nil, err
}
parseCert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, nil, err
}
return &cert, &parseCert.NotAfter, nil
}
func PrivateCert() error {
// 创建一个RSA密钥对
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
pub := &priv.PublicKey
// 生成一个自签名证书
template := x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 365),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
if err != nil {
return err
}
// 将证书编码为PEM格式并将其写入文件
certOut, _ := os.OpenFile(base.Cfg.CertFile, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
// 将私钥编码为PEM格式并将其写入文件
keyOut, _ := os.OpenFile(base.Cfg.CertKey, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
if err != nil {
return err
}
TLSCert = &cert
return nil
}
// func Scrypt(passwd string) string {
// salt := []byte{0xc8, 0x28, 0xf2, 0x58, 0xa7, 0x6a, 0xad, 0x7b}
// hashPasswd, err := scrypt.Key([]byte(passwd), salt, 1<<15, 8, 1, 32)
// if err != nil {
// return err.Error()
// }
// return base64.StdEncoding.EncodeToString(hashPasswd)
// }

View File

@ -50,10 +50,6 @@ func startTls() {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
CipherSuites: selectedCipherSuites, CipherSuites: selectedCipherSuites,
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
// cert, err := tls.LoadX509KeyPair(base.Cfg.CertFile, base.Cfg.CertKey)
// if err != nil {
// return nil, err
// }
return dbdata.TLSCert, nil return dbdata.TLSCert, nil
}, },
// InsecureSkipVerify: true, // InsecureSkipVerify: true,