mirror of https://github.com/bjdgyc/anylink.git
1.重构认证方式的代码,方便未来扩展 2.补充测试用例
This commit is contained in:
parent
c38f1e9b8c
commit
b06c035cce
|
@ -1,7 +1,6 @@
|
||||||
package dbdata
|
package dbdata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -33,11 +32,6 @@ type ValData struct {
|
||||||
Note string `json:"note"`
|
Note string `json:"note"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthRadius struct {
|
|
||||||
Addr string `json:"addr"`
|
|
||||||
Secret string `json:"secret"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// type Group struct {
|
// type Group struct {
|
||||||
// Id int `json:"id" xorm:"pk autoincr not null"`
|
// Id int `json:"id" xorm:"pk autoincr not null"`
|
||||||
// Name string `json:"name" xorm:"varchar(60) not null unique"`
|
// Name string `json:"name" xorm:"varchar(60) not null unique"`
|
||||||
|
@ -154,23 +148,26 @@ func SetGroup(g *Group) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("排除域名有误:" + err.Error())
|
return errors.New("排除域名有误:" + err.Error())
|
||||||
}
|
}
|
||||||
// 处理认证方式的逻辑
|
// 处理登入方式的逻辑
|
||||||
defAuth := map[string]interface{}{
|
defAuth := map[string]interface{}{
|
||||||
"type": "local",
|
"type": "local",
|
||||||
}
|
}
|
||||||
if len(g.Auth) == 0 {
|
if len(g.Auth) == 0 {
|
||||||
g.Auth = defAuth
|
g.Auth = defAuth
|
||||||
}
|
}
|
||||||
switch g.Auth["type"] {
|
authType := g.Auth["type"].(string)
|
||||||
case "local":
|
if authType == "local" {
|
||||||
g.Auth = defAuth
|
g.Auth = defAuth
|
||||||
case "radius":
|
} else {
|
||||||
err = checkRadiusData(g.Auth)
|
_, ok := authRegistry[authType]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("未知的认证方式: " + fmt.Sprintf("%s", g.Auth["type"]))
|
||||||
|
}
|
||||||
|
auth := makeInstance(authType).(IUserAuth)
|
||||||
|
err = auth.checkData(g.Auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return errors.New("#" + fmt.Sprintf("%s", g.Auth["type"]) + "#未知的认证类型")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g.UpdatedAt = time.Now()
|
g.UpdatedAt = time.Now()
|
||||||
|
@ -195,23 +192,6 @@ func parseIpNet(s string) (string, *net.IPNet, error) {
|
||||||
return ipMask, ipNet, nil
|
return ipMask, ipNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRadiusData(auth map[string]interface{}) error {
|
|
||||||
radisConf := AuthRadius{}
|
|
||||||
bodyBytes, err := json.Marshal(auth["radius"])
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("Radius的密钥/服务器地址填写有误")
|
|
||||||
}
|
|
||||||
json.Unmarshal(bodyBytes, &radisConf)
|
|
||||||
if !ValidateIpPort(radisConf.Addr) {
|
|
||||||
return errors.New("Radius的服务器地址填写有误")
|
|
||||||
}
|
|
||||||
// freeradius官网最大8000字符, 这里限制200
|
|
||||||
if len(radisConf.Secret) < 8 || len(radisConf.Secret) > 200 {
|
|
||||||
return errors.New("Radius的密钥长度需在8~200个字符之间")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckDomainNames(domains string) error {
|
func CheckDomainNames(domains string) error {
|
||||||
if domains == "" {
|
if domains == "" {
|
||||||
return nil
|
return nil
|
||||||
|
@ -232,8 +212,3 @@ func ValidateDomainName(domain string) bool {
|
||||||
RegExp := regexp.MustCompile(`^([a-zA-Z0-9][-a-zA-Z0-9]{0,62}\.)+[A-Za-z]{2,18}$`)
|
RegExp := regexp.MustCompile(`^([a-zA-Z0-9][-a-zA-Z0-9]{0,62}\.)+[A-Za-z]{2,18}$`)
|
||||||
return RegExp.MatchString(domain)
|
return RegExp.MatchString(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateIpPort(addr string) bool {
|
|
||||||
RegExp := regexp.MustCompile(`^(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\:([0-9]|[1-9]\d{1,3}|[1-5]\d{4}|6[0-5]{2}[0-3][0-5])$$`)
|
|
||||||
return RegExp.MatchString(addr)
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,8 +24,25 @@ func TestGetGroupNames(t *testing.T) {
|
||||||
err = SetGroup(&g3)
|
err = SetGroup(&g3)
|
||||||
ast.Nil(err)
|
ast.Nil(err)
|
||||||
|
|
||||||
|
authData := map[string]interface{}{
|
||||||
|
"type": "radius",
|
||||||
|
"radius": map[string]string{
|
||||||
|
"addr": "192.168.8.12:1044",
|
||||||
|
"secret": "43214132",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g4 := Group{Name: "g4", ClientDns: []ValData{{Val: "114.114.114.114"}}, Auth: authData}
|
||||||
|
err = SetGroup(&g4)
|
||||||
|
ast.Nil(err)
|
||||||
|
g5 := Group{Name: "g5", ClientDns: []ValData{{Val: "114.114.114.114"}}, DsIncludeDomains: "baidu.com,163.com"}
|
||||||
|
err = SetGroup(&g5)
|
||||||
|
ast.Nil(err)
|
||||||
|
g6 := Group{Name: "g6", ClientDns: []ValData{{Val: "114.114.114.114"}}, DsExcludeDomains: "com.cn,qq.com"}
|
||||||
|
err = SetGroup(&g6)
|
||||||
|
ast.Nil(err)
|
||||||
|
|
||||||
// 判断所有数据
|
// 判断所有数据
|
||||||
gAll := []string{"g1", "g2", "g3"}
|
gAll := []string{"g1", "g2", "g3", "g4", "g5", "g6"}
|
||||||
gs := GetGroupNames()
|
gs := GetGroupNames()
|
||||||
for _, v := range gs {
|
for _, v := range gs {
|
||||||
ast.Equal(true, utils.InArrStr(gAll, v))
|
ast.Equal(true, utils.InArrStr(gAll, v))
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package dbdata
|
package dbdata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -10,8 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/bjdgyc/anylink/pkg/utils"
|
"github.com/bjdgyc/anylink/pkg/utils"
|
||||||
"github.com/xlzd/gotp"
|
"github.com/xlzd/gotp"
|
||||||
"layeh.com/radius"
|
|
||||||
"layeh.com/radius/rfc2865"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// type User struct {
|
// type User struct {
|
||||||
|
@ -82,15 +78,18 @@ func CheckUser(name, pwd, group string) error {
|
||||||
if len(groupData.Auth) == 0 {
|
if len(groupData.Auth) == 0 {
|
||||||
groupData.Auth["type"] = "local"
|
groupData.Auth["type"] = "local"
|
||||||
}
|
}
|
||||||
switch groupData.Auth["type"] {
|
authType := groupData.Auth["type"].(string)
|
||||||
case "", "local":
|
// 本地认证方式
|
||||||
|
if authType == "local" {
|
||||||
return checkLocalUser(name, pwd, group)
|
return checkLocalUser(name, pwd, group)
|
||||||
case "radius":
|
|
||||||
return checkRadiusUser(name, pwd, groupData.Auth)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%s %s", name, "无效的认证类型")
|
|
||||||
}
|
}
|
||||||
return nil
|
// 其它认证方式, 支持自定义
|
||||||
|
_, ok := authRegistry[authType]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%s %s", "未知的认证方式: ", authType)
|
||||||
|
}
|
||||||
|
auth := makeInstance(authType).(IUserAuth)
|
||||||
|
return auth.checkUser(name, pwd, groupData.Auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证本地用户登陆信息
|
// 验证本地用户登陆信息
|
||||||
|
@ -135,35 +134,6 @@ func checkLocalUser(name, pwd, group string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRadiusUser(name string, pwd string, auth map[string]interface{}) error {
|
|
||||||
if _, ok := auth["radius"]; !ok {
|
|
||||||
fmt.Errorf("%s %s", name, "Radius的radius值不存在")
|
|
||||||
}
|
|
||||||
radiusConf := AuthRadius{}
|
|
||||||
bodyBytes, err := json.Marshal(auth["radius"])
|
|
||||||
if err != nil {
|
|
||||||
fmt.Errorf("%s %s", name, "Radius Marshal出现错误")
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(bodyBytes, &radiusConf)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Errorf("%s %s", name, "Radius Unmarshal出现错误")
|
|
||||||
}
|
|
||||||
// radius认证时,设置超时3秒
|
|
||||||
packet := radius.New(radius.CodeAccessRequest, []byte(radiusConf.Secret))
|
|
||||||
rfc2865.UserName_SetString(packet, name)
|
|
||||||
rfc2865.UserPassword_SetString(packet, pwd)
|
|
||||||
ctx, done := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer done()
|
|
||||||
response, err := radius.Exchange(ctx, packet, radiusConf.Addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s %s", name, "Radius服务器连接异常, 请检测服务器和端口")
|
|
||||||
}
|
|
||||||
if response.Code != radius.CodeAccessAccept {
|
|
||||||
return fmt.Errorf("%s %s", name, "Radius:用户名或密码错误")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userOtpMux = sync.Mutex{}
|
userOtpMux = sync.Mutex{}
|
||||||
userOtp = map[string]time.Time{}
|
userOtp = map[string]time.Time{}
|
||||||
|
|
|
@ -40,4 +40,22 @@ func TestCheckUser(t *testing.T) {
|
||||||
_ = SetUser(&u)
|
_ = SetUser(&u)
|
||||||
err = CheckUser("aaa", u.PinCode, group)
|
err = CheckUser("aaa", u.PinCode, group)
|
||||||
ast.Nil(err)
|
ast.Nil(err)
|
||||||
|
|
||||||
|
// 添加一个radius组
|
||||||
|
group2 := "group2"
|
||||||
|
authData := map[string]interface{}{
|
||||||
|
"type": "radius",
|
||||||
|
"radius": map[string]string{
|
||||||
|
"addr": "192.168.1.12:1044",
|
||||||
|
"secret": "43214132",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g2 := Group{Name: group2, Status: 1, ClientDns: dns, RouteInclude: route, Auth: authData}
|
||||||
|
err = SetGroup(&g2)
|
||||||
|
ast.Nil(err)
|
||||||
|
err = CheckUser("aaa", "bbbbbbb", group2)
|
||||||
|
if ast.NotNil(err) {
|
||||||
|
ast.Equal("aaa Radius服务器连接异常, 请检测服务器和端口", err.Error())
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package dbdata
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
var authRegistry = make(map[string]reflect.Type)
|
||||||
|
|
||||||
|
type IUserAuth interface {
|
||||||
|
checkData(authData map[string]interface{}) error
|
||||||
|
checkUser(name string, pwd string, authData map[string]interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeInstance(name string) interface{} {
|
||||||
|
v := reflect.New(authRegistry[name]).Elem()
|
||||||
|
return v.Interface()
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package dbdata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"layeh.com/radius"
|
||||||
|
"layeh.com/radius/rfc2865"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthRadius struct {
|
||||||
|
Addr string `json:"addr"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
authRegistry["radius"] = reflect.TypeOf(AuthRadius{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth AuthRadius) checkData(authData map[string]interface{}) error {
|
||||||
|
authType := authData["type"].(string)
|
||||||
|
bodyBytes, err := json.Marshal(authData[authType])
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("Radius的密钥/服务器地址填写有误")
|
||||||
|
}
|
||||||
|
json.Unmarshal(bodyBytes, &auth)
|
||||||
|
if !ValidateIpPort(auth.Addr) {
|
||||||
|
return errors.New("Radius的服务器地址填写有误")
|
||||||
|
}
|
||||||
|
// freeradius官网最大8000字符, 这里限制200
|
||||||
|
if len(auth.Secret) < 8 || len(auth.Secret) > 200 {
|
||||||
|
return errors.New("Radius的密钥长度需在8~200个字符之间")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth AuthRadius) checkUser(name string, pwd string, authData map[string]interface{}) error {
|
||||||
|
pl := len(pwd)
|
||||||
|
if name == "" || pl < 1 {
|
||||||
|
return fmt.Errorf("%s %s", name, "密码错误")
|
||||||
|
}
|
||||||
|
authType := authData["type"].(string)
|
||||||
|
if _, ok := authData[authType]; !ok {
|
||||||
|
return fmt.Errorf("%s %s", name, "Radius的radius值不存在")
|
||||||
|
}
|
||||||
|
bodyBytes, err := json.Marshal(authData[authType])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s %s", name, "Radius Marshal出现错误")
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(bodyBytes, &auth)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s %s", name, "Radius Unmarshal出现错误")
|
||||||
|
}
|
||||||
|
// radius认证时,设置超时3秒
|
||||||
|
packet := radius.New(radius.CodeAccessRequest, []byte(auth.Secret))
|
||||||
|
rfc2865.UserName_SetString(packet, name)
|
||||||
|
rfc2865.UserPassword_SetString(packet, pwd)
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer done()
|
||||||
|
response, err := radius.Exchange(ctx, packet, auth.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s %s", name, "Radius服务器连接异常, 请检测服务器和端口")
|
||||||
|
}
|
||||||
|
if response.Code != radius.CodeAccessAccept {
|
||||||
|
return fmt.Errorf("%s %s", name, "Radius:用户名或密码错误")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateIpPort(addr string) bool {
|
||||||
|
RegExp := regexp.MustCompile(`^(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\.(\d{1,2}|1\d\d|2[0-4]\d|25[0-5])\:([0-9]|[1-9]\d{1,3}|[1-5]\d{4}|6[0-5]{2}[0-3][0-5])$$`)
|
||||||
|
return RegExp.MatchString(addr)
|
||||||
|
}
|
Loading…
Reference in New Issue