diff --git a/server/dbdata/group.go b/server/dbdata/group.go index 4effbd2..52524d1 100644 --- a/server/dbdata/group.go +++ b/server/dbdata/group.go @@ -1,7 +1,6 @@ package dbdata import ( - "encoding/json" "errors" "fmt" "net" @@ -33,11 +32,6 @@ type ValData struct { Note string `json:"note"` } -type AuthRadius struct { - Addr string `json:"addr"` - Secret string `json:"secret"` -} - // type Group struct { // Id int `json:"id" xorm:"pk autoincr not null"` // Name string `json:"name" xorm:"varchar(60) not null unique"` @@ -154,23 +148,26 @@ func SetGroup(g *Group) error { if err != nil { return errors.New("排除域名有误:" + err.Error()) } - // 处理认证方式的逻辑 + // 处理登入方式的逻辑 defAuth := map[string]interface{}{ "type": "local", } if len(g.Auth) == 0 { g.Auth = defAuth } - switch g.Auth["type"] { - case "local": + authType := g.Auth["type"].(string) + if authType == "local" { g.Auth = defAuth - case "radius": - err = checkRadiusData(g.Auth) + } else { + _, ok := authRegistry[authType] + if !ok { + return errors.New("未知的认证方式: " + fmt.Sprintf("%s", g.Auth["type"])) + } + auth := makeInstance(authType).(IUserAuth) + err = auth.checkData(g.Auth) if err != nil { return err } - default: - return errors.New("#" + fmt.Sprintf("%s", g.Auth["type"]) + "#未知的认证类型") } g.UpdatedAt = time.Now() @@ -195,23 +192,6 @@ func parseIpNet(s string) (string, *net.IPNet, error) { return ipMask, ipNet, nil } -func checkRadiusData(auth map[string]interface{}) error { - radisConf := AuthRadius{} - bodyBytes, err := json.Marshal(auth["radius"]) - if err != nil { - return errors.New("Radius的密钥/服务器地址填写有误") - } - json.Unmarshal(bodyBytes, &radisConf) - if !ValidateIpPort(radisConf.Addr) { - return errors.New("Radius的服务器地址填写有误") - } - // freeradius官网最大8000字符, 这里限制200 - if len(radisConf.Secret) < 8 || len(radisConf.Secret) > 200 { - return errors.New("Radius的密钥长度需在8~200个字符之间") - } - return nil -} - func CheckDomainNames(domains string) error { if domains == "" { return nil @@ -232,8 +212,3 @@ func ValidateDomainName(domain string) bool { RegExp := regexp.MustCompile(`^([a-zA-Z0-9][-a-zA-Z0-9]{0,62}\.)+[A-Za-z]{2,18}$`) return RegExp.MatchString(domain) } - -func ValidateIpPort(addr string) bool { - RegExp := regexp.MustCompile(`^(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\:([0-9]|[1-9]\d{1,3}|[1-5]\d{4}|6[0-5]{2}[0-3][0-5])$$`) - return RegExp.MatchString(addr) -} diff --git a/server/dbdata/group_test.go b/server/dbdata/group_test.go index ee1dee1..7d27c7f 100644 --- a/server/dbdata/group_test.go +++ b/server/dbdata/group_test.go @@ -24,8 +24,25 @@ func TestGetGroupNames(t *testing.T) { err = SetGroup(&g3) ast.Nil(err) + authData := map[string]interface{}{ + "type": "radius", + "radius": map[string]string{ + "addr": "192.168.8.12:1044", + "secret": "43214132", + }, + } + g4 := Group{Name: "g4", ClientDns: []ValData{{Val: "114.114.114.114"}}, Auth: authData} + err = SetGroup(&g4) + ast.Nil(err) + g5 := Group{Name: "g5", ClientDns: []ValData{{Val: "114.114.114.114"}}, DsIncludeDomains: "baidu.com,163.com"} + err = SetGroup(&g5) + ast.Nil(err) + g6 := Group{Name: "g6", ClientDns: []ValData{{Val: "114.114.114.114"}}, DsExcludeDomains: "com.cn,qq.com"} + err = SetGroup(&g6) + ast.Nil(err) + // 判断所有数据 - gAll := []string{"g1", "g2", "g3"} + gAll := []string{"g1", "g2", "g3", "g4", "g5", "g6"} gs := GetGroupNames() for _, v := range gs { ast.Equal(true, utils.InArrStr(gAll, v)) diff --git a/server/dbdata/user.go b/server/dbdata/user.go index 3db7461..7834013 100644 --- a/server/dbdata/user.go +++ b/server/dbdata/user.go @@ -1,8 +1,6 @@ package dbdata import ( - "context" - "encoding/json" "errors" "fmt" "sync" @@ -10,8 +8,6 @@ import ( "github.com/bjdgyc/anylink/pkg/utils" "github.com/xlzd/gotp" - "layeh.com/radius" - "layeh.com/radius/rfc2865" ) // type User struct { @@ -70,30 +66,33 @@ func SetUser(v *User) error { return err } -// 验证用户登陆信息 +// 验证用户登录信息 func CheckUser(name, pwd, group string) error { // 获取登入的group数据 groupData := &Group{} err := One("Name", group, groupData) - if err != nil { - return fmt.Errorf("%s %s", name, "No用户组") + if err != nil || groupData.Status != 1 { + return fmt.Errorf("%s - %s", name, "用户组错误") } // 初始化Auth if len(groupData.Auth) == 0 { groupData.Auth["type"] = "local" } - switch groupData.Auth["type"] { - case "", "local": + authType := groupData.Auth["type"].(string) + // 本地认证方式 + if authType == "local" { return checkLocalUser(name, pwd, group) - case "radius": - return checkRadiusUser(name, pwd, groupData.Auth) - default: - return fmt.Errorf("%s %s", name, "无效的认证类型") } - return nil + // 其它认证方式, 支持自定义 + _, ok := authRegistry[authType] + if !ok { + return fmt.Errorf("%s %s", "未知的认证方式: ", authType) + } + auth := makeInstance(authType).(IUserAuth) + return auth.checkUser(name, pwd, groupData) } -// 验证本地用户登陆信息 +// 验证本地用户登录信息 func checkLocalUser(name, pwd, group string) error { // TODO 严重问题 // return nil @@ -111,12 +110,6 @@ func checkLocalUser(name, pwd, group string) error { if !utils.InArrStr(v.Groups, group) { return fmt.Errorf("%s %s", name, "用户组错误") } - groupData := &Group{} - err = One("Name", group, groupData) - if err != nil || groupData.Status != 1 { - return fmt.Errorf("%s - %s", name, "用户组错误") - } - // 判断otp信息 pinCode := pwd if !v.DisableOtp { @@ -135,35 +128,6 @@ func checkLocalUser(name, pwd, group string) error { return nil } -func checkRadiusUser(name string, pwd string, auth map[string]interface{}) error { - if _, ok := auth["radius"]; !ok { - fmt.Errorf("%s %s", name, "Radius的radius值不存在") - } - radiusConf := AuthRadius{} - bodyBytes, err := json.Marshal(auth["radius"]) - if err != nil { - fmt.Errorf("%s %s", name, "Radius Marshal出现错误") - } - err = json.Unmarshal(bodyBytes, &radiusConf) - if err != nil { - fmt.Errorf("%s %s", name, "Radius Unmarshal出现错误") - } - // radius认证时,设置超时3秒 - packet := radius.New(radius.CodeAccessRequest, []byte(radiusConf.Secret)) - rfc2865.UserName_SetString(packet, name) - rfc2865.UserPassword_SetString(packet, pwd) - ctx, done := context.WithTimeout(context.Background(), 3*time.Second) - defer done() - response, err := radius.Exchange(ctx, packet, radiusConf.Addr) - if err != nil { - return fmt.Errorf("%s %s", name, "Radius服务器连接异常, 请检测服务器和端口") - } - if response.Code != radius.CodeAccessAccept { - return fmt.Errorf("%s %s", name, "Radius:用户名或密码错误") - } - return nil -} - var ( userOtpMux = sync.Mutex{} userOtp = map[string]time.Time{} diff --git a/server/dbdata/user_test.go b/server/dbdata/user_test.go index e076672..545ca45 100644 --- a/server/dbdata/user_test.go +++ b/server/dbdata/user_test.go @@ -40,4 +40,22 @@ func TestCheckUser(t *testing.T) { _ = SetUser(&u) err = CheckUser("aaa", u.PinCode, group) ast.Nil(err) + + // 添加一个radius组 + group2 := "group2" + authData := map[string]interface{}{ + "type": "radius", + "radius": map[string]string{ + "addr": "192.168.1.12:1044", + "secret": "43214132", + }, + } + g2 := Group{Name: group2, Status: 1, ClientDns: dns, RouteInclude: route, Auth: authData} + err = SetGroup(&g2) + ast.Nil(err) + err = CheckUser("aaa", "bbbbbbb", group2) + if ast.NotNil(err) { + ast.Equal("aaa Radius服务器连接异常, 请检测服务器和端口", err.Error()) + + } } diff --git a/server/dbdata/userauth.go b/server/dbdata/userauth.go new file mode 100644 index 0000000..fbc3eb5 --- /dev/null +++ b/server/dbdata/userauth.go @@ -0,0 +1,23 @@ +package dbdata + +import ( + "reflect" + "regexp" +) + +var authRegistry = make(map[string]reflect.Type) + +type IUserAuth interface { + checkData(authData map[string]interface{}) error + checkUser(name, pwd string, g *Group) error +} + +func makeInstance(name string) interface{} { + v := reflect.New(authRegistry[name]).Elem() + return v.Interface() +} + +func ValidateIpPort(addr string) bool { + RegExp := regexp.MustCompile(`^(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\:([0-9]|[1-9]\d{1,3}|[1-5]\d{4}|6[0-5]{2}[0-3][0-5])$$`) + return RegExp.MatchString(addr) +} diff --git a/server/dbdata/userauth_radius.go b/server/dbdata/userauth_radius.go new file mode 100644 index 0000000..4d15eb1 --- /dev/null +++ b/server/dbdata/userauth_radius.go @@ -0,0 +1,72 @@ +package dbdata + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "time" + + "layeh.com/radius" + "layeh.com/radius/rfc2865" +) + +type AuthRadius struct { + Addr string `json:"addr"` + Secret string `json:"secret"` +} + +func init() { + authRegistry["radius"] = reflect.TypeOf(AuthRadius{}) +} + +func (auth AuthRadius) checkData(authData map[string]interface{}) error { + authType := authData["type"].(string) + bodyBytes, err := json.Marshal(authData[authType]) + if err != nil { + return errors.New("Radius的密钥/服务器地址填写有误") + } + json.Unmarshal(bodyBytes, &auth) + if !ValidateIpPort(auth.Addr) { + return errors.New("Radius的服务器地址填写有误") + } + // freeradius官网最大8000字符, 这里限制200 + if len(auth.Secret) < 8 || len(auth.Secret) > 200 { + return errors.New("Radius的密钥长度需在8~200个字符之间") + } + return nil +} + +func (auth AuthRadius) checkUser(name, pwd string, g *Group) error { + pl := len(pwd) + if name == "" || pl < 1 { + return fmt.Errorf("%s %s", name, "密码错误") + } + authType := g.Auth["type"].(string) + if _, ok := g.Auth[authType]; !ok { + return fmt.Errorf("%s %s", name, "Radius的radius值不存在") + } + bodyBytes, err := json.Marshal(g.Auth[authType]) + if err != nil { + return fmt.Errorf("%s %s", name, "Radius Marshal出现错误") + } + err = json.Unmarshal(bodyBytes, &auth) + if err != nil { + return fmt.Errorf("%s %s", name, "Radius Unmarshal出现错误") + } + // radius认证时,设置超时3秒 + packet := radius.New(radius.CodeAccessRequest, []byte(auth.Secret)) + rfc2865.UserName_SetString(packet, name) + rfc2865.UserPassword_SetString(packet, pwd) + ctx, done := context.WithTimeout(context.Background(), 3*time.Second) + defer done() + response, err := radius.Exchange(ctx, packet, auth.Addr) + if err != nil { + return fmt.Errorf("%s %s", name, "Radius服务器连接异常, 请检测服务器和端口") + } + if response.Code != radius.CodeAccessAccept { + return fmt.Errorf("%s %s", name, "Radius:用户名或密码错误") + } + return nil +}