mirror of
				https://github.com/bjdgyc/anylink.git
				synced 2025-11-04 11:06:22 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			192 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			192 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package dbdata
 | 
						||
 | 
						||
import (
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"net"
 | 
						||
	"regexp"
 | 
						||
	"strings"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/bjdgyc/anylink/base"
 | 
						||
)
 | 
						||
 | 
						||
const (
 | 
						||
	Allow = "allow"
 | 
						||
	Deny  = "deny"
 | 
						||
	All   = "all"
 | 
						||
)
 | 
						||
 | 
						||
type GroupLinkAcl struct {
 | 
						||
	// 自上而下匹配 默认 allow * *
 | 
						||
	Action string     `json:"action"` // allow、deny
 | 
						||
	Val    string     `json:"val"`
 | 
						||
	Port   uint16     `json:"port"`
 | 
						||
	IpNet  *net.IPNet `json:"ip_net"`
 | 
						||
	Note   string     `json:"note"`
 | 
						||
}
 | 
						||
 | 
						||
type ValData struct {
 | 
						||
	Val    string `json:"val"`
 | 
						||
	IpMask string `json:"ip_mask"`
 | 
						||
	Note   string `json:"note"`
 | 
						||
}
 | 
						||
 | 
						||
// type Group struct {
 | 
						||
// 	Id           int            `json:"id" xorm:"pk autoincr not null"`
 | 
						||
// 	Name         string         `json:"name" xorm:"not null unique"`
 | 
						||
// 	Note         string         `json:"note"`
 | 
						||
// 	AllowLan     bool           `json:"allow_lan"`
 | 
						||
// 	ClientDns    []ValData      `json:"client_dns"`
 | 
						||
// 	RouteInclude []ValData      `json:"route_include"`
 | 
						||
// 	RouteExclude []ValData      `json:"route_exclude"`
 | 
						||
// 	LinkAcl      []GroupLinkAcl `json:"link_acl"`
 | 
						||
// 	Bandwidth    int            `json:"bandwidth"` // 带宽限制
 | 
						||
// 	Status       int8           `json:"status"`    // 1正常
 | 
						||
// 	CreatedAt    time.Time      `json:"created_at"`
 | 
						||
// 	UpdatedAt    time.Time      `json:"updated_at"`
 | 
						||
// }
 | 
						||
 | 
						||
func GetGroupNames() []string {
 | 
						||
	var datas []Group
 | 
						||
	err := Find(&datas, 0, 0)
 | 
						||
	if err != nil {
 | 
						||
		base.Error(err)
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	var names []string
 | 
						||
	for _, v := range datas {
 | 
						||
		names = append(names, v.Name)
 | 
						||
	}
 | 
						||
	return names
 | 
						||
}
 | 
						||
 | 
						||
func SetGroup(g *Group) error {
 | 
						||
	var err error
 | 
						||
	if g.Name == "" {
 | 
						||
		return errors.New("用户组名错误")
 | 
						||
	}
 | 
						||
 | 
						||
	// 判断数据
 | 
						||
	routeInclude := []ValData{}
 | 
						||
	for _, v := range g.RouteInclude {
 | 
						||
		if v.Val != "" {
 | 
						||
			if v.Val == All {
 | 
						||
				routeInclude = append(routeInclude, v)
 | 
						||
				continue
 | 
						||
			}
 | 
						||
 | 
						||
			ipMask, _, err := parseIpNet(v.Val)
 | 
						||
			if err != nil {
 | 
						||
				return errors.New("RouteInclude 错误" + err.Error())
 | 
						||
			}
 | 
						||
 | 
						||
			v.IpMask = ipMask
 | 
						||
			routeInclude = append(routeInclude, v)
 | 
						||
		}
 | 
						||
	}
 | 
						||
	g.RouteInclude = routeInclude
 | 
						||
	routeExclude := []ValData{}
 | 
						||
	for _, v := range g.RouteExclude {
 | 
						||
		if v.Val != "" {
 | 
						||
			ipMask, _, err := parseIpNet(v.Val)
 | 
						||
			if err != nil {
 | 
						||
				return errors.New("RouteExclude 错误" + err.Error())
 | 
						||
			}
 | 
						||
			v.IpMask = ipMask
 | 
						||
			routeExclude = append(routeExclude, v)
 | 
						||
		}
 | 
						||
	}
 | 
						||
	g.RouteExclude = routeExclude
 | 
						||
	// 转换数据
 | 
						||
	linkAcl := []GroupLinkAcl{}
 | 
						||
	for _, v := range g.LinkAcl {
 | 
						||
		if v.Val != "" {
 | 
						||
			_, ipNet, err := parseIpNet(v.Val)
 | 
						||
			if err != nil {
 | 
						||
				return errors.New("GroupLinkAcl 错误" + err.Error())
 | 
						||
			}
 | 
						||
			v.IpNet = ipNet
 | 
						||
			linkAcl = append(linkAcl, v)
 | 
						||
		}
 | 
						||
	}
 | 
						||
	g.LinkAcl = linkAcl
 | 
						||
 | 
						||
	// DNS 判断
 | 
						||
	clientDns := []ValData{}
 | 
						||
	for _, v := range g.ClientDns {
 | 
						||
		if v.Val != "" {
 | 
						||
			ip := net.ParseIP(v.Val)
 | 
						||
			if ip.String() != v.Val {
 | 
						||
				return errors.New("DNS IP 错误")
 | 
						||
			}
 | 
						||
			clientDns = append(clientDns, v)
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
 | 
						||
		if len(clientDns) == 0 {
 | 
						||
			return errors.New("默认路由,必须设置一个DNS")
 | 
						||
		}
 | 
						||
	}
 | 
						||
	g.ClientDns = clientDns
 | 
						||
	// 域名拆分隧道,不能同时填写
 | 
						||
	if g.DsIncludeDomains != "" && g.DsExcludeDomains != "" {
 | 
						||
		return errors.New("包含/排除域名不能同时填写")
 | 
						||
	}
 | 
						||
	// 校验包含域名的格式
 | 
						||
	err = CheckDomainNames(g.DsIncludeDomains)
 | 
						||
	if err != nil {
 | 
						||
		return errors.New("包含域名有误:" + err.Error())
 | 
						||
	}
 | 
						||
	// 校验排除域名的格式
 | 
						||
	err = CheckDomainNames(g.DsExcludeDomains)
 | 
						||
	if err != nil {
 | 
						||
		return errors.New("排除域名有误:" + err.Error())
 | 
						||
	}
 | 
						||
 | 
						||
	g.UpdatedAt = time.Now()
 | 
						||
	if g.Id > 0 {
 | 
						||
		err = Set(g)
 | 
						||
	} else {
 | 
						||
		err = Add(g)
 | 
						||
	}
 | 
						||
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func parseIpNet(s string) (string, *net.IPNet, error) {
 | 
						||
	ip, ipNet, err := net.ParseCIDR(s)
 | 
						||
	if err != nil {
 | 
						||
		return "", nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	mask := net.IP(ipNet.Mask)
 | 
						||
	ipMask := fmt.Sprintf("%s/%s", ip, mask)
 | 
						||
 | 
						||
	return ipMask, ipNet, nil
 | 
						||
}
 | 
						||
func CheckDomainNames(domains string) error {
 | 
						||
	if domains == "" {
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	str_slice := strings.Split(domains, ",")
 | 
						||
	for _, val := range str_slice {
 | 
						||
		if val == "" {
 | 
						||
			return errors.New(val + " 请以逗号分隔域名")
 | 
						||
		}
 | 
						||
		if !ValidateDomainName(val) {
 | 
						||
			return errors.New(val + " 域名有误")
 | 
						||
		}
 | 
						||
	}
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func ValidateDomainName(domain string) bool {
 | 
						||
	pos := strings.LastIndex(domain, ".")
 | 
						||
	if pos != -1 && len(domain[pos+1:]) < 2 {
 | 
						||
		return false
 | 
						||
	}
 | 
						||
	RegExp := regexp.MustCompile(`^[a-zA-Z0-9][-a-zA-Z0-9]{0,62}(\.[a-zA-Z0-9][-a-zA-Z0-9]{0,62})+$`)
 | 
						||
	return RegExp.MatchString(domain)
 | 
						||
}
 |