修改客户端分配的ip为CIDR格式,请注意原来network格式

This commit is contained in:
bjd 2021-02-04 13:32:10 +08:00
parent 1c6572f5e3
commit edb0fe2dc9
32 changed files with 155 additions and 139 deletions

View File

@ -16,7 +16,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
// hd, _ := httputil.DumpRequest(r, true)
// fmt.Println("DumpRequest: ", string(hd))
r.ParseForm()
_ = r.ParseForm()
admin_user := r.PostFormValue("admin_user")
admin_pass := r.PostFormValue("admin_pass")

View File

@ -10,7 +10,7 @@ import (
)
func GroupList(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
pageS := r.FormValue("page")
page, _ := strconv.Atoi(pageS)
if page < 1 {
@ -48,7 +48,7 @@ func GroupNames(w http.ResponseWriter, r *http.Request) {
}
func GroupDetail(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
if id < 1 {
@ -90,7 +90,7 @@ func GroupSet(w http.ResponseWriter, r *http.Request) {
}
func GroupDel(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
if id < 1 {

View File

@ -11,7 +11,7 @@ import (
)
func UserIpMapList(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
pageS := r.FormValue("page")
page, _ := strconv.Atoi(pageS)
if page < 1 {
@ -39,7 +39,7 @@ func UserIpMapList(w http.ResponseWriter, r *http.Request) {
}
func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
if id < 1 {
@ -58,7 +58,7 @@ func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
}
func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -92,7 +92,7 @@ func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
}
func UserIpMapDel(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)

View File

@ -1,7 +1,6 @@
package admin
import (
"encoding/json"
"fmt"
"net/http"
"runtime"
@ -84,9 +83,8 @@ func SetSystem(w http.ResponseWriter, r *http.Request) {
}
func SetSoft(w http.ResponseWriter, r *http.Request) {
datas := base.ServerCfg2Slice()
b, _ := json.Marshal(datas)
w.Write(b)
data := base.ServerCfg2Slice()
RespSucess(w, data)
}
func decimal(f float64) float64 {

View File

@ -19,7 +19,7 @@ import (
)
func UserList(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
prefix := r.FormValue("prefix")
pageS := r.FormValue("page")
page, _ := strconv.Atoi(pageS)
@ -58,7 +58,7 @@ func UserList(w http.ResponseWriter, r *http.Request) {
}
func UserDetail(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
if id < 1 {
@ -77,7 +77,7 @@ func UserDetail(w http.ResponseWriter, r *http.Request) {
}
func UserSet(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -111,7 +111,7 @@ func UserSet(w http.ResponseWriter, r *http.Request) {
}
func UserDel(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
@ -130,7 +130,7 @@ func UserDel(w http.ResponseWriter, r *http.Request) {
}
func UserOtpQr(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
b64 := r.FormValue("b64")
idS := r.FormValue("id")
id, _ := strconv.Atoi(idS)
@ -148,11 +148,16 @@ func UserOtpQr(w http.ResponseWriter, r *http.Request) {
if b64 == "1" {
data, _ := qr.PNG(300)
s := base64.StdEncoding.EncodeToString(data)
fmt.Fprint(w, s)
} else {
qr.Write(300, w)
_, err = fmt.Fprint(w, s)
if err != nil {
base.Error(err)
}
return
}
err = qr.Write(300, w)
if err != nil {
base.Error(err)
}
}
// 在线用户
@ -169,14 +174,14 @@ func UserOnline(w http.ResponseWriter, r *http.Request) {
}
func UserOffline(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
token := r.FormValue("token")
sessdata.CloseSess(token)
RespSucess(w, nil)
}
func UserReline(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
token := r.FormValue("token")
sessdata.CloseCSess(token)
RespSucess(w, nil)
@ -231,7 +236,10 @@ func userAccountMail(user *dbdata.User) error {
}
w := bytes.NewBufferString("")
t, _ := template.New("auth_complete").Parse(htmlBody)
t.Execute(w, data)
err = t.Execute(w, data)
if err != nil {
return err
}
// fmt.Println(w.String())
return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String())
}

View File

@ -8,8 +8,8 @@ import (
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/dgrijalva/jwt-go"
"github.com/mojocn/base64Captcha"
mail "github.com/xhit/go-simple-mail/v2"
// "github.com/mojocn/base64Captcha"
)
func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) {
@ -43,15 +43,6 @@ func GetJwtData(jwtToken string) (map[string]interface{}, error) {
return claims, nil
}
func createCaptcha() {
var store = base64Captcha.DefaultMemStore
var driver base64Captcha.Driver
driverString := &base64Captcha.DriverString{}
driver = driverString.ConvertFonts()
c := base64Captcha.NewCaptcha(driver, store)
_ = c
}
func SendMail(subject, to, htmlBody string) error {
dataSmtp := &dbdata.SettingSmtp{}

View File

@ -43,8 +43,10 @@ func respHttp(w http.ResponseWriter, respCode int, data interface{}, errS ...int
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(b)
_, err = w.Write(b)
if err != nil {
base.Error(err)
}
// 记录返回数据
// logger.Category("response").Debug(string(b))
}

39
admin/resp_test.go Normal file
View File

@ -0,0 +1,39 @@
package admin
import (
"encoding/json"
"io/ioutil"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRespSucess(t *testing.T) {
assert := assert.New(t)
w := httptest.NewRecorder()
RespSucess(w, "data")
// fmt.Println(w)
assert.Equal(w.Code, 200)
body, _ := ioutil.ReadAll(w.Body)
res := Resp{}
err := json.Unmarshal(body, &res)
assert.Nil(err)
assert.Equal(res.Code, 0)
assert.Equal(res.Data, "data")
}
func TestRespError(t *testing.T) {
assert := assert.New(t)
w := httptest.NewRecorder()
RespError(w, 10, "err-msg")
// fmt.Println(w)
assert.Equal(w.Code, 200)
body, _ := ioutil.ReadAll(w.Body)
res := Resp{}
err := json.Unmarshal(body, &res)
assert.Nil(err)
assert.Equal(res.Code, 10)
assert.Equal(res.Msg, "err-msg")
}

View File

@ -2,5 +2,5 @@ package base
const (
APP_NAME = "AnyLink"
APP_VER = "0.1.1"
APP_VER = "0.1.2"
)

View File

@ -48,8 +48,7 @@ type ServerConfig struct {
JwtSecret string `toml:"jwt_secret" info:"JWT密钥"`
LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap
Ipv4Network string `toml:"ipv4_network" info:"ipv4_network"` // 192.168.1.0
Ipv4Netmask string `toml:"ipv4_netmask" info:"ipv4_netmask"` // 255.255.255.0
Ipv4CIDR string `toml:"ipv4_cidr" info:"ip地址网段"` // 192.168.1.0/24
Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"`
Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200
IpLease int `toml:"ip_lease" info:"IP租期(秒)"`
@ -102,16 +101,17 @@ func getAbsPath(base, cfile string) string {
return filepath.Join(base, cfile)
}
func ServerCfg2Slice() interface{} {
ref := reflect.ValueOf(Cfg)
s := ref.Elem()
type cfg struct {
type SCfg struct {
Name string `json:"name"`
Info string `json:"info"`
Data interface{} `json:"data"`
}
var datas []cfg
}
func ServerCfg2Slice() []SCfg {
ref := reflect.ValueOf(Cfg)
s := ref.Elem()
var datas []SCfg
typ := s.Type()
numFields := s.NumField()
@ -122,7 +122,7 @@ func ServerCfg2Slice() interface{} {
tags := strings.Split(tag, ",")
info := field.Tag.Get("info")
datas = append(datas, cfg{Name: tags[0], Info: info, Data: value.Interface()})
datas = append(datas, SCfg{Name: tags[0], Info: info, Data: value.Interface()})
}
return datas

View File

@ -36,7 +36,7 @@ func logLevel2Int(l string) int {
}
lvl := _Info
for k, v := range levels {
if strings.ToLower(l) == strings.ToLower(v) {
if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) {
lvl = k
}
}
@ -45,7 +45,7 @@ func logLevel2Int(l string) int {
func output(l int, s ...interface{}) {
lvl := fmt.Sprintf("[%s] ", levels[l])
baseLog.Output(3, lvl+fmt.Sprintln(s...))
_ = baseLog.Output(3, lvl+fmt.Sprintln(s...))
}
func Debug(v ...interface{}) {

View File

@ -8,10 +8,6 @@
# Define Bridge Interface
br="anylink0"
# Define list of TAP interfaces to be bridged,
# for example tap="tap0 tap1 tap2".
tap="tap0"
# Define physical ethernet interface to be bridged
# with TAP interface(s) above.

View File

@ -35,8 +35,7 @@ proxy_protocol = false
link_mode = "tun"
#客户端分配的ip地址池
ipv4_network = "192.168.10.0"
ipv4_netmask = "255.255.255.0"
ipv4_cidr = "192.168.10.0/24"
ipv4_gateway = "192.168.10.1"
ipv4_pool = ["192.168.10.100", "192.168.10.200"]

View File

@ -43,28 +43,27 @@ func initData() {
return
}
defer Set(SettingBucket, Installed, true)
defer func() {
_ = Set(SettingBucket, Installed, true)
}()
smtp := &SettingSmtp{
Host: "127.0.0.1",
Port: 25,
From: "vpn@xx.com",
}
SettingSet(smtp)
_ = SettingSet(smtp)
other := &SettingOther{
Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为",
AccountMail: accountMail,
}
SettingSet(other)
_ = SettingSet(other)
}
func CheckErrNotFound(err error) bool {
if err == storm.ErrNotFound {
return true
}
return false
return err == storm.ErrNotFound
}
const accountMail = `<p>您好:</p>

View File

@ -27,7 +27,7 @@ func TestDb(t *testing.T) {
defer closeIpdata()
u := User{Username: "a"}
Save(&u)
_ = Save(&u)
assert.Equal(u.Id, 1)
}

View File

@ -2,4 +2,5 @@ package handler
// 暂时没有实现
func startDtls() {
}

View File

@ -87,7 +87,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress)
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
other := &dbdata.SettingOther{}
dbdata.SettingGet(other)
_ = dbdata.SettingGet(other)
rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token,
Banner: other.Banner}
w.WriteHeader(http.StatusOK)
@ -102,7 +102,7 @@ const (
func tplRequest(typ int, w io.Writer, data RequestData) {
if typ == tpl_request {
t, _ := template.New("auth_request").Parse(auth_request)
t.Execute(w, data)
_ = t.Execute(w, data)
return
}
@ -111,7 +111,7 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
data.Banner = strings.ReplaceAll(data.Banner, "\n", "&#x0A;")
}
t, _ := template.New("auth_complete").Parse(auth_complete)
t.Execute(w, data)
_ = t.Execute(w, data)
}
// 设置输出信息

View File

@ -2,11 +2,9 @@ package handler
import (
"encoding/xml"
"fmt"
"log"
"net/http"
"os/exec"
"strings"
)
const BufferSize = 2048
@ -43,27 +41,6 @@ type macAddressList struct {
MacAddress string `xml:"mac-address"`
}
// 判断anyconnect客户端
func checkLinkClient(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// TODO 调试信息输出
// hd, _ := httputil.DumpRequest(r, true)
// fmt.Println("DumpRequest: ", string(hd))
// fmt.Println(r.RemoteAddr)
userAgent := strings.ToLower(r.UserAgent())
x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
x_Transcend_Version := r.Header.Get("X-Transcend-Version")
if strings.Contains(userAgent, "anyconnect") &&
x_Aggregate_Auth == "1" && x_Transcend_Version == "1" {
h(w, r)
} else {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "error request")
}
}
}
func setCommonHeader(w http.ResponseWriter) {
// Content-Length Date 默认已经存在
w.Header().Set("Content-Type", "text/html; charset=utf-8")

View File

@ -26,7 +26,7 @@ func LinkHome(w http.ResponseWriter, r *http.Request) {
}
func LinkOtpQr(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
_ = r.ParseForm()
idS := r.FormValue("id")
jwtToken := r.FormValue("jwt")
data, err := admin.GetJwtData(jwtToken)

View File

@ -29,6 +29,9 @@ func checkTap() {
bridgeHw = brFace.HardwareAddr
addrs, err := brFace.Addrs()
if err != nil {
base.Fatal("testTap err: ", err)
}
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil || ip.To4() == nil {

View File

@ -74,7 +74,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-CSTP-Version", "1")
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址
w.Header().Set("X-CSTP-Netmask", base.Cfg.Ipv4Netmask) // 子网掩码
w.Header().Set("X-CSTP-Netmask", sessdata.IpPool.Ipv4Mask.String()) // 子网掩码
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
// 允许本地LAN访问vpn网络必须放在路由的第一个
@ -131,11 +131,11 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
// w.Header().Set("X-CSTP-Post-Auth-XML", ``)
w.WriteHeader(http.StatusOK)
h := w.Header().Clone()
hClone := w.Header().Clone()
headers := make([]byte, 0)
buf := bytes.NewBuffer(headers)
h.Write(buf)
base.Debug(string(buf.Bytes()))
_ = hClone.Write(buf)
base.Debug(buf.String())
hj := w.(http.Hijacker)
conn, _, err := hj.Hijack()

View File

@ -21,5 +21,5 @@ func Start() {
}
func Stop() {
dbdata.Stop()
_ = dbdata.Stop()
}

View File

@ -48,7 +48,7 @@ func tableLookup(ip net.IP) *Addr {
}
// 判断老化过期时间
tsub := time.Now().Sub(addr.disTime)
tsub := time.Since(addr.disTime)
switch addr.Type {
case TypeNormal:
if tsub > StaleTimeNormal {

View File

@ -48,7 +48,7 @@ func doPing(ip string) error {
return err
}
conn.SetReadDeadline(time.Now().Add(time.Second * 2))
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 2))
for {
buf := make([]byte, 512)

View File

@ -198,8 +198,10 @@ func (p *Conn) checkPrefixOnce() {
func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
_ = p.conn.SetReadDeadline(readDeadLine)
defer func() {
_ = p.conn.SetReadDeadline(time.Time{})
}()
}
// Incrementally check each byte of the prefix

View File

@ -45,8 +45,6 @@ func CopyStruct(a interface{}, b interface{}, fields ...string) (err error) {
// a中有同名的字段并且类型一致才复制
if f.IsValid() && f.Kind() == bValue.Kind() {
f.Set(bValue)
} else {
// fmt.Printf("no such field or different kind, fieldName: %s\n", name)
}
}
return

View File

@ -19,7 +19,8 @@ type ipPoolConfig struct {
mux sync.Mutex
// 计算动态ip
Ipv4Gateway net.IP
Ipv4IPNet net.IPNet
Ipv4Mask net.IP
Ipv4IPNet *net.IPNet
IpLongMin uint32
IpLongMax uint32
}
@ -27,11 +28,12 @@ type ipPoolConfig struct {
func initIpPool() {
// 地址处理
// ip地址
ip := net.ParseIP(base.Cfg.Ipv4Network)
// 子网掩码
maskIp := net.ParseIP(base.Cfg.Ipv4Netmask).To4()
IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)}
_, ipNet, err := net.ParseCIDR(base.Cfg.Ipv4CIDR)
if err != nil {
panic(err)
}
IpPool.Ipv4IPNet = ipNet
IpPool.Ipv4Mask = net.IP(ipNet.Mask)
IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway)
// 网络地址零值
@ -74,11 +76,11 @@ func AcquireIp(username, macAddr string) net.IP {
mi.Username = username
mi.LastLogin = tNow
// 回写db数据
dbdata.Save(mi)
_ = dbdata.Save(mi)
ipActive[ipStr] = true
return ip
} else {
dbdata.Del(mi)
_ = dbdata.Del(mi)
}
}
@ -92,7 +94,7 @@ func AcquireIp(username, macAddr string) net.IP {
if err != nil && dbdata.CheckErrNotFound(err) {
// 该ip没有被使用
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
dbdata.Save(mi)
_ = dbdata.Save(mi)
ipActive[ipStr] = true
return ip
}
@ -121,10 +123,10 @@ func AcquireIp(username, macAddr string) net.IP {
// 已经超过租期
if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second {
dbdata.Del(v)
_ = dbdata.Del(v)
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
// 重写db数据
dbdata.Save(mi)
_ = dbdata.Save(mi)
ipActive[ipStr] = true
return ip
}
@ -145,7 +147,7 @@ func AcquireIp(username, macAddr string) net.IP {
ipStr := ip.String()
mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
// 回写db数据
dbdata.Save(mi)
_ = dbdata.Save(mi)
ipActive[ipStr] = true
return ip
}
@ -160,6 +162,6 @@ func ReleaseIp(ip net.IP, macAddr string) {
err := dbdata.One("IpAddr", ip, mi)
if err == nil {
mi.LastLogin = time.Now()
dbdata.Save(mi)
_ = dbdata.Save(mi)
}
}

View File

@ -15,8 +15,7 @@ import (
func preData(tmpDir string) {
tmpDb := path.Join(tmpDir, "test.db")
base.Cfg.DbFile = tmpDb
base.Cfg.Ipv4Network = "192.168.3.0"
base.Cfg.Ipv4Netmask = "255.255.255.0"
base.Cfg.Ipv4CIDR = "192.168.3.0/24"
base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"}
base.Cfg.MaxClient = 100
base.Cfg.MaxUserClient = 3
@ -26,12 +25,12 @@ func preData(tmpDir string) {
Name: "group1",
Bandwidth: 1000,
}
dbdata.Save(&group)
_ = dbdata.Save(&group)
initIpPool()
}
func cleardata(tmpDir string) {
dbdata.Stop()
_ = dbdata.Stop()
tmpDb := path.Join(tmpDir, "test.db")
os.Remove(tmpDb)
}
@ -45,15 +44,15 @@ func TestIpPool(t *testing.T) {
var ip net.IP
for i := 1; i <= 100; i++ {
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
_ = AcquireIp("user", fmt.Sprintf("mac-%d", i))
}
ip = AcquireIp("user", fmt.Sprintf("mac-new"))
ip = AcquireIp("user", "mac-new")
assert.True(net.IPv4(192, 168, 3, 101).Equal(ip))
for i := 102; i <= 199; i++ {
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
}
assert.True(net.IPv4(192, 168, 3, 199).Equal(ip))
ip = AcquireIp("user", fmt.Sprintf("mac-nil"))
ip = AcquireIp("user", "mac-nil")
assert.Nil(ip)
ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88")

View File

@ -43,10 +43,11 @@ func TestLimitClient(t *testing.T) {
func TestLimitWait(t *testing.T) {
assert := assert.New(t)
limit := NewLimitRater(1, 2)
limit.Wait(2)
start := time.Now()
err := limit.Wait(2)
assert.Nil(err)
start := time.Now()
err = limit.Wait(2)
assert.Nil(err)
err = limit.Wait(1)
assert.Nil(err)
end := time.Now()

View File

@ -34,10 +34,7 @@ func (o Onlines) Len() int {
}
func (o Onlines) Less(i, j int) bool {
if bytes.Compare(o[i].Ip, o[j].Ip) < 0 {
return true
}
return false
return bytes.Compare(o[i].Ip, o[j].Ip) < 0
}
func (o Onlines) Swap(i, j int) {

View File

@ -78,13 +78,13 @@ func checkSession() {
return
}
timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second
tick := time.Tick(time.Second * 60)
for range tick {
tick := time.NewTicker(time.Second * 60)
for range tick.C {
sessMux.Lock()
t := time.Now()
for k, v := range sessions {
v.mux.Lock()
if v.IsActive != true {
if !v.IsActive {
if t.Sub(v.LastLogin) > timeout {
delete(sessions, k)
}
@ -133,12 +133,12 @@ func (s *Session) NewConn() *ConnSession {
macAddr := s.MacAddr
username := s.Username
s.mux.Unlock()
if active == true {
if active {
s.CSess.Close()
}
limit := LimitClient(username, false)
if limit == false {
if !limit {
return nil
}
// 获取客户端mac地址
@ -208,8 +208,10 @@ func (cs *ConnSession) Close() {
const BandwidthPeriodSec = 2 // 流量速率统计周期(秒)
func (cs *ConnSession) ratePeriod() {
tick := time.Tick(time.Second * BandwidthPeriodSec)
for range tick {
tick := time.NewTicker(time.Second * BandwidthPeriodSec)
defer tick.Stop()
for range tick.C {
select {
case <-cs.CloseChan:
return

View File

@ -28,9 +28,11 @@ func TestConnSession(t *testing.T) {
cSess := sess.NewConn()
cSess.RateLimit(100, true)
err := cSess.RateLimit(100, true)
assert.Nil(err)
assert.Equal(cSess.BandwidthUp, uint32(100))
cSess.RateLimit(200, false)
err = cSess.RateLimit(200, false)
assert.Nil(err)
assert.Equal(cSess.BandwidthDown, uint32(200))
cSess.Close()
}