mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-09-17 08:57:16 +08:00
支持为用户生成不同组的证书
证书管理前端增加搜索功能
This commit is contained in:
@@ -158,19 +158,24 @@ func GenerateClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
// 下载客户端 P12 证书
|
||||
func DownloadClientP12(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
groupname := r.FormValue("groupname")
|
||||
password := r.FormValue("password")
|
||||
|
||||
if username == "" {
|
||||
RespError(w, RespInternalErr, "用户名不能为空")
|
||||
return
|
||||
}
|
||||
if groupname == "" {
|
||||
RespError(w, RespInternalErr, "用户组不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// if password == "" {
|
||||
// password = "123456" // 默认密码
|
||||
// }
|
||||
|
||||
// 生成 P12 证书
|
||||
p12Data, err := dbdata.GenerateClientP12FromDB(username, password)
|
||||
p12Data, err := dbdata.GenerateClientP12FromDB(username, groupname, password)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, fmt.Sprintf("证书下载失败: %v", err))
|
||||
return
|
||||
@@ -190,8 +195,13 @@ func ChangeClientCertStatus(w http.ResponseWriter, r *http.Request) {
|
||||
RespError(w, RespInternalErr, "用户名不能为空")
|
||||
return
|
||||
}
|
||||
groupname := r.FormValue("groupname")
|
||||
if groupname == "" {
|
||||
RespError(w, RespInternalErr, "用户组不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
clientCert, err := dbdata.GetClientCert(username)
|
||||
clientCert, err := dbdata.GetClientCert(username, groupname)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, "证书不存在")
|
||||
return
|
||||
@@ -211,30 +221,6 @@ func ChangeClientCertStatus(w http.ResponseWriter, r *http.Request) {
|
||||
RespSucess(w, fmt.Sprintf("证书%s成功", statusText))
|
||||
}
|
||||
|
||||
// // 禁用客户端证书
|
||||
// func DisableClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
// username := r.FormValue("username")
|
||||
// if username == "" {
|
||||
// RespError(w, RespInternalErr, "用户名不能为空")
|
||||
// return
|
||||
// }
|
||||
|
||||
// // 获取证书并禁用
|
||||
// clientCert, err := dbdata.GetClientCert(username)
|
||||
// if err != nil {
|
||||
// RespError(w, RespInternalErr, "证书不存在")
|
||||
// return
|
||||
// }
|
||||
|
||||
// err = clientCert.Disable()
|
||||
// if err != nil {
|
||||
// RespError(w, RespInternalErr, fmt.Sprintf("证书禁用失败: %v", err))
|
||||
// return
|
||||
// }
|
||||
|
||||
// RespSucess(w, "证书禁用成功")
|
||||
// }
|
||||
|
||||
// 删除客户端证书
|
||||
func DeleteClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
@@ -242,8 +228,13 @@ func DeleteClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
RespError(w, RespInternalErr, "用户名不能为空")
|
||||
return
|
||||
}
|
||||
groupname := r.FormValue("groupname")
|
||||
if groupname == "" {
|
||||
RespError(w, RespInternalErr, "用户组不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
clientCert, err := dbdata.GetClientCert(username)
|
||||
clientCert, err := dbdata.GetClientCert(username, groupname)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, "证书不存在")
|
||||
return
|
||||
@@ -258,28 +249,6 @@ func DeleteClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
RespSucess(w, "证书删除成功")
|
||||
}
|
||||
|
||||
// // 启用客户端证书
|
||||
// func EnableClientCert(w http.ResponseWriter, r *http.Request) {
|
||||
// username := r.FormValue("username")
|
||||
// if username == "" {
|
||||
// RespError(w, RespInternalErr, "用户名不能为空")
|
||||
// return
|
||||
// }
|
||||
|
||||
// clientCert, err := dbdata.GetClientCert(username)
|
||||
// if err != nil {
|
||||
// RespError(w, RespInternalErr, "证书不存在")
|
||||
// return
|
||||
// }
|
||||
|
||||
// if err := clientCert.Enable(); err != nil {
|
||||
// RespError(w, RespInternalErr, fmt.Sprintf("证书启用失败: %v", err))
|
||||
// return
|
||||
// }
|
||||
|
||||
// RespSucess(w, nil)
|
||||
// }
|
||||
|
||||
// 获取客户端证书列表
|
||||
func GetClientCertList(w http.ResponseWriter, r *http.Request) {
|
||||
pageSize := 10
|
||||
@@ -297,7 +266,12 @@ func GetClientCertList(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
certs, total, err := dbdata.GetClientCertList(pageSize, pageIndex)
|
||||
// 添加搜索参数
|
||||
username := r.FormValue("username")
|
||||
groupname := r.FormValue("groupname")
|
||||
status := r.FormValue("status")
|
||||
|
||||
certs, total, err := dbdata.GetClientCertList(pageSize, pageIndex, username, groupname, status)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, fmt.Sprintf("获取证书列表失败: %v", err))
|
||||
return
|
||||
|
@@ -12,6 +12,7 @@ import (
|
||||
"math/big"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -23,7 +24,7 @@ import (
|
||||
type ClientCertData struct {
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
Username string `json:"username" xorm:"varchar(60) not null"`
|
||||
GroupName string `json:"groupname" xorm:"varchar(60)"`
|
||||
Groupname string `json:"groupname" xorm:"varchar(60) not null"`
|
||||
Status int `json:"status" xorm:"int default 0"`
|
||||
Certificate string `json:"certificate" xorm:"text not null"`
|
||||
PrivateKey string `json:"private_key" xorm:"text not null"`
|
||||
@@ -119,10 +120,24 @@ func (c *ClientCertData) CheckAndUpdateStatus() error {
|
||||
}
|
||||
|
||||
// 获取客户端证书列表
|
||||
func GetClientCertList(pageSize int, pageIndex int) ([]ClientCertData, int64, error) {
|
||||
func GetClientCertList(pageSize, pageIndex int, username, groupname, status string) ([]ClientCertData, int64, error) {
|
||||
var certs []ClientCertData
|
||||
session := GetXdb().NewSession()
|
||||
defer session.Close()
|
||||
|
||||
session = session.Where("1=1")
|
||||
// 添加搜索条件
|
||||
if username != "" {
|
||||
session.And("username LIKE ?", "%"+username+"%")
|
||||
}
|
||||
if groupname != "" {
|
||||
session.And("groupname LIKE ?", "%"+groupname+"%")
|
||||
}
|
||||
if status != "" {
|
||||
if statusInt, err := strconv.Atoi(status); err == nil {
|
||||
session.And("status = ?", statusInt)
|
||||
}
|
||||
}
|
||||
total, err := FindAndCount(session, &certs, pageSize, pageIndex)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("获取客户端证书列表失败: %v", err)
|
||||
@@ -131,11 +146,15 @@ func GetClientCertList(pageSize int, pageIndex int) ([]ClientCertData, int64, er
|
||||
}
|
||||
|
||||
// 获取客户端证书
|
||||
func GetClientCert(username string) (*ClientCertData, error) {
|
||||
clientCert := &ClientCertData{
|
||||
Username: username,
|
||||
func GetClientCert(username, groupname string) (*ClientCertData, error) {
|
||||
clientCert := &ClientCertData{}
|
||||
has, err := GetXdb().Where("username = ? AND groupname = ?", username, groupname).Get(clientCert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
err := One("Username", username, clientCert)
|
||||
return clientCert, err
|
||||
}
|
||||
|
||||
@@ -196,14 +215,14 @@ func GenerateClientCert(username, groupname string) (*ClientCertData, error) {
|
||||
return nil, fmt.Errorf("用户 %s 不属于组 %s", username, groupname)
|
||||
}
|
||||
// 检查是否已存在证书记录
|
||||
_, err = GetClientCert(username)
|
||||
_, err = GetClientCert(username, groupname)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
return nil, fmt.Errorf("获取用户证书失败: %v", err)
|
||||
}
|
||||
} else {
|
||||
// 用户已有证书记录,不允许重复生成
|
||||
return nil, fmt.Errorf("用户 %s 已存在证书,请先删除现有证书", username)
|
||||
return nil, fmt.Errorf("用户 %s 已存在证书,所在组:%s,请先删除现有证书", username, groupname)
|
||||
}
|
||||
|
||||
// 确保客户端 CA 已加载
|
||||
@@ -250,7 +269,7 @@ func GenerateClientCert(username, groupname string) (*ClientCertData, error) {
|
||||
// 保存到数据库
|
||||
clientCertData := &ClientCertData{
|
||||
Username: username,
|
||||
GroupName: groupname,
|
||||
Groupname: groupname,
|
||||
Certificate: string(certPEM),
|
||||
PrivateKey: string(keyPEM),
|
||||
SerialNumber: template.SerialNumber.String(),
|
||||
@@ -267,9 +286,9 @@ func GenerateClientCert(username, groupname string) (*ClientCertData, error) {
|
||||
}
|
||||
|
||||
// 生成 PKCS#12 格式证书文件
|
||||
func GenerateClientP12FromDB(username string, password string) ([]byte, error) {
|
||||
func GenerateClientP12FromDB(username, groupname, password string) ([]byte, error) {
|
||||
// 从数据库获取证书
|
||||
clientCert, err := GetClientCert(username)
|
||||
clientCert, err := GetClientCert(username, groupname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -332,13 +351,13 @@ func ValidateClientCert(cert *x509.Certificate, userAgent string) bool {
|
||||
}
|
||||
|
||||
// 获取客户端证书记录
|
||||
clientCertData, err := GetClientCert(user.Username)
|
||||
clientCertData, err := GetClientCert(user.Username, cert.Subject.OrganizationalUnit[0])
|
||||
if err != nil {
|
||||
base.Error("证书验证失败:获取客户端证书失败:", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if clientCertData.GroupName != cert.Subject.OrganizationalUnit[0] {
|
||||
if clientCertData.Groupname != cert.Subject.OrganizationalUnit[0] {
|
||||
base.Error("证书验证失败:证书组名与用户组名不匹配")
|
||||
return false
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ func TestGenerateClientCert(t *testing.T) {
|
||||
ast.Nil(err)
|
||||
ast.NotNil(certData)
|
||||
ast.Equal(username, certData.Username)
|
||||
ast.Equal(group, certData.GroupName)
|
||||
ast.Equal(group, certData.Groupname)
|
||||
ast.Equal(CertStatusActive, certData.Status)
|
||||
ast.NotEmpty(certData.Certificate)
|
||||
ast.NotEmpty(certData.PrivateKey)
|
||||
@@ -141,7 +141,7 @@ func TestValidateClientCert(t *testing.T) {
|
||||
ast.Nil(err)
|
||||
ast.NotNil(certData)
|
||||
ast.Equal(username, certData.Username)
|
||||
ast.Equal(group, certData.GroupName)
|
||||
ast.Equal(group, certData.Groupname)
|
||||
|
||||
// 解析生成的证书
|
||||
cert, err := parseCertFromPEM(certData.Certificate)
|
||||
|
Reference in New Issue
Block a user