mirror of https://github.com/bjdgyc/anylink.git
修改客户端分配的ip为CIDR格式,请注意原来network格式
This commit is contained in:
parent
1c6572f5e3
commit
edb0fe2dc9
|
@ -16,7 +16,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
|||
// hd, _ := httputil.DumpRequest(r, true)
|
||||
// fmt.Println("DumpRequest: ", string(hd))
|
||||
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
admin_user := r.PostFormValue("admin_user")
|
||||
admin_pass := r.PostFormValue("admin_pass")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func GroupList(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
pageS := r.FormValue("page")
|
||||
page, _ := strconv.Atoi(pageS)
|
||||
if page < 1 {
|
||||
|
@ -48,7 +48,7 @@ func GroupNames(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func GroupDetail(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
|
@ -90,7 +90,7 @@ func GroupSet(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func GroupDel(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func UserIpMapList(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
pageS := r.FormValue("page")
|
||||
page, _ := strconv.Atoi(pageS)
|
||||
if page < 1 {
|
||||
|
@ -39,7 +39,7 @@ func UserIpMapList(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
|
@ -58,7 +58,7 @@ func UserIpMapDetail(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -92,7 +92,7 @@ func UserIpMapSet(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserIpMapDel(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
@ -84,9 +83,8 @@ func SetSystem(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func SetSoft(w http.ResponseWriter, r *http.Request) {
|
||||
datas := base.ServerCfg2Slice()
|
||||
b, _ := json.Marshal(datas)
|
||||
w.Write(b)
|
||||
data := base.ServerCfg2Slice()
|
||||
RespSucess(w, data)
|
||||
}
|
||||
|
||||
func decimal(f float64) float64 {
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
)
|
||||
|
||||
func UserList(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
prefix := r.FormValue("prefix")
|
||||
pageS := r.FormValue("page")
|
||||
page, _ := strconv.Atoi(pageS)
|
||||
|
@ -58,7 +58,7 @@ func UserList(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserDetail(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
if id < 1 {
|
||||
|
@ -77,7 +77,7 @@ func UserDetail(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserSet(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -111,7 +111,7 @@ func UserSet(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserDel(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
|
||||
|
@ -130,7 +130,7 @@ func UserDel(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserOtpQr(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
b64 := r.FormValue("b64")
|
||||
idS := r.FormValue("id")
|
||||
id, _ := strconv.Atoi(idS)
|
||||
|
@ -148,11 +148,16 @@ func UserOtpQr(w http.ResponseWriter, r *http.Request) {
|
|||
if b64 == "1" {
|
||||
data, _ := qr.PNG(300)
|
||||
s := base64.StdEncoding.EncodeToString(data)
|
||||
fmt.Fprint(w, s)
|
||||
} else {
|
||||
qr.Write(300, w)
|
||||
_, err = fmt.Fprint(w, s)
|
||||
if err != nil {
|
||||
base.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = qr.Write(300, w)
|
||||
if err != nil {
|
||||
base.Error(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 在线用户
|
||||
|
@ -169,14 +174,14 @@ func UserOnline(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UserOffline(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
token := r.FormValue("token")
|
||||
sessdata.CloseSess(token)
|
||||
RespSucess(w, nil)
|
||||
}
|
||||
|
||||
func UserReline(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
token := r.FormValue("token")
|
||||
sessdata.CloseCSess(token)
|
||||
RespSucess(w, nil)
|
||||
|
@ -231,7 +236,10 @@ func userAccountMail(user *dbdata.User) error {
|
|||
}
|
||||
w := bytes.NewBufferString("")
|
||||
t, _ := template.New("auth_complete").Parse(htmlBody)
|
||||
t.Execute(w, data)
|
||||
err = t.Execute(w, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// fmt.Println(w.String())
|
||||
return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String())
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/dbdata"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
mail "github.com/xhit/go-simple-mail/v2"
|
||||
// "github.com/mojocn/base64Captcha"
|
||||
)
|
||||
|
||||
func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) {
|
||||
|
@ -43,15 +43,6 @@ func GetJwtData(jwtToken string) (map[string]interface{}, error) {
|
|||
return claims, nil
|
||||
}
|
||||
|
||||
func createCaptcha() {
|
||||
var store = base64Captcha.DefaultMemStore
|
||||
var driver base64Captcha.Driver
|
||||
driverString := &base64Captcha.DriverString{}
|
||||
driver = driverString.ConvertFonts()
|
||||
c := base64Captcha.NewCaptcha(driver, store)
|
||||
_ = c
|
||||
}
|
||||
|
||||
func SendMail(subject, to, htmlBody string) error {
|
||||
|
||||
dataSmtp := &dbdata.SettingSmtp{}
|
||||
|
|
|
@ -43,8 +43,10 @@ func respHttp(w http.ResponseWriter, respCode int, data interface{}, errS ...int
|
|||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
|
||||
_, err = w.Write(b)
|
||||
if err != nil {
|
||||
base.Error(err)
|
||||
}
|
||||
// 记录返回数据
|
||||
// logger.Category("response").Debug(string(b))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRespSucess(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
w := httptest.NewRecorder()
|
||||
RespSucess(w, "data")
|
||||
// fmt.Println(w)
|
||||
assert.Equal(w.Code, 200)
|
||||
body, _ := ioutil.ReadAll(w.Body)
|
||||
res := Resp{}
|
||||
err := json.Unmarshal(body, &res)
|
||||
assert.Nil(err)
|
||||
assert.Equal(res.Code, 0)
|
||||
assert.Equal(res.Data, "data")
|
||||
|
||||
}
|
||||
|
||||
func TestRespError(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
w := httptest.NewRecorder()
|
||||
RespError(w, 10, "err-msg")
|
||||
// fmt.Println(w)
|
||||
assert.Equal(w.Code, 200)
|
||||
body, _ := ioutil.ReadAll(w.Body)
|
||||
res := Resp{}
|
||||
err := json.Unmarshal(body, &res)
|
||||
assert.Nil(err)
|
||||
assert.Equal(res.Code, 10)
|
||||
assert.Equal(res.Msg, "err-msg")
|
||||
}
|
|
@ -2,5 +2,5 @@ package base
|
|||
|
||||
const (
|
||||
APP_NAME = "AnyLink"
|
||||
APP_VER = "0.1.1"
|
||||
APP_VER = "0.1.2"
|
||||
)
|
||||
|
|
|
@ -47,9 +47,8 @@ type ServerConfig struct {
|
|||
AdminPass string `toml:"admin_pass" info:"管理用户密码"`
|
||||
JwtSecret string `toml:"jwt_secret" info:"JWT密钥"`
|
||||
|
||||
LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap
|
||||
Ipv4Network string `toml:"ipv4_network" info:"ipv4_network"` // 192.168.1.0
|
||||
Ipv4Netmask string `toml:"ipv4_netmask" info:"ipv4_netmask"` // 255.255.255.0
|
||||
LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap
|
||||
Ipv4CIDR string `toml:"ipv4_cidr" info:"ip地址网段"` // 192.168.1.0/24
|
||||
Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"`
|
||||
Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200
|
||||
IpLease int `toml:"ip_lease" info:"IP租期(秒)"`
|
||||
|
@ -102,16 +101,17 @@ func getAbsPath(base, cfile string) string {
|
|||
return filepath.Join(base, cfile)
|
||||
}
|
||||
|
||||
func ServerCfg2Slice() interface{} {
|
||||
type SCfg struct {
|
||||
Name string `json:"name"`
|
||||
Info string `json:"info"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
func ServerCfg2Slice() []SCfg {
|
||||
ref := reflect.ValueOf(Cfg)
|
||||
s := ref.Elem()
|
||||
|
||||
type cfg struct {
|
||||
Name string `json:"name"`
|
||||
Info string `json:"info"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
var datas []cfg
|
||||
var datas []SCfg
|
||||
|
||||
typ := s.Type()
|
||||
numFields := s.NumField()
|
||||
|
@ -122,7 +122,7 @@ func ServerCfg2Slice() interface{} {
|
|||
tags := strings.Split(tag, ",")
|
||||
info := field.Tag.Get("info")
|
||||
|
||||
datas = append(datas, cfg{Name: tags[0], Info: info, Data: value.Interface()})
|
||||
datas = append(datas, SCfg{Name: tags[0], Info: info, Data: value.Interface()})
|
||||
}
|
||||
|
||||
return datas
|
||||
|
|
|
@ -36,7 +36,7 @@ func logLevel2Int(l string) int {
|
|||
}
|
||||
lvl := _Info
|
||||
for k, v := range levels {
|
||||
if strings.ToLower(l) == strings.ToLower(v) {
|
||||
if strings.EqualFold(strings.ToLower(l), strings.ToLower(v)) {
|
||||
lvl = k
|
||||
}
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func logLevel2Int(l string) int {
|
|||
|
||||
func output(l int, s ...interface{}) {
|
||||
lvl := fmt.Sprintf("[%s] ", levels[l])
|
||||
baseLog.Output(3, lvl+fmt.Sprintln(s...))
|
||||
_ = baseLog.Output(3, lvl+fmt.Sprintln(s...))
|
||||
}
|
||||
|
||||
func Debug(v ...interface{}) {
|
||||
|
|
|
@ -8,10 +8,6 @@
|
|||
# Define Bridge Interface
|
||||
br="anylink0"
|
||||
|
||||
# Define list of TAP interfaces to be bridged,
|
||||
# for example tap="tap0 tap1 tap2".
|
||||
tap="tap0"
|
||||
|
||||
# Define physical ethernet interface to be bridged
|
||||
# with TAP interface(s) above.
|
||||
|
||||
|
|
|
@ -35,8 +35,7 @@ proxy_protocol = false
|
|||
link_mode = "tun"
|
||||
|
||||
#客户端分配的ip地址池
|
||||
ipv4_network = "192.168.10.0"
|
||||
ipv4_netmask = "255.255.255.0"
|
||||
ipv4_cidr = "192.168.10.0/24"
|
||||
ipv4_gateway = "192.168.10.1"
|
||||
ipv4_pool = ["192.168.10.100", "192.168.10.200"]
|
||||
|
||||
|
|
13
dbdata/db.go
13
dbdata/db.go
|
@ -43,28 +43,27 @@ func initData() {
|
|||
return
|
||||
}
|
||||
|
||||
defer Set(SettingBucket, Installed, true)
|
||||
defer func() {
|
||||
_ = Set(SettingBucket, Installed, true)
|
||||
}()
|
||||
|
||||
smtp := &SettingSmtp{
|
||||
Host: "127.0.0.1",
|
||||
Port: 25,
|
||||
From: "vpn@xx.com",
|
||||
}
|
||||
SettingSet(smtp)
|
||||
_ = SettingSet(smtp)
|
||||
|
||||
other := &SettingOther{
|
||||
Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!",
|
||||
AccountMail: accountMail,
|
||||
}
|
||||
SettingSet(other)
|
||||
_ = SettingSet(other)
|
||||
|
||||
}
|
||||
|
||||
func CheckErrNotFound(err error) bool {
|
||||
if err == storm.ErrNotFound {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return err == storm.ErrNotFound
|
||||
}
|
||||
|
||||
const accountMail = `<p>您好:</p>
|
||||
|
|
|
@ -27,7 +27,7 @@ func TestDb(t *testing.T) {
|
|||
defer closeIpdata()
|
||||
|
||||
u := User{Username: "a"}
|
||||
Save(&u)
|
||||
_ = Save(&u)
|
||||
|
||||
assert.Equal(u.Id, 1)
|
||||
}
|
||||
|
|
|
@ -2,4 +2,5 @@ package handler
|
|||
|
||||
// 暂时没有实现
|
||||
func startDtls() {
|
||||
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) {
|
|||
sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress)
|
||||
sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal
|
||||
other := &dbdata.SettingOther{}
|
||||
dbdata.SettingGet(other)
|
||||
_ = dbdata.SettingGet(other)
|
||||
rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token,
|
||||
Banner: other.Banner}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
@ -102,7 +102,7 @@ const (
|
|||
func tplRequest(typ int, w io.Writer, data RequestData) {
|
||||
if typ == tpl_request {
|
||||
t, _ := template.New("auth_request").Parse(auth_request)
|
||||
t.Execute(w, data)
|
||||
_ = t.Execute(w, data)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,7 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
|
|||
data.Banner = strings.ReplaceAll(data.Banner, "\n", "
")
|
||||
}
|
||||
t, _ := template.New("auth_complete").Parse(auth_complete)
|
||||
t.Execute(w, data)
|
||||
_ = t.Execute(w, data)
|
||||
}
|
||||
|
||||
// 设置输出信息
|
||||
|
|
|
@ -2,11 +2,9 @@ package handler
|
|||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const BufferSize = 2048
|
||||
|
@ -43,27 +41,6 @@ type macAddressList struct {
|
|||
MacAddress string `xml:"mac-address"`
|
||||
}
|
||||
|
||||
// 判断anyconnect客户端
|
||||
func checkLinkClient(h http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO 调试信息输出
|
||||
// hd, _ := httputil.DumpRequest(r, true)
|
||||
// fmt.Println("DumpRequest: ", string(hd))
|
||||
// fmt.Println(r.RemoteAddr)
|
||||
|
||||
userAgent := strings.ToLower(r.UserAgent())
|
||||
x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
|
||||
x_Transcend_Version := r.Header.Get("X-Transcend-Version")
|
||||
if strings.Contains(userAgent, "anyconnect") &&
|
||||
x_Aggregate_Auth == "1" && x_Transcend_Version == "1" {
|
||||
h(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprintf(w, "error request")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setCommonHeader(w http.ResponseWriter) {
|
||||
// Content-Length Date 默认已经存在
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
|
|
@ -26,7 +26,7 @@ func LinkHome(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func LinkOtpQr(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
_ = r.ParseForm()
|
||||
idS := r.FormValue("id")
|
||||
jwtToken := r.FormValue("jwt")
|
||||
data, err := admin.GetJwtData(jwtToken)
|
||||
|
|
|
@ -29,6 +29,9 @@ func checkTap() {
|
|||
bridgeHw = brFace.HardwareAddr
|
||||
|
||||
addrs, err := brFace.Addrs()
|
||||
if err != nil {
|
||||
base.Fatal("testTap err: ", err)
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ip, _, err := net.ParseCIDR(addr.String())
|
||||
if err != nil || ip.To4() == nil {
|
||||
|
|
|
@ -73,9 +73,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER))
|
||||
w.Header().Set("X-CSTP-Version", "1")
|
||||
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
|
||||
w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址
|
||||
w.Header().Set("X-CSTP-Netmask", base.Cfg.Ipv4Netmask) // 子网掩码
|
||||
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
|
||||
w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址
|
||||
w.Header().Set("X-CSTP-Netmask", sessdata.IpPool.Ipv4Mask.String()) // 子网掩码
|
||||
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
|
||||
|
||||
// 允许本地LAN访问vpn网络,必须放在路由的第一个
|
||||
if cSess.Group.AllowLan {
|
||||
|
@ -131,11 +131,11 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
// w.Header().Set("X-CSTP-Post-Auth-XML", ``)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
h := w.Header().Clone()
|
||||
hClone := w.Header().Clone()
|
||||
headers := make([]byte, 0)
|
||||
buf := bytes.NewBuffer(headers)
|
||||
h.Write(buf)
|
||||
base.Debug(string(buf.Bytes()))
|
||||
_ = hClone.Write(buf)
|
||||
base.Debug(buf.String())
|
||||
|
||||
hj := w.(http.Hijacker)
|
||||
conn, _, err := hj.Hijack()
|
||||
|
|
|
@ -21,5 +21,5 @@ func Start() {
|
|||
}
|
||||
|
||||
func Stop() {
|
||||
dbdata.Stop()
|
||||
_ = dbdata.Stop()
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func tableLookup(ip net.IP) *Addr {
|
|||
}
|
||||
|
||||
// 判断老化过期时间
|
||||
tsub := time.Now().Sub(addr.disTime)
|
||||
tsub := time.Since(addr.disTime)
|
||||
switch addr.Type {
|
||||
case TypeNormal:
|
||||
if tsub > StaleTimeNormal {
|
||||
|
|
|
@ -48,7 +48,7 @@ func doPing(ip string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second * 2))
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 2))
|
||||
|
||||
for {
|
||||
buf := make([]byte, 512)
|
||||
|
|
|
@ -198,8 +198,10 @@ func (p *Conn) checkPrefixOnce() {
|
|||
func (p *Conn) checkPrefix() error {
|
||||
if p.proxyHeaderTimeout != 0 {
|
||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||
p.conn.SetReadDeadline(readDeadLine)
|
||||
defer p.conn.SetReadDeadline(time.Time{})
|
||||
_ = p.conn.SetReadDeadline(readDeadLine)
|
||||
defer func() {
|
||||
_ = p.conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
|
|
|
@ -45,8 +45,6 @@ func CopyStruct(a interface{}, b interface{}, fields ...string) (err error) {
|
|||
// a中有同名的字段并且类型一致才复制
|
||||
if f.IsValid() && f.Kind() == bValue.Kind() {
|
||||
f.Set(bValue)
|
||||
} else {
|
||||
// fmt.Printf("no such field or different kind, fieldName: %s\n", name)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -19,7 +19,8 @@ type ipPoolConfig struct {
|
|||
mux sync.Mutex
|
||||
// 计算动态ip
|
||||
Ipv4Gateway net.IP
|
||||
Ipv4IPNet net.IPNet
|
||||
Ipv4Mask net.IP
|
||||
Ipv4IPNet *net.IPNet
|
||||
IpLongMin uint32
|
||||
IpLongMax uint32
|
||||
}
|
||||
|
@ -27,11 +28,12 @@ type ipPoolConfig struct {
|
|||
func initIpPool() {
|
||||
|
||||
// 地址处理
|
||||
// ip地址
|
||||
ip := net.ParseIP(base.Cfg.Ipv4Network)
|
||||
// 子网掩码
|
||||
maskIp := net.ParseIP(base.Cfg.Ipv4Netmask).To4()
|
||||
IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)}
|
||||
_, ipNet, err := net.ParseCIDR(base.Cfg.Ipv4CIDR)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
IpPool.Ipv4IPNet = ipNet
|
||||
IpPool.Ipv4Mask = net.IP(ipNet.Mask)
|
||||
IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway)
|
||||
|
||||
// 网络地址零值
|
||||
|
@ -74,11 +76,11 @@ func AcquireIp(username, macAddr string) net.IP {
|
|||
mi.Username = username
|
||||
mi.LastLogin = tNow
|
||||
// 回写db数据
|
||||
dbdata.Save(mi)
|
||||
_ = dbdata.Save(mi)
|
||||
ipActive[ipStr] = true
|
||||
return ip
|
||||
} else {
|
||||
dbdata.Del(mi)
|
||||
_ = dbdata.Del(mi)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,7 +94,7 @@ func AcquireIp(username, macAddr string) net.IP {
|
|||
if err != nil && dbdata.CheckErrNotFound(err) {
|
||||
// 该ip没有被使用
|
||||
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||
dbdata.Save(mi)
|
||||
_ = dbdata.Save(mi)
|
||||
ipActive[ipStr] = true
|
||||
return ip
|
||||
}
|
||||
|
@ -121,10 +123,10 @@ func AcquireIp(username, macAddr string) net.IP {
|
|||
|
||||
// 已经超过租期
|
||||
if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second {
|
||||
dbdata.Del(v)
|
||||
_ = dbdata.Del(v)
|
||||
mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||
// 重写db数据
|
||||
dbdata.Save(mi)
|
||||
_ = dbdata.Save(mi)
|
||||
ipActive[ipStr] = true
|
||||
return ip
|
||||
}
|
||||
|
@ -145,7 +147,7 @@ func AcquireIp(username, macAddr string) net.IP {
|
|||
ipStr := ip.String()
|
||||
mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow}
|
||||
// 回写db数据
|
||||
dbdata.Save(mi)
|
||||
_ = dbdata.Save(mi)
|
||||
ipActive[ipStr] = true
|
||||
return ip
|
||||
}
|
||||
|
@ -160,6 +162,6 @@ func ReleaseIp(ip net.IP, macAddr string) {
|
|||
err := dbdata.One("IpAddr", ip, mi)
|
||||
if err == nil {
|
||||
mi.LastLogin = time.Now()
|
||||
dbdata.Save(mi)
|
||||
_ = dbdata.Save(mi)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,8 +15,7 @@ import (
|
|||
func preData(tmpDir string) {
|
||||
tmpDb := path.Join(tmpDir, "test.db")
|
||||
base.Cfg.DbFile = tmpDb
|
||||
base.Cfg.Ipv4Network = "192.168.3.0"
|
||||
base.Cfg.Ipv4Netmask = "255.255.255.0"
|
||||
base.Cfg.Ipv4CIDR = "192.168.3.0/24"
|
||||
base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"}
|
||||
base.Cfg.MaxClient = 100
|
||||
base.Cfg.MaxUserClient = 3
|
||||
|
@ -26,12 +25,12 @@ func preData(tmpDir string) {
|
|||
Name: "group1",
|
||||
Bandwidth: 1000,
|
||||
}
|
||||
dbdata.Save(&group)
|
||||
_ = dbdata.Save(&group)
|
||||
initIpPool()
|
||||
}
|
||||
|
||||
func cleardata(tmpDir string) {
|
||||
dbdata.Stop()
|
||||
_ = dbdata.Stop()
|
||||
tmpDb := path.Join(tmpDir, "test.db")
|
||||
os.Remove(tmpDb)
|
||||
}
|
||||
|
@ -45,15 +44,15 @@ func TestIpPool(t *testing.T) {
|
|||
var ip net.IP
|
||||
|
||||
for i := 1; i <= 100; i++ {
|
||||
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
||||
_ = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
||||
}
|
||||
ip = AcquireIp("user", fmt.Sprintf("mac-new"))
|
||||
ip = AcquireIp("user", "mac-new")
|
||||
assert.True(net.IPv4(192, 168, 3, 101).Equal(ip))
|
||||
for i := 102; i <= 199; i++ {
|
||||
ip = AcquireIp("user", fmt.Sprintf("mac-%d", i))
|
||||
}
|
||||
assert.True(net.IPv4(192, 168, 3, 199).Equal(ip))
|
||||
ip = AcquireIp("user", fmt.Sprintf("mac-nil"))
|
||||
ip = AcquireIp("user", "mac-nil")
|
||||
assert.Nil(ip)
|
||||
|
||||
ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88")
|
||||
|
|
|
@ -43,10 +43,11 @@ func TestLimitClient(t *testing.T) {
|
|||
func TestLimitWait(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
limit := NewLimitRater(1, 2)
|
||||
limit.Wait(2)
|
||||
start := time.Now()
|
||||
err := limit.Wait(2)
|
||||
assert.Nil(err)
|
||||
start := time.Now()
|
||||
err = limit.Wait(2)
|
||||
assert.Nil(err)
|
||||
err = limit.Wait(1)
|
||||
assert.Nil(err)
|
||||
end := time.Now()
|
||||
|
|
|
@ -34,10 +34,7 @@ func (o Onlines) Len() int {
|
|||
}
|
||||
|
||||
func (o Onlines) Less(i, j int) bool {
|
||||
if bytes.Compare(o[i].Ip, o[j].Ip) < 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return bytes.Compare(o[i].Ip, o[j].Ip) < 0
|
||||
}
|
||||
|
||||
func (o Onlines) Swap(i, j int) {
|
||||
|
|
|
@ -78,13 +78,13 @@ func checkSession() {
|
|||
return
|
||||
}
|
||||
timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second
|
||||
tick := time.Tick(time.Second * 60)
|
||||
for range tick {
|
||||
tick := time.NewTicker(time.Second * 60)
|
||||
for range tick.C {
|
||||
sessMux.Lock()
|
||||
t := time.Now()
|
||||
for k, v := range sessions {
|
||||
v.mux.Lock()
|
||||
if v.IsActive != true {
|
||||
if !v.IsActive {
|
||||
if t.Sub(v.LastLogin) > timeout {
|
||||
delete(sessions, k)
|
||||
}
|
||||
|
@ -133,12 +133,12 @@ func (s *Session) NewConn() *ConnSession {
|
|||
macAddr := s.MacAddr
|
||||
username := s.Username
|
||||
s.mux.Unlock()
|
||||
if active == true {
|
||||
if active {
|
||||
s.CSess.Close()
|
||||
}
|
||||
|
||||
limit := LimitClient(username, false)
|
||||
if limit == false {
|
||||
if !limit {
|
||||
return nil
|
||||
}
|
||||
// 获取客户端mac地址
|
||||
|
@ -208,8 +208,10 @@ func (cs *ConnSession) Close() {
|
|||
const BandwidthPeriodSec = 2 // 流量速率统计周期(秒)
|
||||
|
||||
func (cs *ConnSession) ratePeriod() {
|
||||
tick := time.Tick(time.Second * BandwidthPeriodSec)
|
||||
for range tick {
|
||||
tick := time.NewTicker(time.Second * BandwidthPeriodSec)
|
||||
defer tick.Stop()
|
||||
|
||||
for range tick.C {
|
||||
select {
|
||||
case <-cs.CloseChan:
|
||||
return
|
||||
|
|
|
@ -28,9 +28,11 @@ func TestConnSession(t *testing.T) {
|
|||
|
||||
cSess := sess.NewConn()
|
||||
|
||||
cSess.RateLimit(100, true)
|
||||
err := cSess.RateLimit(100, true)
|
||||
assert.Nil(err)
|
||||
assert.Equal(cSess.BandwidthUp, uint32(100))
|
||||
cSess.RateLimit(200, false)
|
||||
err = cSess.RateLimit(200, false)
|
||||
assert.Nil(err)
|
||||
assert.Equal(cSess.BandwidthDown, uint32(200))
|
||||
cSess.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue