mirror of https://github.com/bjdgyc/anylink.git
修改客户端分配的ip为CIDR格式,请注意原来network格式
This commit is contained in:
parent
1c6572f5e3
commit
edb0fe2dc9
|
@ -16,7 +16,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
||||||
// hd, _ := httputil.DumpRequest(r, true)
|
// hd, _ := httputil.DumpRequest(r, true)
|
||||||
// fmt.Println("DumpRequest: ", string(hd))
|
// fmt.Println("DumpRequest: ", string(hd))
|
||||||
|
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
admin_user := r.PostFormValue("admin_user")
|
admin_user := r.PostFormValue("admin_user")
|
||||||
admin_pass := r.PostFormValue("admin_pass")
|
admin_pass := r.PostFormValue("admin_pass")
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func GroupList(w http.ResponseWriter, r *http.Request) {
|
func GroupList(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
pageS := r.FormValue("page")
|
pageS := r.FormValue("page")
|
||||||
page, _ := strconv.Atoi(pageS)
|
page, _ := strconv.Atoi(pageS)
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
|
@ -48,7 +48,7 @@ func GroupNames(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GroupDetail(w http.ResponseWriter, r *http.Request) {
|
func GroupDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
if id < 1 {
|
if id < 1 {
|
||||||
|
@ -90,7 +90,7 @@ func GroupSet(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GroupDel(w http.ResponseWriter, r *http.Request) {
|
func GroupDel(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
if id < 1 {
|
if id < 1 {
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func UserIpMapList(w http.ResponseWriter, r *http.Request) {
|
func UserIpMapList(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
pageS := r.FormValue("page")
|
pageS := r.FormValue("page")
|
||||||
page, _ := strconv.Atoi(pageS)
|
page, _ := strconv.Atoi(pageS)
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
|
@ -39,7 +39,7 @@ func UserIpMapList(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
|
func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
if id < 1 {
|
if id < 1 {
|
||||||
|
@ -58,7 +58,7 @@ func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
|
func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -92,7 +92,7 @@ func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserIpMapDel(w http.ResponseWriter, r *http.Request) {
|
func UserIpMapDel(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -84,9 +83,8 @@ func SetSystem(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSoft(w http.ResponseWriter, r *http.Request) {
|
func SetSoft(w http.ResponseWriter, r *http.Request) {
|
||||||
datas := base.ServerCfg2Slice()
|
data := base.ServerCfg2Slice()
|
||||||
b, _ := json.Marshal(datas)
|
RespSucess(w, data)
|
||||||
w.Write(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func decimal(f float64) float64 {
|
func decimal(f float64) float64 {
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func UserList(w http.ResponseWriter, r *http.Request) {
|
func UserList(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
prefix := r.FormValue("prefix")
|
prefix := r.FormValue("prefix")
|
||||||
pageS := r.FormValue("page")
|
pageS := r.FormValue("page")
|
||||||
page, _ := strconv.Atoi(pageS)
|
page, _ := strconv.Atoi(pageS)
|
||||||
|
@ -58,7 +58,7 @@ func UserList(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserDetail(w http.ResponseWriter, r *http.Request) {
|
func UserDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
if id < 1 {
|
if id < 1 {
|
||||||
|
@ -77,7 +77,7 @@ func UserDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserSet(w http.ResponseWriter, r *http.Request) {
|
func UserSet(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -111,7 +111,7 @@ func UserSet(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserDel(w http.ResponseWriter, r *http.Request) {
|
func UserDel(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ func UserDel(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserOtpQr(w http.ResponseWriter, r *http.Request) {
|
func UserOtpQr(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
b64 := r.FormValue("b64")
|
b64 := r.FormValue("b64")
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
id, _ := strconv.Atoi(idS)
|
id, _ := strconv.Atoi(idS)
|
||||||
|
@ -148,11 +148,16 @@ func UserOtpQr(w http.ResponseWriter, r *http.Request) {
|
||||||
if b64 == "1" {
|
if b64 == "1" {
|
||||||
data, _ := qr.PNG(300)
|
data, _ := qr.PNG(300)
|
||||||
s := base64.StdEncoding.EncodeToString(data)
|
s := base64.StdEncoding.EncodeToString(data)
|
||||||
fmt.Fprint(w, s)
|
_, err = fmt.Fprint(w, s)
|
||||||
} else {
|
if err != nil {
|
||||||
qr.Write(300, w)
|
base.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = qr.Write(300, w)
|
||||||
|
if err != nil {
|
||||||
|
base.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 在线用户
|
// 在线用户
|
||||||
|
@ -169,14 +174,14 @@ func UserOnline(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserOffline(w http.ResponseWriter, r *http.Request) {
|
func UserOffline(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
token := r.FormValue("token")
|
token := r.FormValue("token")
|
||||||
sessdata.CloseSess(token)
|
sessdata.CloseSess(token)
|
||||||
RespSucess(w, nil)
|
RespSucess(w, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserReline(w http.ResponseWriter, r *http.Request) {
|
func UserReline(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
token := r.FormValue("token")
|
token := r.FormValue("token")
|
||||||
sessdata.CloseCSess(token)
|
sessdata.CloseCSess(token)
|
||||||
RespSucess(w, nil)
|
RespSucess(w, nil)
|
||||||
|
@ -231,7 +236,10 @@ func userAccountMail(user *dbdata.User) error {
|
||||||
}
|
}
|
||||||
w := bytes.NewBufferString("")
|
w := bytes.NewBufferString("")
|
||||||
t, _ := template.New("auth_complete").Parse(htmlBody)
|
t, _ := template.New("auth_complete").Parse(htmlBody)
|
||||||
t.Execute(w, data)
|
err = t.Execute(w, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// fmt.Println(w.String())
|
// fmt.Println(w.String())
|
||||||
return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String())
|
return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String())
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"github.com/bjdgyc/anylink/base"
|
"github.com/bjdgyc/anylink/base"
|
||||||
"github.com/bjdgyc/anylink/dbdata"
|
"github.com/bjdgyc/anylink/dbdata"
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/mojocn/base64Captcha"
|
|
||||||
mail "github.com/xhit/go-simple-mail/v2"
|
mail "github.com/xhit/go-simple-mail/v2"
|
||||||
|
// "github.com/mojocn/base64Captcha"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) {
|
func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) {
|
||||||
|
@ -43,15 +43,6 @@ func GetJwtData(jwtToken string) (map[string]interface{}, error) {
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createCaptcha() {
|
|
||||||
var store = base64Captcha.DefaultMemStore
|
|
||||||
var driver base64Captcha.Driver
|
|
||||||
driverString := &base64Captcha.DriverString{}
|
|
||||||
driver = driverString.ConvertFonts()
|
|
||||||
c := base64Captcha.NewCaptcha(driver, store)
|
|
||||||
_ = c
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendMail(subject, to, htmlBody string) error {
|
func SendMail(subject, to, htmlBody string) error {
|
||||||
|
|
||||||
dataSmtp := &dbdata.SettingSmtp{}
|
dataSmtp := &dbdata.SettingSmtp{}
|
||||||
|
|
|
@ -43,8 +43,10 @@ func respHttp(w http.ResponseWriter, respCode int, data interface{}, errS ...int
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(b)
|
_, err = w.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
base.Error(err)
|
||||||
|
}
|
||||||
// 记录返回数据
|
// 记录返回数据
|
||||||
// logger.Category("response").Debug(string(b))
|
// logger.Category("response").Debug(string(b))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRespSucess(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
RespSucess(w, "data")
|
||||||
|
// fmt.Println(w)
|
||||||
|
assert.Equal(w.Code, 200)
|
||||||
|
body, _ := ioutil.ReadAll(w.Body)
|
||||||
|
res := Resp{}
|
||||||
|
err := json.Unmarshal(body, &res)
|
||||||
|
assert.Nil(err)
|
||||||
|
assert.Equal(res.Code, 0)
|
||||||
|
assert.Equal(res.Data, "data")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRespError(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
RespError(w, 10, "err-msg")
|
||||||
|
// fmt.Println(w)
|
||||||
|
assert.Equal(w.Code, 200)
|
||||||
|
body, _ := ioutil.ReadAll(w.Body)
|
||||||
|
res := Resp{}
|
||||||
|
err := json.Unmarshal(body, &res)
|
||||||
|
assert.Nil(err)
|
||||||
|
assert.Equal(res.Code, 10)
|
||||||
|
assert.Equal(res.Msg, "err-msg")
|
||||||
|
}
|
|
@ -2,5 +2,5 @@ package base
|
||||||
|
|
||||||
const (
|
const (
|
||||||
APP_NAME = "AnyLink"
|
APP_NAME = "AnyLink"
|
||||||
APP_VER = "0.1.1"
|
APP_VER = "0.1.2"
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,8 +48,7 @@ type ServerConfig struct {
|
||||||
JwtSecret string `toml:"jwt_secret" info:"JWT密钥"`
|
JwtSecret string `toml:"jwt_secret" info:"JWT密钥"`
|
||||||
|
|
||||||
LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap
|
LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap
|
||||||
Ipv4Network string `toml:"ipv4_network" info:"ipv4_network"` // 192.168.1.0
|
Ipv4CIDR string `toml:"ipv4_cidr" info:"ip地址网段"` // 192.168.1.0/24
|
||||||
Ipv4Netmask string `toml:"ipv4_netmask" info:"ipv4_netmask"` // 255.255.255.0
|
|
||||||
Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"`
|
Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"`
|
||||||
Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200
|
Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200
|
||||||
IpLease int `toml:"ip_lease" info:"IP租期(秒)"`
|
IpLease int `toml:"ip_lease" info:"IP租期(秒)"`
|
||||||
|
@ -102,16 +101,17 @@ func getAbsPath(base, cfile string) string {
|
||||||
return filepath.Join(base, cfile)
|
return filepath.Join(base, cfile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServerCfg2Slice() interface{} {
|
type SCfg struct {
|
||||||
ref := reflect.ValueOf(Cfg)
|
|
||||||
s := ref.Elem()
|
|
||||||
|
|
||||||
type cfg struct {
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Info string `json:"info"`
|
Info string `json:"info"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
var datas []cfg
|
|
||||||
|
func ServerCfg2Slice() []SCfg {
|
||||||
|
ref := reflect.ValueOf(Cfg)
|
||||||
|
s := ref.Elem()
|
||||||
|
|
||||||
|
var datas []SCfg
|
||||||
|
|
||||||
typ := s.Type()
|
typ := s.Type()
|
||||||
numFields := s.NumField()
|
numFields := s.NumField()
|
||||||
|
@ -122,7 +122,7 @@ func ServerCfg2Slice() interface{} {
|
||||||
tags := strings.Split(tag, ",")
|
tags := strings.Split(tag, ",")
|
||||||
info := field.Tag.Get("info")
|
info := field.Tag.Get("info")
|
||||||
|
|
||||||
datas = append(datas, cfg{Name: tags[0], Info: info, Data: value.Interface()})
|
datas = append(datas, SCfg{Name: tags[0], Info: info, Data: value.Interface()})
|
||||||
}
|
}
|
||||||
|
|
||||||
return datas
|
return datas
|
||||||
|
|
|
@ -36,7 +36,7 @@ func logLevel2Int(l string) int {
|
||||||
}
|
}
|
||||||
lvl := _Info
|
lvl := _Info
|
||||||
for k, v := range levels {
|
for k, v := range levels {
|
||||||
if strings.ToLower(l) == strings.ToLower(v) {
|
if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) {
|
||||||
lvl = k
|
lvl = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ func logLevel2Int(l string) int {
|
||||||
|
|
||||||
func output(l int, s ...interface{}) {
|
func output(l int, s ...interface{}) {
|
||||||
lvl := fmt.Sprintf("[%s] ", levels[l])
|
lvl := fmt.Sprintf("[%s] ", levels[l])
|
||||||
baseLog.Output(3, lvl+fmt.Sprintln(s...))
|
_ = baseLog.Output(3, lvl+fmt.Sprintln(s...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Debug(v ...interface{}) {
|
func Debug(v ...interface{}) {
|
||||||
|
|
|
@ -8,10 +8,6 @@
|
||||||
# Define Bridge Interface
|
# Define Bridge Interface
|
||||||
br="anylink0"
|
br="anylink0"
|
||||||
|
|
||||||
# Define list of TAP interfaces to be bridged,
|
|
||||||
# for example tap="tap0 tap1 tap2".
|
|
||||||
tap="tap0"
|
|
||||||
|
|
||||||
# Define physical ethernet interface to be bridged
|
# Define physical ethernet interface to be bridged
|
||||||
# with TAP interface(s) above.
|
# with TAP interface(s) above.
|
||||||
|
|
||||||
|
|
|
@ -35,8 +35,7 @@ proxy_protocol = false
|
||||||
link_mode = "tun"
|
link_mode = "tun"
|
||||||
|
|
||||||
#客户端分配的ip地址池
|
#客户端分配的ip地址池
|
||||||
ipv4_network = "192.168.10.0"
|
ipv4_cidr = "192.168.10.0/24"
|
||||||
ipv4_netmask = "255.255.255.0"
|
|
||||||
ipv4_gateway = "192.168.10.1"
|
ipv4_gateway = "192.168.10.1"
|
||||||
ipv4_pool = ["192.168.10.100", "192.168.10.200"]
|
ipv4_pool = ["192.168.10.100", "192.168.10.200"]
|
||||||
|
|
||||||
|
|
13
dbdata/db.go
13
dbdata/db.go
|
@ -43,28 +43,27 @@ func initData() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer Set(SettingBucket, Installed, true)
|
defer func() {
|
||||||
|
_ = Set(SettingBucket, Installed, true)
|
||||||
|
}()
|
||||||
|
|
||||||
smtp := &SettingSmtp{
|
smtp := &SettingSmtp{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: 25,
|
Port: 25,
|
||||||
From: "vpn@xx.com",
|
From: "vpn@xx.com",
|
||||||
}
|
}
|
||||||
SettingSet(smtp)
|
_ = SettingSet(smtp)
|
||||||
|
|
||||||
other := &SettingOther{
|
other := &SettingOther{
|
||||||
Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!",
|
Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!",
|
||||||
AccountMail: accountMail,
|
AccountMail: accountMail,
|
||||||
}
|
}
|
||||||
SettingSet(other)
|
_ = SettingSet(other)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckErrNotFound(err error) bool {
|
func CheckErrNotFound(err error) bool {
|
||||||
if err == storm.ErrNotFound {
|
return err == storm.ErrNotFound
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const accountMail = `<p>您好:</p>
|
const accountMail = `<p>您好:</p>
|
||||||
|
|
|
@ -27,7 +27,7 @@ func TestDb(t *testing.T) {
|
||||||
defer closeIpdata()
|
defer closeIpdata()
|
||||||
|
|
||||||
u := User{Username: "a"}
|
u := User{Username: "a"}
|
||||||
Save(&u)
|
_ = Save(&u)
|
||||||
|
|
||||||
assert.Equal(u.Id, 1)
|
assert.Equal(u.Id, 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,4 +2,5 @@ package handler
|
||||||
|
|
||||||
// 暂时没有实现
|
// 暂时没有实现
|
||||||
func startDtls() {
|
func startDtls() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress)
|
sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress)
|
||||||
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
|
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
|
||||||
other := &dbdata.SettingOther{}
|
other := &dbdata.SettingOther{}
|
||||||
dbdata.SettingGet(other)
|
_ = dbdata.SettingGet(other)
|
||||||
rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token,
|
rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token,
|
||||||
Banner: other.Banner}
|
Banner: other.Banner}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
@ -102,7 +102,7 @@ const (
|
||||||
func tplRequest(typ int, w io.Writer, data RequestData) {
|
func tplRequest(typ int, w io.Writer, data RequestData) {
|
||||||
if typ == tpl_request {
|
if typ == tpl_request {
|
||||||
t, _ := template.New("auth_request").Parse(auth_request)
|
t, _ := template.New("auth_request").Parse(auth_request)
|
||||||
t.Execute(w, data)
|
_ = t.Execute(w, data)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
|
||||||
data.Banner = strings.ReplaceAll(data.Banner, "\n", "
")
|
data.Banner = strings.ReplaceAll(data.Banner, "\n", "
")
|
||||||
}
|
}
|
||||||
t, _ := template.New("auth_complete").Parse(auth_complete)
|
t, _ := template.New("auth_complete").Parse(auth_complete)
|
||||||
t.Execute(w, data)
|
_ = t.Execute(w, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置输出信息
|
// 设置输出信息
|
||||||
|
|
|
@ -2,11 +2,9 @@ package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const BufferSize = 2048
|
const BufferSize = 2048
|
||||||
|
@ -43,27 +41,6 @@ type macAddressList struct {
|
||||||
MacAddress string `xml:"mac-address"`
|
MacAddress string `xml:"mac-address"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 判断anyconnect客户端
|
|
||||||
func checkLinkClient(h http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// TODO 调试信息输出
|
|
||||||
// hd, _ := httputil.DumpRequest(r, true)
|
|
||||||
// fmt.Println("DumpRequest: ", string(hd))
|
|
||||||
// fmt.Println(r.RemoteAddr)
|
|
||||||
|
|
||||||
userAgent := strings.ToLower(r.UserAgent())
|
|
||||||
x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
|
|
||||||
x_Transcend_Version := r.Header.Get("X-Transcend-Version")
|
|
||||||
if strings.Contains(userAgent, "anyconnect") &&
|
|
||||||
x_Aggregate_Auth == "1" && x_Transcend_Version == "1" {
|
|
||||||
h(w, r)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
fmt.Fprintf(w, "error request")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCommonHeader(w http.ResponseWriter) {
|
func setCommonHeader(w http.ResponseWriter) {
|
||||||
// Content-Length Date 默认已经存在
|
// Content-Length Date 默认已经存在
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
|
|
@ -26,7 +26,7 @@ func LinkHome(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func LinkOtpQr(w http.ResponseWriter, r *http.Request) {
|
func LinkOtpQr(w http.ResponseWriter, r *http.Request) {
|
||||||
r.ParseForm()
|
_ = r.ParseForm()
|
||||||
idS := r.FormValue("id")
|
idS := r.FormValue("id")
|
||||||
jwtToken := r.FormValue("jwt")
|
jwtToken := r.FormValue("jwt")
|
||||||
data, err := admin.GetJwtData(jwtToken)
|
data, err := admin.GetJwtData(jwtToken)
|
||||||
|
|
|
@ -29,6 +29,9 @@ func checkTap() {
|
||||||
bridgeHw = brFace.HardwareAddr
|
bridgeHw = brFace.HardwareAddr
|
||||||
|
|
||||||
addrs, err := brFace.Addrs()
|
addrs, err := brFace.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
base.Fatal("testTap err: ", err)
|
||||||
|
}
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
ip, _, err := net.ParseCIDR(addr.String())
|
ip, _, err := net.ParseCIDR(addr.String())
|
||||||
if err != nil || ip.To4() == nil {
|
if err != nil || ip.To4() == nil {
|
||||||
|
|
|
@ -74,7 +74,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-CSTP-Version", "1")
|
w.Header().Set("X-CSTP-Version", "1")
|
||||||
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
|
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
|
||||||
w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址
|
w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址
|
||||||
w.Header().Set("X-CSTP-Netmask", base.Cfg.Ipv4Netmask) // 子网掩码
|
w.Header().Set("X-CSTP-Netmask", sessdata.IpPool.Ipv4Mask.String()) // 子网掩码
|
||||||
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
|
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
|
||||||
|
|
||||||
// 允许本地LAN访问vpn网络,必须放在路由的第一个
|
// 允许本地LAN访问vpn网络,必须放在路由的第一个
|
||||||
|
@ -131,11 +131,11 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||||
// w.Header().Set("X-CSTP-Post-Auth-XML", ``)
|
// w.Header().Set("X-CSTP-Post-Auth-XML", ``)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
h := w.Header().Clone()
|
hClone := w.Header().Clone()
|
||||||
headers := make([]byte, 0)
|
headers := make([]byte, 0)
|
||||||
buf := bytes.NewBuffer(headers)
|
buf := bytes.NewBuffer(headers)
|
||||||
h.Write(buf)
|
_ = hClone.Write(buf)
|
||||||
base.Debug(string(buf.Bytes()))
|
base.Debug(buf.String())
|
||||||
|
|
||||||
hj := w.(http.Hijacker)
|
hj := w.(http.Hijacker)
|
||||||
conn, _, err := hj.Hijack()
|
conn, _, err := hj.Hijack()
|
||||||
|
|
|
@ -21,5 +21,5 @@ func Start() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Stop() {
|
func Stop() {
|
||||||
dbdata.Stop()
|
_ = dbdata.Stop()
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ func tableLookup(ip net.IP) *Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 判断老化过期时间
|
// 判断老化过期时间
|
||||||
tsub := time.Now().Sub(addr.disTime)
|
tsub := time.Since(addr.disTime)
|
||||||
switch addr.Type {
|
switch addr.Type {
|
||||||
case TypeNormal:
|
case TypeNormal:
|
||||||
if tsub > StaleTimeNormal {
|
if tsub > StaleTimeNormal {
|
||||||
|
|
|
@ -48,7 +48,7 @@ func doPing(ip string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.SetReadDeadline(time.Now().Add(time.Second * 2))
|
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 2))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, 512)
|
buf := make([]byte, 512)
|
||||||
|
|
|
@ -198,8 +198,10 @@ func (p *Conn) checkPrefixOnce() {
|
||||||
func (p *Conn) checkPrefix() error {
|
func (p *Conn) checkPrefix() error {
|
||||||
if p.proxyHeaderTimeout != 0 {
|
if p.proxyHeaderTimeout != 0 {
|
||||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||||
p.conn.SetReadDeadline(readDeadLine)
|
_ = p.conn.SetReadDeadline(readDeadLine)
|
||||||
defer p.conn.SetReadDeadline(time.Time{})
|
defer func() {
|
||||||
|
_ = p.conn.SetReadDeadline(time.Time{})
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incrementally check each byte of the prefix
|
// Incrementally check each byte of the prefix
|
||||||
|
|
|
@ -45,8 +45,6 @@ func CopyStruct(a interface{}, b interface{}, fields ...string) (err error) {
|
||||||
// a中有同名的字段并且类型一致才复制
|
// a中有同名的字段并且类型一致才复制
|
||||||
if f.IsValid() && f.Kind() == bValue.Kind() {
|
if f.IsValid() && f.Kind() == bValue.Kind() {
|
||||||
f.Set(bValue)
|
f.Set(bValue)
|
||||||
} else {
|
|
||||||
// fmt.Printf("no such field or different kind, fieldName: %s\n", name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -19,7 +19,8 @@ type ipPoolConfig struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
// 计算动态ip
|
// 计算动态ip
|
||||||
Ipv4Gateway net.IP
|
Ipv4Gateway net.IP
|
||||||
Ipv4IPNet net.IPNet
|
Ipv4Mask net.IP
|
||||||
|
Ipv4IPNet *net.IPNet
|
||||||
IpLongMin uint32
|
IpLongMin uint32
|
||||||
IpLongMax uint32
|
IpLongMax uint32
|
||||||
}
|
}
|
||||||
|
@ -27,11 +28,12 @@ type ipPoolConfig struct {
|
||||||
func initIpPool() {
|
func initIpPool() {
|
||||||
|
|
||||||
// 地址处理
|
// 地址处理
|
||||||
// ip地址
|
_, ipNet, err := net.ParseCIDR(base.Cfg.Ipv4CIDR)
|
||||||
ip := net.ParseIP(base.Cfg.Ipv4Network)
|
if err != nil {
|
||||||
// 子网掩码
|
panic(err)
|
||||||
maskIp := net.ParseIP(base.Cfg.Ipv4Netmask).To4()
|
}
|
||||||
IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)}
|
IpPool.Ipv4IPNet = ipNet
|
||||||
|
IpPool.Ipv4Mask = net.IP(ipNet.Mask)
|
||||||
IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway)
|
IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway)
|
||||||
|
|
||||||
// 网络地址零值
|
// 网络地址零值
|
||||||
|
@ -74,11 +76,11 @@ func AcquireIp(username, macAddr string) net.IP {
|
||||||
mi.Username = username
|
mi.Username = username
|
||||||
mi.LastLogin = tNow
|
mi.LastLogin = tNow
|
||||||
// 回写db数据
|
// 回写db数据
|
||||||
dbdata.Save(mi)
|
_ = dbdata.Save(mi)
|
||||||
ipActive[ipStr] = true
|
ipActive[ipStr] = true
|
||||||
return ip
|
return ip
|
||||||
} else {
|
} else {
|
||||||
dbdata.Del(mi)
|
_ = dbdata.Del(mi)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +94,7 @@ func AcquireIp(username, macAddr string) net.IP {
|
||||||
if err != nil && dbdata.CheckErrNotFound(err) {
|
if err != nil && dbdata.CheckErrNotFound(err) {
|
||||||
// 该ip没有被使用
|
// 该ip没有被使用
|
||||||
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||||
dbdata.Save(mi)
|
_ = dbdata.Save(mi)
|
||||||
ipActive[ipStr] = true
|
ipActive[ipStr] = true
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
@ -121,10 +123,10 @@ func AcquireIp(username, macAddr string) net.IP {
|
||||||
|
|
||||||
// 已经超过租期
|
// 已经超过租期
|
||||||
if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second {
|
if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second {
|
||||||
dbdata.Del(v)
|
_ = dbdata.Del(v)
|
||||||
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||||
// 重写db数据
|
// 重写db数据
|
||||||
dbdata.Save(mi)
|
_ = dbdata.Save(mi)
|
||||||
ipActive[ipStr] = true
|
ipActive[ipStr] = true
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
@ -145,7 +147,7 @@ func AcquireIp(username, macAddr string) net.IP {
|
||||||
ipStr := ip.String()
|
ipStr := ip.String()
|
||||||
mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||||
// 回写db数据
|
// 回写db数据
|
||||||
dbdata.Save(mi)
|
_ = dbdata.Save(mi)
|
||||||
ipActive[ipStr] = true
|
ipActive[ipStr] = true
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
@ -160,6 +162,6 @@ func ReleaseIp(ip net.IP, macAddr string) {
|
||||||
err := dbdata.One("IpAddr", ip, mi)
|
err := dbdata.One("IpAddr", ip, mi)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
mi.LastLogin = time.Now()
|
mi.LastLogin = time.Now()
|
||||||
dbdata.Save(mi)
|
_ = dbdata.Save(mi)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,7 @@ import (
|
||||||
func preData(tmpDir string) {
|
func preData(tmpDir string) {
|
||||||
tmpDb := path.Join(tmpDir, "test.db")
|
tmpDb := path.Join(tmpDir, "test.db")
|
||||||
base.Cfg.DbFile = tmpDb
|
base.Cfg.DbFile = tmpDb
|
||||||
base.Cfg.Ipv4Network = "192.168.3.0"
|
base.Cfg.Ipv4CIDR = "192.168.3.0/24"
|
||||||
base.Cfg.Ipv4Netmask = "255.255.255.0"
|
|
||||||
base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"}
|
base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"}
|
||||||
base.Cfg.MaxClient = 100
|
base.Cfg.MaxClient = 100
|
||||||
base.Cfg.MaxUserClient = 3
|
base.Cfg.MaxUserClient = 3
|
||||||
|
@ -26,12 +25,12 @@ func preData(tmpDir string) {
|
||||||
Name: "group1",
|
Name: "group1",
|
||||||
Bandwidth: 1000,
|
Bandwidth: 1000,
|
||||||
}
|
}
|
||||||
dbdata.Save(&group)
|
_ = dbdata.Save(&group)
|
||||||
initIpPool()
|
initIpPool()
|
||||||
}
|
}
|
||||||
|
|
||||||
func cleardata(tmpDir string) {
|
func cleardata(tmpDir string) {
|
||||||
dbdata.Stop()
|
_ = dbdata.Stop()
|
||||||
tmpDb := path.Join(tmpDir, "test.db")
|
tmpDb := path.Join(tmpDir, "test.db")
|
||||||
os.Remove(tmpDb)
|
os.Remove(tmpDb)
|
||||||
}
|
}
|
||||||
|
@ -45,15 +44,15 @@ func TestIpPool(t *testing.T) {
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
|
|
||||||
for i := 1; i <= 100; i++ {
|
for i := 1; i <= 100; i++ {
|
||||||
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
_ = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
||||||
}
|
}
|
||||||
ip = AcquireIp("user", fmt.Sprintf("mac-new"))
|
ip = AcquireIp("user", "mac-new")
|
||||||
assert.True(net.IPv4(192, 168, 3, 101).Equal(ip))
|
assert.True(net.IPv4(192, 168, 3, 101).Equal(ip))
|
||||||
for i := 102; i <= 199; i++ {
|
for i := 102; i <= 199; i++ {
|
||||||
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
||||||
}
|
}
|
||||||
assert.True(net.IPv4(192, 168, 3, 199).Equal(ip))
|
assert.True(net.IPv4(192, 168, 3, 199).Equal(ip))
|
||||||
ip = AcquireIp("user", fmt.Sprintf("mac-nil"))
|
ip = AcquireIp("user", "mac-nil")
|
||||||
assert.Nil(ip)
|
assert.Nil(ip)
|
||||||
|
|
||||||
ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88")
|
ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88")
|
||||||
|
|
|
@ -43,10 +43,11 @@ func TestLimitClient(t *testing.T) {
|
||||||
func TestLimitWait(t *testing.T) {
|
func TestLimitWait(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
limit := NewLimitRater(1, 2)
|
limit := NewLimitRater(1, 2)
|
||||||
limit.Wait(2)
|
|
||||||
start := time.Now()
|
|
||||||
err := limit.Wait(2)
|
err := limit.Wait(2)
|
||||||
assert.Nil(err)
|
assert.Nil(err)
|
||||||
|
start := time.Now()
|
||||||
|
err = limit.Wait(2)
|
||||||
|
assert.Nil(err)
|
||||||
err = limit.Wait(1)
|
err = limit.Wait(1)
|
||||||
assert.Nil(err)
|
assert.Nil(err)
|
||||||
end := time.Now()
|
end := time.Now()
|
||||||
|
|
|
@ -34,10 +34,7 @@ func (o Onlines) Len() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Onlines) Less(i, j int) bool {
|
func (o Onlines) Less(i, j int) bool {
|
||||||
if bytes.Compare(o[i].Ip, o[j].Ip) < 0 {
|
return bytes.Compare(o[i].Ip, o[j].Ip) < 0
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Onlines) Swap(i, j int) {
|
func (o Onlines) Swap(i, j int) {
|
||||||
|
|
|
@ -78,13 +78,13 @@ func checkSession() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second
|
timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second
|
||||||
tick := time.Tick(time.Second * 60)
|
tick := time.NewTicker(time.Second * 60)
|
||||||
for range tick {
|
for range tick.C {
|
||||||
sessMux.Lock()
|
sessMux.Lock()
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
for k, v := range sessions {
|
for k, v := range sessions {
|
||||||
v.mux.Lock()
|
v.mux.Lock()
|
||||||
if v.IsActive != true {
|
if !v.IsActive {
|
||||||
if t.Sub(v.LastLogin) > timeout {
|
if t.Sub(v.LastLogin) > timeout {
|
||||||
delete(sessions, k)
|
delete(sessions, k)
|
||||||
}
|
}
|
||||||
|
@ -133,12 +133,12 @@ func (s *Session) NewConn() *ConnSession {
|
||||||
macAddr := s.MacAddr
|
macAddr := s.MacAddr
|
||||||
username := s.Username
|
username := s.Username
|
||||||
s.mux.Unlock()
|
s.mux.Unlock()
|
||||||
if active == true {
|
if active {
|
||||||
s.CSess.Close()
|
s.CSess.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
limit := LimitClient(username, false)
|
limit := LimitClient(username, false)
|
||||||
if limit == false {
|
if !limit {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 获取客户端mac地址
|
// 获取客户端mac地址
|
||||||
|
@ -208,8 +208,10 @@ func (cs *ConnSession) Close() {
|
||||||
const BandwidthPeriodSec = 2 // 流量速率统计周期(秒)
|
const BandwidthPeriodSec = 2 // 流量速率统计周期(秒)
|
||||||
|
|
||||||
func (cs *ConnSession) ratePeriod() {
|
func (cs *ConnSession) ratePeriod() {
|
||||||
tick := time.Tick(time.Second * BandwidthPeriodSec)
|
tick := time.NewTicker(time.Second * BandwidthPeriodSec)
|
||||||
for range tick {
|
defer tick.Stop()
|
||||||
|
|
||||||
|
for range tick.C {
|
||||||
select {
|
select {
|
||||||
case <-cs.CloseChan:
|
case <-cs.CloseChan:
|
||||||
return
|
return
|
||||||
|
|
|
@ -28,9 +28,11 @@ func TestConnSession(t *testing.T) {
|
||||||
|
|
||||||
cSess := sess.NewConn()
|
cSess := sess.NewConn()
|
||||||
|
|
||||||
cSess.RateLimit(100, true)
|
err := cSess.RateLimit(100, true)
|
||||||
|
assert.Nil(err)
|
||||||
assert.Equal(cSess.BandwidthUp, uint32(100))
|
assert.Equal(cSess.BandwidthUp, uint32(100))
|
||||||
cSess.RateLimit(200, false)
|
err = cSess.RateLimit(200, false)
|
||||||
|
assert.Nil(err)
|
||||||
assert.Equal(cSess.BandwidthDown, uint32(200))
|
assert.Equal(cSess.BandwidthDown, uint32(200))
|
||||||
cSess.Close()
|
cSess.Close()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue