From edb0fe2dc9b9393fa9840898fffedd5320508cae Mon Sep 17 00:00:00 2001 From: bjd Date: Thu, 4 Feb 2021 13:32:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E5=88=86=E9=85=8D=E7=9A=84ip=E4=B8=BACIDR=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=EF=BC=8C=E8=AF=B7=E6=B3=A8=E6=84=8F=E5=8E=9F=E6=9D=A5network?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- admin/api_base.go | 2 +- admin/api_group.go | 6 +++--- admin/api_ip_map.go | 8 ++++---- admin/api_set.go | 6 ++---- admin/api_user.go | 32 +++++++++++++++++++------------ admin/common.go | 11 +---------- admin/resp.go | 6 ++++-- admin/resp_test.go | 39 ++++++++++++++++++++++++++++++++++++++ base/app_ver.go | 2 +- base/cfg_server.go | 22 ++++++++++----------- base/log.go | 4 ++-- bridge-init.sh | 4 ---- conf/server.toml | 3 +-- dbdata/db.go | 13 ++++++------- dbdata/db_test.go | 2 +- handler/dtls.go | 1 + handler/link_auth.go | 6 +++--- handler/link_base.go | 23 ---------------------- handler/link_home.go | 2 +- handler/link_tap.go | 3 +++ handler/link_tunnel.go | 12 ++++++------ handler/start.go | 2 +- pkg/arpdis/addr.go | 2 +- pkg/arpdis/icmp.go | 2 +- pkg/proxyproto/protocol.go | 6 ++++-- sessdata/copy_struct.go | 2 -- sessdata/ip_pool.go | 28 ++++++++++++++------------- sessdata/ip_pool_test.go | 13 ++++++------- sessdata/limit_test.go | 5 +++-- sessdata/online.go | 5 +---- sessdata/session.go | 16 +++++++++------- sessdata/session_test.go | 6 ++++-- 32 files changed, 155 insertions(+), 139 deletions(-) create mode 100644 admin/resp_test.go diff --git a/admin/api_base.go b/admin/api_base.go index ab7bb46..c44a721 100644 --- a/admin/api_base.go +++ b/admin/api_base.go @@ -16,7 +16,7 @@ func Login(w http.ResponseWriter, r *http.Request) { // hd, _ := httputil.DumpRequest(r, true) // fmt.Println("DumpRequest: ", string(hd)) - r.ParseForm() + _ = r.ParseForm() admin_user := r.PostFormValue("admin_user") admin_pass := r.PostFormValue("admin_pass") diff --git a/admin/api_group.go b/admin/api_group.go index ba9302a..087333d 100644 --- a/admin/api_group.go +++ b/admin/api_group.go @@ -10,7 +10,7 @@ import ( ) func GroupList(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() pageS := r.FormValue("page") page, _ := strconv.Atoi(pageS) if page < 1 { @@ -48,7 +48,7 @@ func GroupNames(w http.ResponseWriter, r *http.Request) { } func GroupDetail(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) if id < 1 { @@ -90,7 +90,7 @@ func GroupSet(w http.ResponseWriter, r *http.Request) { } func GroupDel(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) if id < 1 { diff --git a/admin/api_ip_map.go b/admin/api_ip_map.go index ca4a024..3f57fba 100644 --- a/admin/api_ip_map.go +++ b/admin/api_ip_map.go @@ -11,7 +11,7 @@ import ( ) func UserIpMapList(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() pageS := r.FormValue("page") page, _ := strconv.Atoi(pageS) if page < 1 { @@ -39,7 +39,7 @@ func UserIpMapList(w http.ResponseWriter, r *http.Request) { } func UserIpMapDetail(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) if id < 1 { @@ -58,7 +58,7 @@ func UserIpMapDetail(w http.ResponseWriter, r *http.Request) { } func UserIpMapSet(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -92,7 +92,7 @@ func UserIpMapSet(w http.ResponseWriter, r *http.Request) { } func UserIpMapDel(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) diff --git a/admin/api_set.go b/admin/api_set.go index 6be5737..d40c4e2 100644 --- a/admin/api_set.go +++ b/admin/api_set.go @@ -1,7 +1,6 @@ package admin import ( - "encoding/json" "fmt" "net/http" "runtime" @@ -84,9 +83,8 @@ func SetSystem(w http.ResponseWriter, r *http.Request) { } func SetSoft(w http.ResponseWriter, r *http.Request) { - datas := base.ServerCfg2Slice() - b, _ := json.Marshal(datas) - w.Write(b) + data := base.ServerCfg2Slice() + RespSucess(w, data) } func decimal(f float64) float64 { diff --git a/admin/api_user.go b/admin/api_user.go index 960cafa..8442ee7 100644 --- a/admin/api_user.go +++ b/admin/api_user.go @@ -19,7 +19,7 @@ import ( ) func UserList(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() prefix := r.FormValue("prefix") pageS := r.FormValue("page") page, _ := strconv.Atoi(pageS) @@ -58,7 +58,7 @@ func UserList(w http.ResponseWriter, r *http.Request) { } func UserDetail(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) if id < 1 { @@ -77,7 +77,7 @@ func UserDetail(w http.ResponseWriter, r *http.Request) { } func UserSet(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -111,7 +111,7 @@ func UserSet(w http.ResponseWriter, r *http.Request) { } func UserDel(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") id, _ := strconv.Atoi(idS) @@ -130,7 +130,7 @@ func UserDel(w http.ResponseWriter, r *http.Request) { } func UserOtpQr(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() b64 := r.FormValue("b64") idS := r.FormValue("id") id, _ := strconv.Atoi(idS) @@ -148,11 +148,16 @@ func UserOtpQr(w http.ResponseWriter, r *http.Request) { if b64 == "1" { data, _ := qr.PNG(300) s := base64.StdEncoding.EncodeToString(data) - fmt.Fprint(w, s) - } else { - qr.Write(300, w) + _, err = fmt.Fprint(w, s) + if err != nil { + base.Error(err) + } + return + } + err = qr.Write(300, w) + if err != nil { + base.Error(err) } - } // 在线用户 @@ -169,14 +174,14 @@ func UserOnline(w http.ResponseWriter, r *http.Request) { } func UserOffline(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() token := r.FormValue("token") sessdata.CloseSess(token) RespSucess(w, nil) } func UserReline(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() token := r.FormValue("token") sessdata.CloseCSess(token) RespSucess(w, nil) @@ -231,7 +236,10 @@ func userAccountMail(user *dbdata.User) error { } w := bytes.NewBufferString("") t, _ := template.New("auth_complete").Parse(htmlBody) - t.Execute(w, data) + err = t.Execute(w, data) + if err != nil { + return err + } // fmt.Println(w.String()) return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String()) } diff --git a/admin/common.go b/admin/common.go index 98ee065..14796bf 100644 --- a/admin/common.go +++ b/admin/common.go @@ -8,8 +8,8 @@ import ( "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" "github.com/dgrijalva/jwt-go" - "github.com/mojocn/base64Captcha" mail "github.com/xhit/go-simple-mail/v2" + // "github.com/mojocn/base64Captcha" ) func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) { @@ -43,15 +43,6 @@ func GetJwtData(jwtToken string) (map[string]interface{}, error) { return claims, nil } -func createCaptcha() { - var store = base64Captcha.DefaultMemStore - var driver base64Captcha.Driver - driverString := &base64Captcha.DriverString{} - driver = driverString.ConvertFonts() - c := base64Captcha.NewCaptcha(driver, store) - _ = c -} - func SendMail(subject, to, htmlBody string) error { dataSmtp := &dbdata.SettingSmtp{} diff --git a/admin/resp.go b/admin/resp.go index fdd953f..a133f24 100644 --- a/admin/resp.go +++ b/admin/resp.go @@ -43,8 +43,10 @@ func respHttp(w http.ResponseWriter, respCode int, data interface{}, errS ...int w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(http.StatusOK) - w.Write(b) - + _, err = w.Write(b) + if err != nil { + base.Error(err) + } // 记录返回数据 // logger.Category("response").Debug(string(b)) } diff --git a/admin/resp_test.go b/admin/resp_test.go new file mode 100644 index 0000000..4d3dd60 --- /dev/null +++ b/admin/resp_test.go @@ -0,0 +1,39 @@ +package admin + +import ( + "encoding/json" + "io/ioutil" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRespSucess(t *testing.T) { + assert := assert.New(t) + w := httptest.NewRecorder() + RespSucess(w, "data") + // fmt.Println(w) + assert.Equal(w.Code, 200) + body, _ := ioutil.ReadAll(w.Body) + res := Resp{} + err := json.Unmarshal(body, &res) + assert.Nil(err) + assert.Equal(res.Code, 0) + assert.Equal(res.Data, "data") + +} + +func TestRespError(t *testing.T) { + assert := assert.New(t) + w := httptest.NewRecorder() + RespError(w, 10, "err-msg") + // fmt.Println(w) + assert.Equal(w.Code, 200) + body, _ := ioutil.ReadAll(w.Body) + res := Resp{} + err := json.Unmarshal(body, &res) + assert.Nil(err) + assert.Equal(res.Code, 10) + assert.Equal(res.Msg, "err-msg") +} diff --git a/base/app_ver.go b/base/app_ver.go index 9c31e5e..9e7adc9 100644 --- a/base/app_ver.go +++ b/base/app_ver.go @@ -2,5 +2,5 @@ package base const ( APP_NAME = "AnyLink" - APP_VER = "0.1.1" + APP_VER = "0.1.2" ) diff --git a/base/cfg_server.go b/base/cfg_server.go index bda9397..7280e52 100644 --- a/base/cfg_server.go +++ b/base/cfg_server.go @@ -47,9 +47,8 @@ type ServerConfig struct { AdminPass string `toml:"admin_pass" info:"管理用户密码"` JwtSecret string `toml:"jwt_secret" info:"JWT密钥"` - LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap - Ipv4Network string `toml:"ipv4_network" info:"ipv4_network"` // 192.168.1.0 - Ipv4Netmask string `toml:"ipv4_netmask" info:"ipv4_netmask"` // 255.255.255.0 + LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap + Ipv4CIDR string `toml:"ipv4_cidr" info:"ip地址网段"` // 192.168.1.0/24 Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"` Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200 IpLease int `toml:"ip_lease" info:"IP租期(秒)"` @@ -102,16 +101,17 @@ func getAbsPath(base, cfile string) string { return filepath.Join(base, cfile) } -func ServerCfg2Slice() interface{} { +type SCfg struct { + Name string `json:"name"` + Info string `json:"info"` + Data interface{} `json:"data"` +} + +func ServerCfg2Slice() []SCfg { ref := reflect.ValueOf(Cfg) s := ref.Elem() - type cfg struct { - Name string `json:"name"` - Info string `json:"info"` - Data interface{} `json:"data"` - } - var datas []cfg + var datas []SCfg typ := s.Type() numFields := s.NumField() @@ -122,7 +122,7 @@ func ServerCfg2Slice() interface{} { tags := strings.Split(tag, ",") info := field.Tag.Get("info") - datas = append(datas, cfg{Name: tags[0], Info: info, Data: value.Interface()}) + datas = append(datas, SCfg{Name: tags[0], Info: info, Data: value.Interface()}) } return datas diff --git a/base/log.go b/base/log.go index 2a0516f..7c5efc0 100644 --- a/base/log.go +++ b/base/log.go @@ -36,7 +36,7 @@ func logLevel2Int(l string) int { } lvl := _Info for k, v := range levels { - if strings.ToLower(l) == strings.ToLower(v) { + if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) { lvl = k } } @@ -45,7 +45,7 @@ func logLevel2Int(l string) int { func output(l int, s ...interface{}) { lvl := fmt.Sprintf("[%s] ", levels[l]) - baseLog.Output(3, lvl+fmt.Sprintln(s...)) + _ = baseLog.Output(3, lvl+fmt.Sprintln(s...)) } func Debug(v ...interface{}) { diff --git a/bridge-init.sh b/bridge-init.sh index f60dac1..8837020 100644 --- a/bridge-init.sh +++ b/bridge-init.sh @@ -8,10 +8,6 @@ # Define Bridge Interface br="anylink0" -# Define list of TAP interfaces to be bridged, -# for example tap="tap0 tap1 tap2". -tap="tap0" - # Define physical ethernet interface to be bridged # with TAP interface(s) above. diff --git a/conf/server.toml b/conf/server.toml index 533bf62..98f873b 100644 --- a/conf/server.toml +++ b/conf/server.toml @@ -35,8 +35,7 @@ proxy_protocol = false link_mode = "tun" #客户端分配的ip地址池 -ipv4_network = "192.168.10.0" -ipv4_netmask = "255.255.255.0" +ipv4_cidr = "192.168.10.0/24" ipv4_gateway = "192.168.10.1" ipv4_pool = ["192.168.10.100", "192.168.10.200"] diff --git a/dbdata/db.go b/dbdata/db.go index b5bb20b..b74bc63 100644 --- a/dbdata/db.go +++ b/dbdata/db.go @@ -43,28 +43,27 @@ func initData() { return } - defer Set(SettingBucket, Installed, true) + defer func() { + _ = Set(SettingBucket, Installed, true) + }() smtp := &SettingSmtp{ Host: "127.0.0.1", Port: 25, From: "vpn@xx.com", } - SettingSet(smtp) + _ = SettingSet(smtp) other := &SettingOther{ Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!", AccountMail: accountMail, } - SettingSet(other) + _ = SettingSet(other) } func CheckErrNotFound(err error) bool { - if err == storm.ErrNotFound { - return true - } - return false + return err == storm.ErrNotFound } const accountMail = `

您好:

diff --git a/dbdata/db_test.go b/dbdata/db_test.go index 60c5c72..9950f9d 100644 --- a/dbdata/db_test.go +++ b/dbdata/db_test.go @@ -27,7 +27,7 @@ func TestDb(t *testing.T) { defer closeIpdata() u := User{Username: "a"} - Save(&u) + _ = Save(&u) assert.Equal(u.Id, 1) } diff --git a/handler/dtls.go b/handler/dtls.go index 9fe46bb..32c9470 100644 --- a/handler/dtls.go +++ b/handler/dtls.go @@ -2,4 +2,5 @@ package handler // 暂时没有实现 func startDtls() { + } diff --git a/handler/link_auth.go b/handler/link_auth.go index e5ab17e..6701995 100644 --- a/handler/link_auth.go +++ b/handler/link_auth.go @@ -87,7 +87,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress) sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal other := &dbdata.SettingOther{} - dbdata.SettingGet(other) + _ = dbdata.SettingGet(other) rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token, Banner: other.Banner} w.WriteHeader(http.StatusOK) @@ -102,7 +102,7 @@ const ( func tplRequest(typ int, w io.Writer, data RequestData) { if typ == tpl_request { t, _ := template.New("auth_request").Parse(auth_request) - t.Execute(w, data) + _ = t.Execute(w, data) return } @@ -111,7 +111,7 @@ func tplRequest(typ int, w io.Writer, data RequestData) { data.Banner = strings.ReplaceAll(data.Banner, "\n", " ") } t, _ := template.New("auth_complete").Parse(auth_complete) - t.Execute(w, data) + _ = t.Execute(w, data) } // 设置输出信息 diff --git a/handler/link_base.go b/handler/link_base.go index 14f6409..8c3b1b9 100644 --- a/handler/link_base.go +++ b/handler/link_base.go @@ -2,11 +2,9 @@ package handler import ( "encoding/xml" - "fmt" "log" "net/http" "os/exec" - "strings" ) const BufferSize = 2048 @@ -43,27 +41,6 @@ type macAddressList struct { MacAddress string `xml:"mac-address"` } -// 判断anyconnect客户端 -func checkLinkClient(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // TODO 调试信息输出 - // hd, _ := httputil.DumpRequest(r, true) - // fmt.Println("DumpRequest: ", string(hd)) - // fmt.Println(r.RemoteAddr) - - userAgent := strings.ToLower(r.UserAgent()) - x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth") - x_Transcend_Version := r.Header.Get("X-Transcend-Version") - if strings.Contains(userAgent, "anyconnect") && - x_Aggregate_Auth == "1" && x_Transcend_Version == "1" { - h(w, r) - } else { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "error request") - } - } -} - func setCommonHeader(w http.ResponseWriter) { // Content-Length Date 默认已经存在 w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/handler/link_home.go b/handler/link_home.go index f927f6f..bb7c05f 100644 --- a/handler/link_home.go +++ b/handler/link_home.go @@ -26,7 +26,7 @@ func LinkHome(w http.ResponseWriter, r *http.Request) { } func LinkOtpQr(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + _ = r.ParseForm() idS := r.FormValue("id") jwtToken := r.FormValue("jwt") data, err := admin.GetJwtData(jwtToken) diff --git a/handler/link_tap.go b/handler/link_tap.go index 21389fb..a43a3d1 100644 --- a/handler/link_tap.go +++ b/handler/link_tap.go @@ -29,6 +29,9 @@ func checkTap() { bridgeHw = brFace.HardwareAddr addrs, err := brFace.Addrs() + if err != nil { + base.Fatal("testTap err: ", err) + } for _, addr := range addrs { ip, _, err := net.ParseCIDR(addr.String()) if err != nil || ip.To4() == nil { diff --git a/handler/link_tunnel.go b/handler/link_tunnel.go index d6f51bb..16b466f 100644 --- a/handler/link_tunnel.go +++ b/handler/link_tunnel.go @@ -73,9 +73,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { w.Header().Set("Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER)) w.Header().Set("X-CSTP-Version", "1") w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.") - w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址 - w.Header().Set("X-CSTP-Netmask", base.Cfg.Ipv4Netmask) // 子网掩码 - w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 + w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址 + w.Header().Set("X-CSTP-Netmask", sessdata.IpPool.Ipv4Mask.String()) // 子网掩码 + w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 // 允许本地LAN访问vpn网络,必须放在路由的第一个 if cSess.Group.AllowLan { @@ -131,11 +131,11 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // w.Header().Set("X-CSTP-Post-Auth-XML", ``) w.WriteHeader(http.StatusOK) - h := w.Header().Clone() + hClone := w.Header().Clone() headers := make([]byte, 0) buf := bytes.NewBuffer(headers) - h.Write(buf) - base.Debug(string(buf.Bytes())) + _ = hClone.Write(buf) + base.Debug(buf.String()) hj := w.(http.Hijacker) conn, _, err := hj.Hijack() diff --git a/handler/start.go b/handler/start.go index feee6ea..d074244 100644 --- a/handler/start.go +++ b/handler/start.go @@ -21,5 +21,5 @@ func Start() { } func Stop() { - dbdata.Stop() + _ = dbdata.Stop() } diff --git a/pkg/arpdis/addr.go b/pkg/arpdis/addr.go index 8db9c10..fb46bad 100644 --- a/pkg/arpdis/addr.go +++ b/pkg/arpdis/addr.go @@ -48,7 +48,7 @@ func tableLookup(ip net.IP) *Addr { } // 判断老化过期时间 - tsub := time.Now().Sub(addr.disTime) + tsub := time.Since(addr.disTime) switch addr.Type { case TypeNormal: if tsub > StaleTimeNormal { diff --git a/pkg/arpdis/icmp.go b/pkg/arpdis/icmp.go index d89d7ae..925e16c 100644 --- a/pkg/arpdis/icmp.go +++ b/pkg/arpdis/icmp.go @@ -48,7 +48,7 @@ func doPing(ip string) error { return err } - conn.SetReadDeadline(time.Now().Add(time.Second * 2)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second * 2)) for { buf := make([]byte, 512) diff --git a/pkg/proxyproto/protocol.go b/pkg/proxyproto/protocol.go index c4468f0..f91f0b0 100644 --- a/pkg/proxyproto/protocol.go +++ b/pkg/proxyproto/protocol.go @@ -198,8 +198,10 @@ func (p *Conn) checkPrefixOnce() { func (p *Conn) checkPrefix() error { if p.proxyHeaderTimeout != 0 { readDeadLine := time.Now().Add(p.proxyHeaderTimeout) - p.conn.SetReadDeadline(readDeadLine) - defer p.conn.SetReadDeadline(time.Time{}) + _ = p.conn.SetReadDeadline(readDeadLine) + defer func() { + _ = p.conn.SetReadDeadline(time.Time{}) + }() } // Incrementally check each byte of the prefix diff --git a/sessdata/copy_struct.go b/sessdata/copy_struct.go index 33db3c5..b93e2f5 100644 --- a/sessdata/copy_struct.go +++ b/sessdata/copy_struct.go @@ -45,8 +45,6 @@ func CopyStruct(a interface{}, b interface{}, fields ...string) (err error) { // a中有同名的字段并且类型一致才复制 if f.IsValid() && f.Kind() == bValue.Kind() { f.Set(bValue) - } else { - // fmt.Printf("no such field or different kind, fieldName: %s\n", name) } } return diff --git a/sessdata/ip_pool.go b/sessdata/ip_pool.go index 55f3079..f3f51a0 100644 --- a/sessdata/ip_pool.go +++ b/sessdata/ip_pool.go @@ -19,7 +19,8 @@ type ipPoolConfig struct { mux sync.Mutex // 计算动态ip Ipv4Gateway net.IP - Ipv4IPNet net.IPNet + Ipv4Mask net.IP + Ipv4IPNet *net.IPNet IpLongMin uint32 IpLongMax uint32 } @@ -27,11 +28,12 @@ type ipPoolConfig struct { func initIpPool() { // 地址处理 - // ip地址 - ip := net.ParseIP(base.Cfg.Ipv4Network) - // 子网掩码 - maskIp := net.ParseIP(base.Cfg.Ipv4Netmask).To4() - IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)} + _, ipNet, err := net.ParseCIDR(base.Cfg.Ipv4CIDR) + if err != nil { + panic(err) + } + IpPool.Ipv4IPNet = ipNet + IpPool.Ipv4Mask = net.IP(ipNet.Mask) IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway) // 网络地址零值 @@ -74,11 +76,11 @@ func AcquireIp(username, macAddr string) net.IP { mi.Username = username mi.LastLogin = tNow // 回写db数据 - dbdata.Save(mi) + _ = dbdata.Save(mi) ipActive[ipStr] = true return ip } else { - dbdata.Del(mi) + _ = dbdata.Del(mi) } } @@ -92,7 +94,7 @@ func AcquireIp(username, macAddr string) net.IP { if err != nil && dbdata.CheckErrNotFound(err) { // 该ip没有被使用 mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} - dbdata.Save(mi) + _ = dbdata.Save(mi) ipActive[ipStr] = true return ip } @@ -121,10 +123,10 @@ func AcquireIp(username, macAddr string) net.IP { // 已经超过租期 if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second { - dbdata.Del(v) + _ = dbdata.Del(v) mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} // 重写db数据 - dbdata.Save(mi) + _ = dbdata.Save(mi) ipActive[ipStr] = true return ip } @@ -145,7 +147,7 @@ func AcquireIp(username, macAddr string) net.IP { ipStr := ip.String() mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} // 回写db数据 - dbdata.Save(mi) + _ = dbdata.Save(mi) ipActive[ipStr] = true return ip } @@ -160,6 +162,6 @@ func ReleaseIp(ip net.IP, macAddr string) { err := dbdata.One("IpAddr", ip, mi) if err == nil { mi.LastLogin = time.Now() - dbdata.Save(mi) + _ = dbdata.Save(mi) } } diff --git a/sessdata/ip_pool_test.go b/sessdata/ip_pool_test.go index d1583cb..92d0327 100644 --- a/sessdata/ip_pool_test.go +++ b/sessdata/ip_pool_test.go @@ -15,8 +15,7 @@ import ( func preData(tmpDir string) { tmpDb := path.Join(tmpDir, "test.db") base.Cfg.DbFile = tmpDb - base.Cfg.Ipv4Network = "192.168.3.0" - base.Cfg.Ipv4Netmask = "255.255.255.0" + base.Cfg.Ipv4CIDR = "192.168.3.0/24" base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"} base.Cfg.MaxClient = 100 base.Cfg.MaxUserClient = 3 @@ -26,12 +25,12 @@ func preData(tmpDir string) { Name: "group1", Bandwidth: 1000, } - dbdata.Save(&group) + _ = dbdata.Save(&group) initIpPool() } func cleardata(tmpDir string) { - dbdata.Stop() + _ = dbdata.Stop() tmpDb := path.Join(tmpDir, "test.db") os.Remove(tmpDb) } @@ -45,15 +44,15 @@ func TestIpPool(t *testing.T) { var ip net.IP for i := 1; i <= 100; i++ { - ip = AcquireIp("user", fmt.Sprintf("mac-%d", i)) + _ = AcquireIp("user", fmt.Sprintf("mac-%d", i)) } - ip = AcquireIp("user", fmt.Sprintf("mac-new")) + ip = AcquireIp("user", "mac-new") assert.True(net.IPv4(192, 168, 3, 101).Equal(ip)) for i := 102; i <= 199; i++ { ip = AcquireIp("user", fmt.Sprintf("mac-%d", i)) } assert.True(net.IPv4(192, 168, 3, 199).Equal(ip)) - ip = AcquireIp("user", fmt.Sprintf("mac-nil")) + ip = AcquireIp("user", "mac-nil") assert.Nil(ip) ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88") diff --git a/sessdata/limit_test.go b/sessdata/limit_test.go index 04252e3..6ebd178 100644 --- a/sessdata/limit_test.go +++ b/sessdata/limit_test.go @@ -43,10 +43,11 @@ func TestLimitClient(t *testing.T) { func TestLimitWait(t *testing.T) { assert := assert.New(t) limit := NewLimitRater(1, 2) - limit.Wait(2) - start := time.Now() err := limit.Wait(2) assert.Nil(err) + start := time.Now() + err = limit.Wait(2) + assert.Nil(err) err = limit.Wait(1) assert.Nil(err) end := time.Now() diff --git a/sessdata/online.go b/sessdata/online.go index 634214d..df1cee8 100644 --- a/sessdata/online.go +++ b/sessdata/online.go @@ -34,10 +34,7 @@ func (o Onlines) Len() int { } func (o Onlines) Less(i, j int) bool { - if bytes.Compare(o[i].Ip, o[j].Ip) < 0 { - return true - } - return false + return bytes.Compare(o[i].Ip, o[j].Ip) < 0 } func (o Onlines) Swap(i, j int) { diff --git a/sessdata/session.go b/sessdata/session.go index 6331cf7..8418879 100644 --- a/sessdata/session.go +++ b/sessdata/session.go @@ -78,13 +78,13 @@ func checkSession() { return } timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second - tick := time.Tick(time.Second * 60) - for range tick { + tick := time.NewTicker(time.Second * 60) + for range tick.C { sessMux.Lock() t := time.Now() for k, v := range sessions { v.mux.Lock() - if v.IsActive != true { + if !v.IsActive { if t.Sub(v.LastLogin) > timeout { delete(sessions, k) } @@ -133,12 +133,12 @@ func (s *Session) NewConn() *ConnSession { macAddr := s.MacAddr username := s.Username s.mux.Unlock() - if active == true { + if active { s.CSess.Close() } limit := LimitClient(username, false) - if limit == false { + if !limit { return nil } // 获取客户端mac地址 @@ -208,8 +208,10 @@ func (cs *ConnSession) Close() { const BandwidthPeriodSec = 2 // 流量速率统计周期(秒) func (cs *ConnSession) ratePeriod() { - tick := time.Tick(time.Second * BandwidthPeriodSec) - for range tick { + tick := time.NewTicker(time.Second * BandwidthPeriodSec) + defer tick.Stop() + + for range tick.C { select { case <-cs.CloseChan: return diff --git a/sessdata/session_test.go b/sessdata/session_test.go index 268d918..4499762 100644 --- a/sessdata/session_test.go +++ b/sessdata/session_test.go @@ -28,9 +28,11 @@ func TestConnSession(t *testing.T) { cSess := sess.NewConn() - cSess.RateLimit(100, true) + err := cSess.RateLimit(100, true) + assert.Nil(err) assert.Equal(cSess.BandwidthUp, uint32(100)) - cSess.RateLimit(200, false) + err = cSess.RateLimit(200, false) + assert.Nil(err) assert.Equal(cSess.BandwidthDown, uint32(200)) cSess.Close() }