From 1fcbad6665a05c03b204b26e2bd88eb46be3b180 Mon Sep 17 00:00:00 2001 From: wsczx Date: Wed, 20 Aug 2025 18:19:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/dbdata/cert_client.go | 19 +- server/dbdata/cert_client_test.go | 194 ++++++++++++++++ server/handler/link_auth_otp_test.go | 335 +++++++++++++++++++++++++++ 3 files changed, 547 insertions(+), 1 deletion(-) create mode 100644 server/dbdata/cert_client_test.go create mode 100644 server/handler/link_auth_otp_test.go diff --git a/server/dbdata/cert_client.go b/server/dbdata/cert_client.go index 4ddb890..758ba97 100644 --- a/server/dbdata/cert_client.go +++ b/server/dbdata/cert_client.go @@ -66,6 +66,9 @@ func (c *ClientCertData) GetStatus() int { // 保存客户端证书 func (c *ClientCertData) Save() error { + if c.Id > 0 { + return Set(c) // 更新现有记录 + } return Add(c) } @@ -178,8 +181,22 @@ func GenerateClientCA() error { // 生成客户端证书并保存到数据库 func GenerateClientCert(username, groupname string) (*ClientCertData, error) { + // 检查用户是否存在并验证组成员资格 + user := &User{} + err := One("Username", username, user) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("用户不存在: %s", username) + } + return nil, fmt.Errorf("获取用户信息失败: %v", err) + } + + // 检查用户是否属于指定组 + if !slices.Contains(user.Groups, groupname) { + return nil, fmt.Errorf("用户 %s 不属于组 %s", username, groupname) + } // 检查是否已存在证书记录 - _, err := GetClientCert(username) + _, err = GetClientCert(username) if err != nil { if !errors.Is(err, ErrNotFound) { return nil, fmt.Errorf("获取用户证书失败: %v", err) diff --git a/server/dbdata/cert_client_test.go b/server/dbdata/cert_client_test.go new file mode 100644 index 0000000..2027958 --- /dev/null +++ b/server/dbdata/cert_client_test.go @@ -0,0 +1,194 @@ +package dbdata + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "testing" + + "github.com/bjdgyc/anylink/base" + "github.com/stretchr/testify/assert" +) + +func TestGenerateClientCert(t *testing.T) { + base.Test() + ast := assert.New(t) + + // 设置临时目录用于测试 + tempDir := t.TempDir() + base.Cfg.ClientCertCAFile = tempDir + "/client_ca.pem" + base.Cfg.ClientCertCAKeyFile = tempDir + "/client_ca_key.pem" + + preIpData() + defer closeIpdata() + + // 使用 GenerateClientCA 生成 CA + err := GenerateClientCA() + ast.Nil(err, "生成客户端 CA 失败") + + // 创建测试组 + group := "cert-test-group" + dns := []ValData{{Val: "8.8.8.8"}} + g := Group{Name: group, Status: 1, ClientDns: dns} + err = SetGroup(&g) + ast.Nil(err) + + // 创建测试用户 + username := "cert-test-user" + u := User{Username: username, Groups: []string{group}, Status: 1} + err = SetUser(&u) + ast.Nil(err) + + // 测试证书生成成功 + certData, err := GenerateClientCert(username, group) + ast.Nil(err) + ast.NotNil(certData) + ast.Equal(username, certData.Username) + ast.Equal(group, certData.GroupName) + ast.Equal(CertStatusActive, certData.Status) + ast.NotEmpty(certData.Certificate) + ast.NotEmpty(certData.PrivateKey) + ast.NotEmpty(certData.SerialNumber) + + // 测试重复生成证书失败 + _, err = GenerateClientCert(username, group) + ast.NotNil(err) + ast.Contains(err.Error(), "已存在证书") + + // 测试用户不属于指定组 + _, err = GenerateClientCert(username, "nonexistent-group") + ast.NotNil(err) + ast.Contains(err.Error(), "不属于组") + + // 测试用户不存在 + _, err = GenerateClientCert("nonexistent-user", group) + ast.NotNil(err) + ast.Contains(err.Error(), "用户不存在") +} +func TestCertificateAuthFlow(t *testing.T) { + base.Test() + ast := assert.New(t) + + preIpData() + defer closeIpdata() + + // 设置测试环境 + group := "auth-test-group" + username := "auth-test-user" + + // 创建组和用户 + dns := []ValData{{Val: "8.8.8.8"}} + g := Group{Name: group, Status: 1, ClientDns: dns} + err := SetGroup(&g) + ast.Nil(err) + + u := User{Username: username, Groups: []string{group}, Status: 1} + err = SetUser(&u) + ast.Nil(err) + + // 生成证书 + certData, err := GenerateClientCert(username, group) + ast.Nil(err) + + // 解析证书 + cert, err := parseCertFromPEM(certData.Certificate) + ast.Nil(err) + + // 证书验证 + valid := ValidateClientCert(cert, "test-agent") + ast.True(valid) + + // 测试证书状态变更 + certData.Status = CertStatusDisabled + err = certData.UpdateStatus(CertStatusDisabled) + ast.Nil(err) + + valid = ValidateClientCert(cert, "test-agent") + ast.False(valid) +} + +func TestValidateClientCert(t *testing.T) { + base.Test() + ast := assert.New(t) + + // 设置临时目录用于测试 + tempDir := t.TempDir() + base.Cfg.ClientCertCAFile = tempDir + "/client_ca.pem" + base.Cfg.ClientCertCAKeyFile = tempDir + "/client_ca_key.pem" + + preIpData() + defer closeIpdata() + + // 初始化客户端 CA + err := GenerateClientCA() + ast.Nil(err, "初始化客户端 CA 失败") + + // 创建测试组 + group := "test-group" + dns := []ValData{{Val: "8.8.8.8"}} + g := Group{Name: group, Status: 1, ClientDns: dns} + err = SetGroup(&g) + ast.Nil(err) + + // 创建测试用户 + username := "test-user" + u := User{Username: username, Groups: []string{group}, Status: 1} + err = SetUser(&u) + ast.Nil(err) + + // 生成客户端证书 + certData, err := GenerateClientCert(username, group) + ast.Nil(err) + ast.NotNil(certData) + ast.Equal(username, certData.Username) + ast.Equal(group, certData.GroupName) + + // 解析生成的证书 + cert, err := parseCertFromPEM(certData.Certificate) + ast.Nil(err) + ast.Equal(username, cert.Subject.CommonName) + ast.Equal(group, cert.Subject.OrganizationalUnit[0]) + + // 测试证书验证成功 + valid := ValidateClientCert(cert, "test-agent") + ast.True(valid) + + // 测试用户不存在的情况 + cert.Subject.CommonName = "nonexistent-user" + valid = ValidateClientCert(cert, "test-agent") + ast.False(valid) + + // 测试用户被禁用的情况 + cert.Subject.CommonName = username + u.Status = 0 + err = SetUser(&u) + ast.Nil(err) + valid = ValidateClientCert(cert, "test-agent") + ast.False(valid) + + // 恢复用户状态 + u.Status = 1 + err = SetUser(&u) + ast.Nil(err) + + // 测试证书组不匹配的情况 + cert.Subject.OrganizationalUnit[0] = "wrong-group" + valid = ValidateClientCert(cert, "test-agent") + ast.False(valid) + + // 测试证书状态被禁用的情况 + cert.Subject.OrganizationalUnit[0] = group + certData.Status = CertStatusDisabled + err = certData.Save() + ast.Nil(err) + valid = ValidateClientCert(cert, "test-agent") + ast.False(valid) +} + +func parseCertFromPEM(certPEM string) (*x509.Certificate, error) { + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + return x509.ParseCertificate(block.Bytes) +} diff --git a/server/handler/link_auth_otp_test.go b/server/handler/link_auth_otp_test.go new file mode 100644 index 0000000..d573cda --- /dev/null +++ b/server/handler/link_auth_otp_test.go @@ -0,0 +1,335 @@ +package handler + +import ( + "bytes" + "encoding/xml" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" + + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" + "github.com/stretchr/testify/assert" + "github.com/xlzd/gotp" +) + +func TestSessionStore(t *testing.T) { + ast := assert.New(t) + + // 测试会话存储基本功能 + store := NewSessionStore() + sessionID := "test-session-123" + + // 创建测试会话数据 + authSession := &AuthSession{ + ClientRequest: &ClientRequest{ + Auth: auth{ + Username: "test-user", + OtpSecret: "JBSWY3DPEHPK3PXP", + }, + GroupSelect: "test-group", + }, + UserActLog: &dbdata.UserActLog{ + Username: "test-user", + Status: dbdata.UserAuthSuccess, + }, + } + + // 测试保存会话 + store.SaveAuthSession(sessionID, authSession) + + // 测试获取会话 + retrievedSession, err := store.GetAuthSession(sessionID) + ast.Nil(err) + ast.NotNil(retrievedSession) + ast.Equal("test-user", retrievedSession.ClientRequest.Auth.Username) + + // 测试获取不存在的会话 + _, err = store.GetAuthSession("nonexistent-session") + ast.NotNil(err) + ast.Contains(err.Error(), "auth session not found") + + // 测试删除会话 + store.DeleteAuthSession(sessionID) + _, err = store.GetAuthSession(sessionID) + ast.NotNil(err) +} + +func TestGenerateSessionID(t *testing.T) { + ast := assert.New(t) + + // 测试会话ID生成 + sessionID, err := GenerateSessionID() + ast.Nil(err) + ast.NotEmpty(sessionID) + ast.Equal(32, len(sessionID)) + + // 测试生成的ID唯一性 + sessionID2, err := GenerateSessionID() + ast.Nil(err) + ast.NotEqual(sessionID, sessionID2) +} + +func TestCookieOperations(t *testing.T) { + ast := assert.New(t) + + // 测试设置和获取Cookie + w := httptest.NewRecorder() + SetCookie(w, "test-cookie", "test-value", 3600) + + cookies := w.Result().Cookies() + ast.Equal(1, len(cookies)) + ast.Equal("test-cookie", cookies[0].Name) + ast.Equal("test-value", cookies[0].Value) + ast.True(cookies[0].HttpOnly) + ast.True(cookies[0].Secure) + + // 测试从请求中获取Cookie + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookies[0]) + + value, err := GetCookie(req, "test-cookie") + ast.Nil(err) + ast.Equal("test-value", value) + + // 测试获取不存在的Cookie + _, err = GetCookie(req, "nonexistent-cookie") + ast.NotNil(err) + + // 测试删除Cookie + w2 := httptest.NewRecorder() + DeleteCookie(w2, "test-cookie") + deleteCookies := w2.Result().Cookies() + ast.Equal(1, len(deleteCookies)) + ast.Equal("test-cookie", deleteCookies[0].Name) + ast.Equal("", deleteCookies[0].Value) + ast.Equal(-1, deleteCookies[0].MaxAge) +} + +func TestLinkAuthOtp(t *testing.T) { + base.Test() + ast := assert.New(t) + + base.Cfg.DisplayError = true + + // 设置测试数据库 + preIpData() + defer closeIpdata() + + // 创建测试组 + group := "otp-test-group" + dns := []dbdata.ValData{{Val: "8.8.8.8"}} + g := dbdata.Group{Name: group, Status: 1, ClientDns: dns} + err := dbdata.SetGroup(&g) + ast.Nil(err) + + // 创建测试用户 + username := "otp-test-user" + otpSecret := "JBSWY3DPEHPK3PXP" + u := dbdata.User{ + Username: username, + Groups: []string{group}, + Status: 1, + OtpSecret: otpSecret, + } + err = dbdata.SetUser(&u) + ast.Nil(err) + + // 生成有效的OTP代码 + totp := gotp.NewDefaultTOTP(otpSecret) + validOtp := totp.Now() + + // 创建测试会话 + sessionID := "test-otp-session" + authSession := &AuthSession{ + ClientRequest: &ClientRequest{ + Auth: auth{ + Username: username, + OtpSecret: otpSecret, + }, + GroupSelect: group, + UserAgent: "test-agent", + }, + UserActLog: &dbdata.UserActLog{ + Username: username, + Status: dbdata.UserAuthSuccess, + }, + } + SessStore.SaveAuthSession(sessionID, authSession) + + // 测试成功的OTP验证 + t.Run("ValidOTP", func(t *testing.T) { + ast := assert.New(t) + + // 创建OTP验证请求 + clientReq := ClientRequest{ + Auth: auth{ + SecondaryPassword: validOtp, + }, + } + reqBody, _ := xml.Marshal(clientReq) + + req := httptest.NewRequest("POST", "/otp-verification", bytes.NewReader(reqBody)) + req.AddCookie(&http.Cookie{Name: "auth-session-id", Value: sessionID}) + w := httptest.NewRecorder() + + LinkAuth_otp(w, req) + + ast.Equal(http.StatusOK, w.Code) + // 验证会话已被删除 + _, err := SessStore.GetAuthSession(sessionID) + ast.NotNil(err) + }) + + // 测试无效的OTP代码 + t.Run("InvalidOTP", func(t *testing.T) { + ast := assert.New(t) + + // 重新创建会话(因为上一个测试中被删除了) + SessStore.SaveAuthSession(sessionID+"2", authSession) + + clientReq := ClientRequest{ + Auth: auth{ + SecondaryPassword: "123456", // 无效的OTP + }, + } + reqBody, _ := xml.Marshal(clientReq) + + req := httptest.NewRequest("POST", "/otp-verification", bytes.NewReader(reqBody)) + req.AddCookie(&http.Cookie{Name: "auth-session-id", Value: sessionID + "2"}) + w := httptest.NewRecorder() + + LinkAuth_otp(w, req) + + ast.Equal(http.StatusOK, w.Code) + // 验证响应包含错误信息 + ast.Contains(w.Body.String(), "OTP 动态码错误") + }) + + // 测试无效会话 + t.Run("InvalidSession", func(t *testing.T) { + ast := assert.New(t) + + clientReq := ClientRequest{ + Auth: auth{ + SecondaryPassword: validOtp, + }, + } + reqBody, _ := xml.Marshal(clientReq) + + req := httptest.NewRequest("POST", "/otp-verification", bytes.NewReader(reqBody)) + req.AddCookie(&http.Cookie{Name: "auth-session-id", Value: "invalid-session"}) + w := httptest.NewRecorder() + + LinkAuth_otp(w, req) + + ast.Equal(http.StatusUnauthorized, w.Code) + }) + + // 测试缺少会话Cookie + t.Run("MissingSessionCookie", func(t *testing.T) { + ast := assert.New(t) + + clientReq := ClientRequest{ + Auth: auth{ + SecondaryPassword: validOtp, + }, + } + reqBody, _ := xml.Marshal(clientReq) + + req := httptest.NewRequest("POST", "/otp-verification", bytes.NewReader(reqBody)) + w := httptest.NewRecorder() + + LinkAuth_otp(w, req) + + ast.Equal(http.StatusUnauthorized, w.Code) + }) +} + +func TestCreateSession(t *testing.T) { + base.Test() + ast := assert.New(t) + + preIpData() + defer closeIpdata() + + // 创建测试数据 + group := "session-test-group" + username := "session-test-user" + + dns := []dbdata.ValData{{Val: "8.8.8.8"}} + g := dbdata.Group{Name: group, Status: 1, ClientDns: dns} + err := dbdata.SetGroup(&g) + ast.Nil(err) + + u := dbdata.User{Username: username, Groups: []string{group}, Status: 1} + err = dbdata.SetUser(&u) + ast.Nil(err) + + // 创建认证会话数据 + authSession := &AuthSession{ + ClientRequest: &ClientRequest{ + Auth: auth{ + Username: username, + }, + GroupSelect: group, + UserAgent: "test-agent", + DeviceId: deviceId{ + UniqueIdGlobal: "test-device-id", + }, + MacAddressList: macAddressList{ + MacAddress: "00:11:22:33:44:55", + }, + RemoteAddr: "192.168.1.100", + }, + UserActLog: &dbdata.UserActLog{ + Username: username, + Status: dbdata.UserAuthSuccess, + DeviceType: "test-device", + PlatformVersion: "test-platform", + }, + } + + // 测试会话创建 + req := httptest.NewRequest("POST", "/", nil) + req.RemoteAddr = "192.168.1.100:12345" + w := httptest.NewRecorder() + + CreateSession(w, req, authSession) + + ast.Equal(http.StatusOK, w.Code) + // 验证响应包含会话信息 + ast.Contains(w.Body.String(), "session-token") +} + +func preIpData() { + // 设置测试模式 + base.Test() + + // 创建临时数据库文件 + tmpDb := path.Join(os.TempDir(), "anylink_otp_test.db") + + // 设置数据库配置 + base.Cfg.DbType = "sqlite3" + base.Cfg.DbSource = tmpDb + + // 设置其他必要的配置 + base.Cfg.Ipv4CIDR = "192.168.3.0/24" + base.Cfg.Ipv4Gateway = "192.168.3.1" + base.Cfg.Ipv4Start = "192.168.3.100" + base.Cfg.Ipv4End = "192.168.3.150" + base.Cfg.MaxClient = 100 + base.Cfg.MaxUserClient = 3 + base.Cfg.IpLease = 5 + + // 启动数据库 + dbdata.Start() +} + +func closeIpdata() { + _ = dbdata.Stop() + tmpDb := path.Join(os.TempDir(), "anylink_otp_test.db") + os.Remove(tmpDb) +}