增加测试用例

This commit is contained in:
wsczx
2025-08-20 18:19:42 +08:00
parent 033a8e1749
commit 1fcbad6665
3 changed files with 547 additions and 1 deletions

View File

@@ -66,6 +66,9 @@ func (c *ClientCertData) GetStatus() int {
// 保存客户端证书
func (c *ClientCertData) Save() error {
if c.Id > 0 {
return Set(c) // 更新现有记录
}
return Add(c)
}
@@ -178,8 +181,22 @@ func GenerateClientCA() error {
// 生成客户端证书并保存到数据库
func GenerateClientCert(username, groupname string) (*ClientCertData, error) {
// 检查用户是否存在并验证组成员资格
user := &User{}
err := One("Username", username, user)
if err != nil {
if errors.Is(err, ErrNotFound) {
return nil, fmt.Errorf("用户不存在: %s", username)
}
return nil, fmt.Errorf("获取用户信息失败: %v", err)
}
// 检查用户是否属于指定组
if !slices.Contains(user.Groups, groupname) {
return nil, fmt.Errorf("用户 %s 不属于组 %s", username, groupname)
}
// 检查是否已存在证书记录
_, err := GetClientCert(username)
_, err = GetClientCert(username)
if err != nil {
if !errors.Is(err, ErrNotFound) {
return nil, fmt.Errorf("获取用户证书失败: %v", err)

View File

@@ -0,0 +1,194 @@
package dbdata
import (
"crypto/x509"
"encoding/pem"
"fmt"
"testing"
"github.com/bjdgyc/anylink/base"
"github.com/stretchr/testify/assert"
)
func TestGenerateClientCert(t *testing.T) {
base.Test()
ast := assert.New(t)
// 设置临时目录用于测试
tempDir := t.TempDir()
base.Cfg.ClientCertCAFile = tempDir + "/client_ca.pem"
base.Cfg.ClientCertCAKeyFile = tempDir + "/client_ca_key.pem"
preIpData()
defer closeIpdata()
// 使用 GenerateClientCA 生成 CA
err := GenerateClientCA()
ast.Nil(err, "生成客户端 CA 失败")
// 创建测试组
group := "cert-test-group"
dns := []ValData{{Val: "8.8.8.8"}}
g := Group{Name: group, Status: 1, ClientDns: dns}
err = SetGroup(&g)
ast.Nil(err)
// 创建测试用户
username := "cert-test-user"
u := User{Username: username, Groups: []string{group}, Status: 1}
err = SetUser(&u)
ast.Nil(err)
// 测试证书生成成功
certData, err := GenerateClientCert(username, group)
ast.Nil(err)
ast.NotNil(certData)
ast.Equal(username, certData.Username)
ast.Equal(group, certData.GroupName)
ast.Equal(CertStatusActive, certData.Status)
ast.NotEmpty(certData.Certificate)
ast.NotEmpty(certData.PrivateKey)
ast.NotEmpty(certData.SerialNumber)
// 测试重复生成证书失败
_, err = GenerateClientCert(username, group)
ast.NotNil(err)
ast.Contains(err.Error(), "已存在证书")
// 测试用户不属于指定组
_, err = GenerateClientCert(username, "nonexistent-group")
ast.NotNil(err)
ast.Contains(err.Error(), "不属于组")
// 测试用户不存在
_, err = GenerateClientCert("nonexistent-user", group)
ast.NotNil(err)
ast.Contains(err.Error(), "用户不存在")
}
func TestCertificateAuthFlow(t *testing.T) {
base.Test()
ast := assert.New(t)
preIpData()
defer closeIpdata()
// 设置测试环境
group := "auth-test-group"
username := "auth-test-user"
// 创建组和用户
dns := []ValData{{Val: "8.8.8.8"}}
g := Group{Name: group, Status: 1, ClientDns: dns}
err := SetGroup(&g)
ast.Nil(err)
u := User{Username: username, Groups: []string{group}, Status: 1}
err = SetUser(&u)
ast.Nil(err)
// 生成证书
certData, err := GenerateClientCert(username, group)
ast.Nil(err)
// 解析证书
cert, err := parseCertFromPEM(certData.Certificate)
ast.Nil(err)
// 证书验证
valid := ValidateClientCert(cert, "test-agent")
ast.True(valid)
// 测试证书状态变更
certData.Status = CertStatusDisabled
err = certData.UpdateStatus(CertStatusDisabled)
ast.Nil(err)
valid = ValidateClientCert(cert, "test-agent")
ast.False(valid)
}
func TestValidateClientCert(t *testing.T) {
base.Test()
ast := assert.New(t)
// 设置临时目录用于测试
tempDir := t.TempDir()
base.Cfg.ClientCertCAFile = tempDir + "/client_ca.pem"
base.Cfg.ClientCertCAKeyFile = tempDir + "/client_ca_key.pem"
preIpData()
defer closeIpdata()
// 初始化客户端 CA
err := GenerateClientCA()
ast.Nil(err, "初始化客户端 CA 失败")
// 创建测试组
group := "test-group"
dns := []ValData{{Val: "8.8.8.8"}}
g := Group{Name: group, Status: 1, ClientDns: dns}
err = SetGroup(&g)
ast.Nil(err)
// 创建测试用户
username := "test-user"
u := User{Username: username, Groups: []string{group}, Status: 1}
err = SetUser(&u)
ast.Nil(err)
// 生成客户端证书
certData, err := GenerateClientCert(username, group)
ast.Nil(err)
ast.NotNil(certData)
ast.Equal(username, certData.Username)
ast.Equal(group, certData.GroupName)
// 解析生成的证书
cert, err := parseCertFromPEM(certData.Certificate)
ast.Nil(err)
ast.Equal(username, cert.Subject.CommonName)
ast.Equal(group, cert.Subject.OrganizationalUnit[0])
// 测试证书验证成功
valid := ValidateClientCert(cert, "test-agent")
ast.True(valid)
// 测试用户不存在的情况
cert.Subject.CommonName = "nonexistent-user"
valid = ValidateClientCert(cert, "test-agent")
ast.False(valid)
// 测试用户被禁用的情况
cert.Subject.CommonName = username
u.Status = 0
err = SetUser(&u)
ast.Nil(err)
valid = ValidateClientCert(cert, "test-agent")
ast.False(valid)
// 恢复用户状态
u.Status = 1
err = SetUser(&u)
ast.Nil(err)
// 测试证书组不匹配的情况
cert.Subject.OrganizationalUnit[0] = "wrong-group"
valid = ValidateClientCert(cert, "test-agent")
ast.False(valid)
// 测试证书状态被禁用的情况
cert.Subject.OrganizationalUnit[0] = group
certData.Status = CertStatusDisabled
err = certData.Save()
ast.Nil(err)
valid = ValidateClientCert(cert, "test-agent")
ast.False(valid)
}
func parseCertFromPEM(certPEM string) (*x509.Certificate, error) {
block, _ := pem.Decode([]byte(certPEM))
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}
return x509.ParseCertificate(block.Bytes)
}