mirror of https://github.com/bjdgyc/anylink.git
240 lines
6.3 KiB
Go
240 lines
6.3 KiB
Go
package dbdata
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/bjdgyc/anylink/base"
|
||
)
|
||
|
||
const (
|
||
Allow = "allow"
|
||
Deny = "deny"
|
||
All = "all"
|
||
)
|
||
|
||
type GroupLinkAcl struct {
|
||
// 自上而下匹配 默认 allow * *
|
||
Action string `json:"action"` // allow、deny
|
||
Val string `json:"val"`
|
||
Port uint16 `json:"port"`
|
||
IpNet *net.IPNet `json:"ip_net"`
|
||
Note string `json:"note"`
|
||
}
|
||
|
||
type ValData struct {
|
||
Val string `json:"val"`
|
||
IpMask string `json:"ip_mask"`
|
||
Note string `json:"note"`
|
||
}
|
||
|
||
type AuthRadius struct {
|
||
Addr string `json:"addr"`
|
||
Secret string `json:"secret"`
|
||
}
|
||
|
||
// type Group struct {
|
||
// Id int `json:"id" xorm:"pk autoincr not null"`
|
||
// Name string `json:"name" xorm:"varchar(60) not null unique"`
|
||
// Note string `json:"note" xorm:"varchar(255)"`
|
||
// AllowLan bool `json:"allow_lan" xorm:"Bool"`
|
||
// ClientDns []ValData `json:"client_dns" xorm:"Text"`
|
||
// RouteInclude []ValData `json:"route_include" xorm:"Text"`
|
||
// RouteExclude []ValData `json:"route_exclude" xorm:"Text"`
|
||
// DsExcludeDomains string `json:"ds_exclude_domains" xorm:"Text"`
|
||
// DsIncludeDomains string `json:"ds_include_domains" xorm:"Text"`
|
||
// LinkAcl []GroupLinkAcl `json:"link_acl" xorm:"Text"`
|
||
// Bandwidth int `json:"bandwidth" xorm:"Int"` // 带宽限制
|
||
// Auth map[string]interface{} `json:"auth" xorm:"not null default '{}' varchar(255)"` // 认证方式
|
||
// Status int8 `json:"status" xorm:"Int"` // 1正常
|
||
// CreatedAt time.Time `json:"created_at" xorm:"DateTime created"`
|
||
// UpdatedAt time.Time `json:"updated_at" xorm:"DateTime updated"`
|
||
// }
|
||
|
||
func GetGroupNames() []string {
|
||
var datas []Group
|
||
err := Find(&datas, 0, 0)
|
||
if err != nil {
|
||
base.Error(err)
|
||
return nil
|
||
}
|
||
var names []string
|
||
for _, v := range datas {
|
||
names = append(names, v.Name)
|
||
}
|
||
return names
|
||
}
|
||
|
||
func SetGroup(g *Group) error {
|
||
var err error
|
||
if g.Name == "" {
|
||
return errors.New("用户组名错误")
|
||
}
|
||
|
||
// 判断数据
|
||
routeInclude := []ValData{}
|
||
for _, v := range g.RouteInclude {
|
||
if v.Val != "" {
|
||
if v.Val == All {
|
||
routeInclude = append(routeInclude, v)
|
||
continue
|
||
}
|
||
|
||
ipMask, _, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("RouteInclude 错误" + err.Error())
|
||
}
|
||
|
||
v.IpMask = ipMask
|
||
routeInclude = append(routeInclude, v)
|
||
}
|
||
}
|
||
g.RouteInclude = routeInclude
|
||
routeExclude := []ValData{}
|
||
for _, v := range g.RouteExclude {
|
||
if v.Val != "" {
|
||
ipMask, _, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("RouteExclude 错误" + err.Error())
|
||
}
|
||
v.IpMask = ipMask
|
||
routeExclude = append(routeExclude, v)
|
||
}
|
||
}
|
||
g.RouteExclude = routeExclude
|
||
// 转换数据
|
||
linkAcl := []GroupLinkAcl{}
|
||
for _, v := range g.LinkAcl {
|
||
if v.Val != "" {
|
||
_, ipNet, err := parseIpNet(v.Val)
|
||
if err != nil {
|
||
return errors.New("GroupLinkAcl 错误" + err.Error())
|
||
}
|
||
v.IpNet = ipNet
|
||
linkAcl = append(linkAcl, v)
|
||
}
|
||
}
|
||
g.LinkAcl = linkAcl
|
||
|
||
// DNS 判断
|
||
clientDns := []ValData{}
|
||
for _, v := range g.ClientDns {
|
||
if v.Val != "" {
|
||
ip := net.ParseIP(v.Val)
|
||
if ip.String() != v.Val {
|
||
return errors.New("DNS IP 错误")
|
||
}
|
||
clientDns = append(clientDns, v)
|
||
}
|
||
}
|
||
if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
|
||
if len(clientDns) == 0 {
|
||
return errors.New("默认路由,必须设置一个DNS")
|
||
}
|
||
}
|
||
g.ClientDns = clientDns
|
||
// 域名拆分隧道,不能同时填写
|
||
g.DsIncludeDomains = strings.TrimSpace(g.DsIncludeDomains)
|
||
g.DsExcludeDomains = strings.TrimSpace(g.DsExcludeDomains)
|
||
if g.DsIncludeDomains != "" && g.DsExcludeDomains != "" {
|
||
return errors.New("包含/排除域名不能同时填写")
|
||
}
|
||
// 校验包含域名的格式
|
||
err = CheckDomainNames(g.DsIncludeDomains)
|
||
if err != nil {
|
||
return errors.New("包含域名有误:" + err.Error())
|
||
}
|
||
// 校验排除域名的格式
|
||
err = CheckDomainNames(g.DsExcludeDomains)
|
||
if err != nil {
|
||
return errors.New("排除域名有误:" + err.Error())
|
||
}
|
||
// 处理认证方式的逻辑
|
||
defAuth := map[string]interface{}{
|
||
"type": "local",
|
||
}
|
||
if len(g.Auth) == 0 {
|
||
g.Auth = defAuth
|
||
}
|
||
switch g.Auth["type"] {
|
||
case "local":
|
||
g.Auth = defAuth
|
||
case "radius":
|
||
err = checkRadiusData(g.Auth)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
default:
|
||
return errors.New("#" + fmt.Sprintf("%s", g.Auth["type"]) + "#未知的认证类型")
|
||
}
|
||
|
||
g.UpdatedAt = time.Now()
|
||
if g.Id > 0 {
|
||
err = Set(g)
|
||
} else {
|
||
err = Add(g)
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
func parseIpNet(s string) (string, *net.IPNet, error) {
|
||
ip, ipNet, err := net.ParseCIDR(s)
|
||
if err != nil {
|
||
return "", nil, err
|
||
}
|
||
|
||
mask := net.IP(ipNet.Mask)
|
||
ipMask := fmt.Sprintf("%s/%s", ip, mask)
|
||
|
||
return ipMask, ipNet, nil
|
||
}
|
||
|
||
func checkRadiusData(auth map[string]interface{}) error {
|
||
radisConf := AuthRadius{}
|
||
bodyBytes, err := json.Marshal(auth["radius"])
|
||
if err != nil {
|
||
return errors.New("Radius的密钥/服务器地址填写有误")
|
||
}
|
||
json.Unmarshal(bodyBytes, &radisConf)
|
||
if !ValidateIpPort(radisConf.Addr) {
|
||
return errors.New("Radius的服务器地址填写有误")
|
||
}
|
||
// freeradius官网最大8000字符, 这里限制200
|
||
if len(radisConf.Secret) < 8 || len(radisConf.Secret) > 200 {
|
||
return errors.New("Radius的密钥长度需在8~200个字符之间")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func CheckDomainNames(domains string) error {
|
||
if domains == "" {
|
||
return nil
|
||
}
|
||
str_slice := strings.Split(domains, ",")
|
||
for _, val := range str_slice {
|
||
if val == "" {
|
||
return errors.New(val + " 请以逗号分隔域名")
|
||
}
|
||
if !ValidateDomainName(val) {
|
||
return errors.New(val + " 域名有误")
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func ValidateDomainName(domain string) bool {
|
||
RegExp := regexp.MustCompile(`^([a-zA-Z0-9][-a-zA-Z0-9]{0,62}\.)+[A-Za-z]{2,18}$`)
|
||
return RegExp.MatchString(domain)
|
||
}
|
||
|
||
func ValidateIpPort(addr string) bool {
|
||
RegExp := regexp.MustCompile(`^(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\:([0-9]|[1-9]\d{1,3}|[1-5]\d{4}|6[0-5]{2}[0-3][0-5])$$`)
|
||
return RegExp.MatchString(addr)
|
||
}
|