mirror of https://github.com/bjdgyc/anylink.git
152 lines
3.3 KiB
Go
152 lines
3.3 KiB
Go
package dbdata
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"net"
|
||
"time"
|
||
|
||
"github.com/bjdgyc/anylink/base"
|
||
)
|
||
|
||
const (
|
||
Allow = "allow"
|
||
Deny = "deny"
|
||
All = "all"
|
||
)
|
||
|
||
type GroupLinkAcl struct {
|
||
// 自上而下匹配 默认 allow * *
|
||
Action string `json:"action"` // allow、deny
|
||
Val string `json:"val"`
|
||
Port uint16 `json:"port"`
|
||
IpNet *net.IPNet `json:"ip_net"`
|
||
Note string `json:"note"`
|
||
}
|
||
|
||
type ValData struct {
|
||
Val string `json:"val"`
|
||
IpMask string `json:"ip_mask"`
|
||
Note string `json:"note"`
|
||
}
|
||
|
||
// type Group struct {
|
||
// Id int `json:"id" xorm:"pk autoincr not null"`
|
||
// Name string `json:"name" xorm:"not null unique"`
|
||
// Note string `json:"note"`
|
||
// AllowLan bool `json:"allow_lan"`
|
||
// ClientDns []ValData `json:"client_dns"`
|
||
// RouteInclude []ValData `json:"route_include"`
|
||
// RouteExclude []ValData `json:"route_exclude"`
|
||
// LinkAcl []GroupLinkAcl `json:"link_acl"`
|
||
// Bandwidth int `json:"bandwidth"` // 带宽限制
|
||
// Status int8 `json:"status"` // 1正常
|
||
// CreatedAt time.Time `json:"created_at"`
|
||
// UpdatedAt time.Time `json:"updated_at"`
|
||
// }
|
||
|
||
func GetGroupNames() []string {
|
||
var datas []Group
|
||
err := Find(&datas, 0, 0)
|
||
if err != nil {
|
||
base.Error(err)
|
||
return nil
|
||
}
|
||
var names []string
|
||
for _, v := range datas {
|
||
names = append(names, v.Name)
|
||
}
|
||
return names
|
||
}
|
||
|
||
func SetGroup(g *Group) error {
|
||
var err error
|
||
if g.Name == "" {
|
||
return errors.New("用户组名错误")
|
||
}
|
||
|
||
// 判断数据
|
||
routeInclude := []ValData{}
|
||
for _, v := range g.RouteInclude {
|
||
if v.Val != "" {
|
||
if v.Val == All {
|
||
routeInclude = append(routeInclude, v)
|
||
continue
|
||
}
|
||
|
||
ipMask, _, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("RouteInclude 错误" + err.Error())
|
||
}
|
||
|
||
v.IpMask = ipMask
|
||
routeInclude = append(routeInclude, v)
|
||
}
|
||
}
|
||
g.RouteInclude = routeInclude
|
||
routeExclude := []ValData{}
|
||
for _, v := range g.RouteExclude {
|
||
if v.Val != "" {
|
||
ipMask, _, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("RouteExclude 错误" + err.Error())
|
||
}
|
||
v.IpMask = ipMask
|
||
routeExclude = append(routeExclude, v)
|
||
}
|
||
}
|
||
g.RouteExclude = routeExclude
|
||
// 转换数据
|
||
linkAcl := []GroupLinkAcl{}
|
||
for _, v := range g.LinkAcl {
|
||
if v.Val != "" {
|
||
_, ipNet, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("GroupLinkAcl 错误" + err.Error())
|
||
}
|
||
v.IpNet = ipNet
|
||
linkAcl = append(linkAcl, v)
|
||
}
|
||
}
|
||
g.LinkAcl = linkAcl
|
||
|
||
// DNS 判断
|
||
clientDns := []ValData{}
|
||
for _, v := range g.ClientDns {
|
||
if v.Val != "" {
|
||
ip := net.ParseIP(v.Val)
|
||
if ip.String() != v.Val {
|
||
return errors.New("DNS IP 错误")
|
||
}
|
||
clientDns = append(clientDns, v)
|
||
}
|
||
}
|
||
if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
|
||
if len(clientDns) == 0 {
|
||
return errors.New("默认路由,必须设置一个DNS")
|
||
}
|
||
}
|
||
g.ClientDns = clientDns
|
||
|
||
g.UpdatedAt = time.Now()
|
||
if g.Id > 0 {
|
||
err = Set(g)
|
||
} else {
|
||
err = Add(g)
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
func parseIpNet(s string) (string, *net.IPNet, error) {
|
||
ip, ipNet, err := net.ParseCIDR(s)
|
||
if err != nil {
|
||
return "", nil, err
|
||
}
|
||
|
||
mask := net.IP(ipNet.Mask)
|
||
ipMask := fmt.Sprintf("%s/%s", ip, mask)
|
||
|
||
return ipMask, ipNet, nil
|
||
}
|