mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 19:22:42 +08:00
新增用户策略的功能
This commit is contained in:
98
server/admin/api_policy.go
Normal file
98
server/admin/api_policy.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bjdgyc/anylink/dbdata"
|
||||
)
|
||||
|
||||
func PolicyList(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
pageS := r.FormValue("page")
|
||||
page, _ := strconv.Atoi(pageS)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
var pageSize = dbdata.PageSize
|
||||
|
||||
count := dbdata.CountAll(&dbdata.Policy{})
|
||||
|
||||
var datas []dbdata.Policy
|
||||
err := dbdata.Find(&datas, pageSize, page)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"count": count,
|
||||
"page_size": pageSize,
|
||||
"datas": datas,
|
||||
}
|
||||
|
||||
RespSucess(w, data)
|
||||
}
|
||||
|
||||
func PolicyDetail(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
RespError(w, RespParamErr, "Id错误")
|
||||
return
|
||||
}
|
||||
|
||||
var data dbdata.Policy
|
||||
err := dbdata.One("Id", id, &data)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
|
||||
RespSucess(w, data)
|
||||
}
|
||||
|
||||
func PolicySet(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
v := &dbdata.Policy{}
|
||||
err = json.Unmarshal(body, v)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dbdata.SetPolicy(v)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
|
||||
RespSucess(w, nil)
|
||||
}
|
||||
|
||||
func PolicyDel(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
RespError(w, RespParamErr, "Id错误")
|
||||
return
|
||||
}
|
||||
|
||||
data := dbdata.Policy{Id: id}
|
||||
err := dbdata.Del(&data)
|
||||
if err != nil {
|
||||
RespError(w, RespInternalErr, err)
|
||||
return
|
||||
}
|
||||
RespSucess(w, nil)
|
||||
}
|
@@ -52,6 +52,10 @@ func StartAdmin() {
|
||||
r.HandleFunc("/user/ip_map/detail", UserIpMapDetail)
|
||||
r.HandleFunc("/user/ip_map/set", UserIpMapSet)
|
||||
r.HandleFunc("/user/ip_map/del", UserIpMapDel)
|
||||
r.HandleFunc("/user/policy/list", PolicyList)
|
||||
r.HandleFunc("/user/policy/detail", PolicyDetail)
|
||||
r.HandleFunc("/user/policy/set", PolicySet)
|
||||
r.HandleFunc("/user/policy/del", PolicyDel)
|
||||
|
||||
r.HandleFunc("/group/list", GroupList)
|
||||
r.HandleFunc("/group/names", GroupNames)
|
||||
|
@@ -25,7 +25,7 @@ func initDb() {
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
err = xdb.Sync2(&User{}, &Setting{}, &Group{}, &IpMap{}, &AccessAudit{})
|
||||
err = xdb.Sync2(&User{}, &Setting{}, &Group{}, &IpMap{}, &AccessAudit{}, &Policy{})
|
||||
if err != nil {
|
||||
base.Fatal(err)
|
||||
}
|
||||
|
@@ -161,7 +161,7 @@ func SetGroup(g *Group) error {
|
||||
} else {
|
||||
_, ok := authRegistry[authType]
|
||||
if !ok {
|
||||
return errors.New("未知的认证方式: " + fmt.Sprintf("%s", g.Auth["type"]))
|
||||
return errors.New("未知的认证方式: " + authType)
|
||||
}
|
||||
auth := makeInstance(authType).(IUserAuth)
|
||||
err = auth.checkData(g.Auth)
|
||||
|
101
server/dbdata/policy.go
Normal file
101
server/dbdata/policy.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package dbdata
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetPolicy(Username string) *Policy {
|
||||
policyData := &Policy{}
|
||||
err := One("Username", Username, policyData)
|
||||
if err != nil {
|
||||
return policyData
|
||||
}
|
||||
return policyData
|
||||
}
|
||||
|
||||
func SetPolicy(p *Policy) error {
|
||||
var err error
|
||||
if p.Username == "" {
|
||||
return errors.New("用户名错误")
|
||||
}
|
||||
|
||||
// 包含路由
|
||||
routeInclude := []ValData{}
|
||||
for _, v := range p.RouteInclude {
|
||||
if v.Val != "" {
|
||||
if v.Val == All {
|
||||
routeInclude = append(routeInclude, v)
|
||||
continue
|
||||
}
|
||||
|
||||
ipMask, _, err := parseIpNet(v.Val)
|
||||
if err != nil {
|
||||
return errors.New("RouteInclude 错误" + err.Error())
|
||||
}
|
||||
|
||||
v.IpMask = ipMask
|
||||
routeInclude = append(routeInclude, v)
|
||||
}
|
||||
}
|
||||
p.RouteInclude = routeInclude
|
||||
// 包含路由
|
||||
routeExclude := []ValData{}
|
||||
for _, v := range p.RouteExclude {
|
||||
if v.Val != "" {
|
||||
ipMask, _, err := parseIpNet(v.Val)
|
||||
if err != nil {
|
||||
return errors.New("RouteExclude 错误" + err.Error())
|
||||
}
|
||||
v.IpMask = ipMask
|
||||
routeExclude = append(routeExclude, v)
|
||||
}
|
||||
}
|
||||
p.RouteExclude = routeExclude
|
||||
|
||||
// DNS 判断
|
||||
clientDns := []ValData{}
|
||||
for _, v := range p.ClientDns {
|
||||
if v.Val != "" {
|
||||
ip := net.ParseIP(v.Val)
|
||||
if ip.String() != v.Val {
|
||||
return errors.New("DNS IP 错误")
|
||||
}
|
||||
clientDns = append(clientDns, v)
|
||||
}
|
||||
}
|
||||
if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
|
||||
if len(clientDns) == 0 {
|
||||
return errors.New("默认路由,必须设置一个DNS")
|
||||
}
|
||||
}
|
||||
p.ClientDns = clientDns
|
||||
|
||||
// 域名拆分隧道,不能同时填写
|
||||
p.DsIncludeDomains = strings.TrimSpace(p.DsIncludeDomains)
|
||||
p.DsExcludeDomains = strings.TrimSpace(p.DsExcludeDomains)
|
||||
if p.DsIncludeDomains != "" && p.DsExcludeDomains != "" {
|
||||
return errors.New("包含/排除域名不能同时填写")
|
||||
}
|
||||
// 校验包含域名的格式
|
||||
err = CheckDomainNames(p.DsIncludeDomains)
|
||||
if err != nil {
|
||||
return errors.New("包含域名有误:" + err.Error())
|
||||
}
|
||||
// 校验排除域名的格式
|
||||
err = CheckDomainNames(p.DsExcludeDomains)
|
||||
if err != nil {
|
||||
return errors.New("排除域名有误:" + err.Error())
|
||||
}
|
||||
|
||||
p.UpdatedAt = time.Now()
|
||||
if p.Id > 0 {
|
||||
err = Set(p)
|
||||
} else {
|
||||
err = Add(p)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
@@ -68,3 +68,17 @@ type AccessAudit struct {
|
||||
DstPort uint16 `json:"dst_port" xorm:"not null"`
|
||||
CreatedAt time.Time `json:"created_at" xorm:"DateTime"`
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
Username string `json:"username" xorm:"varchar(60) not null unique"`
|
||||
AllowLan bool `json:"allow_lan" xorm:"Bool"`
|
||||
ClientDns []ValData `json:"client_dns" xorm:"Text"`
|
||||
RouteInclude []ValData `json:"route_include" xorm:"Text"`
|
||||
RouteExclude []ValData `json:"route_exclude" xorm:"Text"`
|
||||
DsExcludeDomains string `json:"ds_exclude_domains" xorm:"Text"`
|
||||
DsIncludeDomains string `json:"ds_include_domains" xorm:"Text"`
|
||||
Status int8 `json:"status" xorm:"Int"` // 1正常 0 禁用
|
||||
CreatedAt time.Time `json:"created_at" xorm:"DateTime created"`
|
||||
UpdatedAt time.Time `json:"updated_at" xorm:"DateTime updated"`
|
||||
}
|
||||
|
@@ -58,4 +58,12 @@ func TestCheckUser(t *testing.T) {
|
||||
ast.Equal("aaa Radius服务器连接异常, 请检测服务器和端口", err.Error())
|
||||
|
||||
}
|
||||
// 添加用户策略
|
||||
dns2 := []ValData{{Val: "8.8.8.8"}}
|
||||
route2 := []ValData{{Val: "192.168.2.1/24"}}
|
||||
p1 := Policy{Username: "aaa", Status: 1, ClientDns: dns2, RouteInclude: route2}
|
||||
err = SetPolicy(&p1)
|
||||
ast.Nil(err)
|
||||
err = CheckUser("aaa", u.PinCode, group)
|
||||
ast.Nil(err)
|
||||
}
|
||||
|
@@ -100,6 +100,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
//HttpSetHeader(w, "X-CSTP-Default-Domain", cSess.LocalIp)
|
||||
HttpSetHeader(w, "X-CSTP-Base-MTU", cstpBaseMtu)
|
||||
|
||||
// 设置用户策略
|
||||
SetUserPolicy(sess.Username, cSess.Group)
|
||||
|
||||
// 允许本地LAN访问vpn网络,必须放在路由的第一个
|
||||
if cSess.Group.AllowLan {
|
||||
HttpSetHeader(w, "X-CSTP-Split-Exclude", "0.0.0.0/255.255.255.255")
|
||||
@@ -209,3 +212,17 @@ func SetPostAuthXml(g *dbdata.Group, w http.ResponseWriter) error {
|
||||
HttpSetHeader(w, "X-CSTP-Post-Auth-XML", result.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// 设置用户策略, 覆盖Group的属性值
|
||||
func SetUserPolicy(username string, g *dbdata.Group) {
|
||||
userPolicy := dbdata.GetPolicy(username)
|
||||
if userPolicy.Id != 0 && userPolicy.Status == 1 {
|
||||
base.Debug(username + " use UserPolicy")
|
||||
g.AllowLan = userPolicy.AllowLan
|
||||
g.ClientDns = userPolicy.ClientDns
|
||||
g.RouteInclude = userPolicy.RouteInclude
|
||||
g.RouteExclude = userPolicy.RouteExclude
|
||||
g.DsExcludeDomains = userPolicy.DsExcludeDomains
|
||||
g.DsIncludeDomains = userPolicy.DsIncludeDomains
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user