diff --git a/.gitignore b/.gitignore index e26bd7b..b52a489 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,6 @@ # Dependency directories (remove the comment below to include it) vendor/ - +ui/ .idea/ anylink \ No newline at end of file diff --git a/README.md b/README.md index 4f02bfa..62d5f7d 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,46 @@ # AnyLink -AnyLink 是一个企业级远程办公vpn软件,可以支持多人同时在线使用。 +[![PkgGoDev](https://pkg.go.dev/badge/github.com/bjdgyc/anylink)](https://pkg.go.dev/github.com/bjdgyc/anylink) + +AnyLink 是一个企业级远程办公ssl vpn软件,可以支持多人同时在线使用。 + +## Repo + +> github: https://github.com/bjdgyc/anylink + +> gitee: https://gitee.com/bjdgyc/anylink ## Introduction -AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannopoulos-openconnect-02) 协议开发,并且借鉴了 [ocserv](http://ocserv.gitlab.io/www/index.html) 的开发思路,使其可以同时兼容 AnyConnect 客户端。 +AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannopoulos-openconnect-02) +协议开发,并且借鉴了 [ocserv](http://ocserv.gitlab.io/www/index.html) 的开发思路,使其可以同时兼容 AnyConnect 客户端。 AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。 AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun/tap功能、ip设置命令。 - ## Installation ``` +rootPath=`pwd` + git clone https://github.com/bjdgyc/anylink.git -cd anylink +git clone https://github.com/bjdgyc/anylink-web.git + +cd $rootPath/anylink-web +npm install +npm run build + +cd $rootPath/anylink go build -o anylink -ldflags "-X main.COMMIT_ID=`git rev-parse HEAD`" + +mkdir $linkPath/anylink-deploy +$linkPath/anylink-deploy +cp -r $rootPath/anylink-web/ui . +cp -r $rootPath/anylink/anylink . +cp -r $rootPath/anylink/conf . +cp -r $rootPath/anylink/downfiles . + #注意使用root权限运行 sudo ./anylink -conf="conf/server.toml" ``` @@ -28,17 +52,16 @@ sudo ./anylink -conf="conf/server.toml" - [x] 兼容AnyConnect - [x] 基于tun设备的nat访问模式 - [x] 基于tap设备的桥接访问模式 -- [x] 多用户支持 - [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议 +- [x] 用户组支持 +- [x] 多用户支持 +- [x] TOTP令牌支持 +- [x] 流量控制 +- [x] 后台管理界面 -- [ ] 用户组支持 -- [ ] TOTP令牌支持 -- [ ] 流量控制 - [ ] 访问权限管理 -- [ ] 后台管理界面 - [ ] DTLS-UDP通道 - ## Config 默认配置文件内有详细的注释,根据注释填写配置即可。 @@ -47,11 +70,9 @@ sudo ./anylink -conf="conf/server.toml" ## Setting -网络模式选择,需要配置 `link_mode` 参数,如 `link_mode="tun"`,`link_mode="tap"` 两种参数。 -不同的参数需要对服务器做相应的设置。 +网络模式选择,需要配置 `link_mode` 参数,如 `link_mode="tun"`,`link_mode="tap"` 两种参数。 不同的参数需要对服务器做相应的设置。 -建议优先选择tun模式,因客户端传输的是IP层数据,无须进行数据转换。 -tap模式是在用户态做的链路层到IP层的数据互相转换,性能会有所下降。 +建议优先选择tun模式,因客户端传输的是IP层数据,无须进行数据转换。 tap模式是在用户态做的链路层到IP层的数据互相转换,性能会有所下降。 如果需要在虚拟机内开启tap模式,请确认虚拟机的网卡开启混杂模式。 ### tun设置 @@ -69,18 +90,17 @@ tap模式是在用户态做的链路层到IP层的数据互相转换,性能会 # eth0为服务器内网网卡 iptables -t nat -A POSTROUTING -s 192.168.10.0/255.255.255.0 -o eth0 -j MASQUERADE ``` - + 3. 使用AnyConnect客户端连接即可 - ### tap设置 - + 1. 创建桥接网卡 ``` 注意 server.toml 的ip参数,需要与 bridge.sh 的配置参数一致 ``` - -2. 修改 bridge-init.sh 内的参数 + +2. 修改 bridge.sh 内的参数 ``` # file: ./bridge.sh eth="eth0" @@ -89,13 +109,11 @@ tap模式是在用户态做的链路层到IP层的数据互相转换,性能会 eth_broadcast="192.168.1.255" eth_gateway="192.168.1.1" ``` - + 3. 执行 bridge.sh 文件 ``` sh bridge.sh ``` - - ## License diff --git a/admin/api_base.go b/admin/api_base.go new file mode 100644 index 0000000..805b9d4 --- /dev/null +++ b/admin/api_base.go @@ -0,0 +1,80 @@ +package admin + +import ( + "fmt" + "net/http" + "time" + + "github.com/bjdgyc/anylink/pkg/utils" + + "github.com/bjdgyc/anylink/base" + "github.com/gorilla/mux" +) + +// 登陆接口 +func Login(w http.ResponseWriter, r *http.Request) { + // TODO 调试信息输出 + // hd, _ := httputil.DumpRequest(r, true) + // fmt.Println("DumpRequest: ", string(hd)) + + r.ParseForm() + admin_user := r.PostFormValue("admin_user") + admin_pass := r.PostFormValue("admin_pass") + + // 认证错误 + if !(admin_user == base.Cfg.AdminUser && + utils.PasswordVerify(admin_pass, base.Cfg.AdminPass)) { + RespError(w, RespUserOrPassErr) + return + } + + // token有效期 + expiresAt := time.Now().Unix() + 3600*3 + jwtData := map[string]interface{}{"admin_user": admin_user} + tokenString, err := SetJwtData(jwtData, expiresAt) + if err != nil { + RespError(w, 1, err) + return + } + + data := make(map[string]interface{}) + data["token"] = tokenString + data["admin_user"] = admin_user + data["expires_at"] = expiresAt + + RespSucess(w, data) +} + +func authMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "*") + if r.Method == http.MethodOptions { + return + } + + route := mux.CurrentRoute(r) + name := route.GetName() + // fmt.Println("bb", r.URL.Path, name) + if name == "login" || name == "static" { + // 不进行鉴权 + next.ServeHTTP(w, r) + return + } + + // 进行登陆鉴权 + jwtToken := r.Header.Get("Jwt") + if jwtToken == "" { + jwtToken = r.FormValue("jwt") + } + data, err := GetJwtData(jwtToken) + if err != nil || base.Cfg.AdminUser != fmt.Sprint(data["admin_user"]) { + w.WriteHeader(http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} diff --git a/admin/api_group.go b/admin/api_group.go new file mode 100644 index 0000000..ba9302a --- /dev/null +++ b/admin/api_group.go @@ -0,0 +1,108 @@ +package admin + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "strconv" + + "github.com/bjdgyc/anylink/dbdata" +) + +func GroupList(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + pageS := r.FormValue("page") + page, _ := strconv.Atoi(pageS) + if page < 1 { + page = 1 + } + + var pageSize = dbdata.PageSize + + count := dbdata.CountAll(&dbdata.Group{}) + + var datas []dbdata.Group + err := dbdata.All(&datas, pageSize, page) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + data := map[string]interface{}{ + "count": count, + "page_size": pageSize, + "datas": datas, + } + + RespSucess(w, data) +} + +func GroupNames(w http.ResponseWriter, r *http.Request) { + var names = dbdata.GetGroupNames() + data := map[string]interface{}{ + "count": len(names), + "page_size": 0, + "datas": names, + } + RespSucess(w, data) +} + +func GroupDetail(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + if id < 1 { + RespError(w, RespParamErr, "Id错误") + return + } + + var data dbdata.Group + err := dbdata.One("Id", id, &data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + RespSucess(w, data) +} + +func GroupSet(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + defer r.Body.Close() + v := &dbdata.Group{} + err = json.Unmarshal(body, v) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + err = dbdata.SetGroup(v) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + RespSucess(w, nil) +} + +func GroupDel(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + if id < 1 { + RespError(w, RespParamErr, "Id错误") + return + } + + data := dbdata.Group{Id: id} + err := dbdata.Del(&data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + RespSucess(w, nil) +} diff --git a/admin/api_ip_map.go b/admin/api_ip_map.go new file mode 100644 index 0000000..ca4a024 --- /dev/null +++ b/admin/api_ip_map.go @@ -0,0 +1,111 @@ +package admin + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "strconv" + "time" + + "github.com/bjdgyc/anylink/dbdata" +) + +func UserIpMapList(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + pageS := r.FormValue("page") + page, _ := strconv.Atoi(pageS) + if page < 1 { + page = 1 + } + + var pageSize = dbdata.PageSize + + count := dbdata.CountAll(&dbdata.IpMap{}) + + var datas []dbdata.IpMap + err := dbdata.All(&datas, pageSize, page) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + data := map[string]interface{}{ + "count": count, + "page_size": pageSize, + "datas": datas, + } + + RespSucess(w, data) +} + +func UserIpMapDetail(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + if id < 1 { + RespError(w, RespParamErr, "用户名错误") + return + } + + var data dbdata.IpMap + err := dbdata.One("Id", id, &data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + RespSucess(w, data) +} + +func UserIpMapSet(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + defer r.Body.Close() + v := &dbdata.IpMap{} + err = json.Unmarshal(body, v) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + // fmt.Println(v, len(v.Ip), len(v.MacAddr)) + + if len(v.IpAddr) < 4 || len(v.MacAddr) < 6 { + RespError(w, RespParamErr, "IP或MAC错误") + return + } + + v.UpdatedAt = time.Now() + err = dbdata.Save(v) + + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + RespSucess(w, nil) +} + +func UserIpMapDel(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + + if id < 1 { + RespError(w, RespParamErr, "IP映射id错误") + return + } + + data := dbdata.IpMap{Id: id} + err := dbdata.Del(&data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + RespSucess(w, nil) +} diff --git a/admin/api_other.go b/admin/api_other.go new file mode 100644 index 0000000..00f926a --- /dev/null +++ b/admin/api_other.go @@ -0,0 +1,62 @@ +package admin + +import ( + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/bjdgyc/anylink/dbdata" +) + +func setOtherGet(data interface{}, w http.ResponseWriter) { + err := dbdata.SettingGet(data) + if err != nil && !dbdata.CheckErrNotFound(err) { + RespError(w, RespInternalErr, err) + return + } + RespSucess(w, data) +} + +func setOtherEdit(data interface{}, w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + defer r.Body.Close() + + err = json.Unmarshal(body, data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + // fmt.Println(data) + + err = dbdata.SettingSet(data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + RespSucess(w, data) +} + +func SetOtherSmtp(w http.ResponseWriter, r *http.Request) { + data := &dbdata.SettingSmtp{} + setOtherGet(data, w) +} + +func SetOtherSmtpEdit(w http.ResponseWriter, r *http.Request) { + data := &dbdata.SettingSmtp{} + setOtherEdit(data, w, r) +} + +func SetOther(w http.ResponseWriter, r *http.Request) { + data := &dbdata.SettingOther{} + setOtherGet(data, w) +} + +func SetOtherEdit(w http.ResponseWriter, r *http.Request) { + data := &dbdata.SettingOther{} + setOtherEdit(data, w, r) +} diff --git a/admin/api_set.go b/admin/api_set.go new file mode 100644 index 0000000..6be5737 --- /dev/null +++ b/admin/api_set.go @@ -0,0 +1,95 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "runtime" + + "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/sessdata" + + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/pkg/utils" + "github.com/shirou/gopsutil/cpu" + "github.com/shirou/gopsutil/disk" + "github.com/shirou/gopsutil/host" + "github.com/shirou/gopsutil/load" + "github.com/shirou/gopsutil/mem" +) + +func SetHome(w http.ResponseWriter, r *http.Request) { + data := make(map[string]interface{}) + + sess := sessdata.OnlineSess() + + data["counts"] = map[string]int{ + "online": len(sess), + "user": dbdata.CountAll(&dbdata.User{}), + "group": dbdata.CountAll(&dbdata.Group{}), + "ip_map": dbdata.CountAll(&dbdata.IpMap{}), + } + + RespSucess(w, data) +} + +func SetSystem(w http.ResponseWriter, r *http.Request) { + data := make(map[string]interface{}) + + m, _ := mem.VirtualMemory() + data["mem"] = map[string]interface{}{ + "total": utils.HumanByte(m.Total), + "free": utils.HumanByte(m.Free), + "percent": decimal(m.UsedPercent), + } + + d, _ := disk.Usage("/") + data["disk"] = map[string]interface{}{ + "total": utils.HumanByte(d.Total), + "free": utils.HumanByte(d.Free), + "percent": decimal(d.UsedPercent), + } + + cc, _ := cpu.Counts(true) + c, _ := cpu.Info() + ci := c[0] + cpuUsedPercent, _ := cpu.Percent(0, false) + cup := cpuUsedPercent[0] + if cup == 0 { + cup = 1 + } + data["cpu"] = map[string]interface{}{ + "core": cc, + "modelName": ci.ModelName, + "ghz": fmt.Sprintf("%.2f GHz", ci.Mhz/1000), + "percent": decimal(cup), + } + + hi, _ := host.Info() + l, _ := load.Avg() + data["sys"] = map[string]interface{}{ + "goOs": runtime.GOOS, + "goArch": runtime.GOARCH, + "goVersion": runtime.Version(), + "goroutine": runtime.NumGoroutine(), + + "hostname": hi.Hostname, + "platform": fmt.Sprintf("%v %v %v", hi.Platform, hi.PlatformFamily, hi.PlatformVersion), + "kernel": hi.KernelVersion, + + "load": fmt.Sprint(l.Load1, l.Load5, l.Load15), + } + + RespSucess(w, data) +} + +func SetSoft(w http.ResponseWriter, r *http.Request) { + datas := base.ServerCfg2Slice() + b, _ := json.Marshal(datas) + w.Write(b) +} + +func decimal(f float64) float64 { + i := int(f * 100) + return float64(i) / 100 +} diff --git a/admin/api_user.go b/admin/api_user.go new file mode 100644 index 0000000..6e3db9e --- /dev/null +++ b/admin/api_user.go @@ -0,0 +1,226 @@ +package admin + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "strings" + "text/template" + "time" + + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/sessdata" + "github.com/skip2/go-qrcode" +) + +func UserList(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + prefix := r.FormValue("prefix") + pageS := r.FormValue("page") + page, _ := strconv.Atoi(pageS) + if page < 1 { + page = 1 + } + + var ( + pageSize = dbdata.PageSize + count int + datas []dbdata.User + err error + ) + + // 查询前缀匹配 + if len(prefix) > 0 { + count = pageSize + err = dbdata.Prefix("Username", prefix, &datas, pageSize, 1) + } else { + count = dbdata.CountAll(&dbdata.User{}) + err = dbdata.All(&datas, pageSize, page) + } + + if err != nil && !dbdata.CheckErrNotFound(err) { + RespError(w, RespInternalErr, err) + return + } + + data := map[string]interface{}{ + "count": count, + "page_size": pageSize, + "datas": datas, + } + + RespSucess(w, data) +} + +func UserDetail(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + if id < 1 { + RespError(w, RespParamErr, "用户名错误") + return + } + + var user dbdata.User + err := dbdata.One("Id", id, &user) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + RespSucess(w, user) +} + +func UserSet(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + defer r.Body.Close() + data := &dbdata.User{} + err = json.Unmarshal(body, data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + err = dbdata.SetUser(data) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + // 发送邮件 + if data.SendEmail { + userAccountMail(data) + } + + RespSucess(w, nil) +} + +func UserDel(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + + if id < 1 { + RespError(w, RespParamErr, "用户id错误") + return + } + + user := dbdata.User{Id: id} + err := dbdata.Del(&user) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + RespSucess(w, nil) +} + +func UserOtpQr(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + b64 := r.FormValue("b64") + idS := r.FormValue("id") + id, _ := strconv.Atoi(idS) + var user dbdata.User + err := dbdata.One("Id", id, &user) + if err != nil { + RespError(w, RespInternalErr, err) + return + } + + issuer := base.Cfg.Issuer + qrstr := fmt.Sprintf("otpauth://totp/%s:%s?issuer=%s&secret=%s", issuer, user.Email, issuer, user.OtpSecret) + qr, _ := qrcode.New(qrstr, qrcode.High) + + if b64 == "1" { + data, _ := qr.PNG(300) + s := base64.StdEncoding.EncodeToString(data) + fmt.Fprint(w, s) + } else { + qr.Write(300, w) + } + +} + +// 在线用户 +func UserOnline(w http.ResponseWriter, r *http.Request) { + datas := sessdata.OnlineSess() + + data := map[string]interface{}{ + "count": len(datas), + "page_size": dbdata.PageSize, + "datas": datas, + } + + RespSucess(w, data) +} + +func UserOffline(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + token := r.FormValue("token") + sessdata.CloseSess(token) + RespSucess(w, nil) +} + +type userAccountMailData struct { + Issuer string + LinkAddr string + Group string + Username string + PinCode string + OtpImg string +} + +func userAccountMail(user *dbdata.User) error { + // 平台通知 + htmlBody := ` + + + + + Hello AnyLink! + + +%s + + +` + dataOther := &dbdata.SettingOther{} + err := dbdata.SettingGet(dataOther) + if err != nil { + base.Error(err) + return err + } + htmlBody = fmt.Sprintf(htmlBody, dataOther.AccountMail) + // fmt.Println(htmlBody) + + // token有效期3天 + expiresAt := time.Now().Unix() + 3600*24*3 + jwtData := map[string]interface{}{"id": user.Id} + tokenString, err := SetJwtData(jwtData, expiresAt) + if err != nil { + return err + } + + data := userAccountMailData{ + LinkAddr: base.Cfg.LinkAddr, + Group: strings.Join(user.Groups, ","), + Username: user.Username, + PinCode: user.PinCode, + OtpImg: fmt.Sprintf("https://%s/otp_qr?id=%d&jwt=%s", base.Cfg.LinkAddr, user.Id, tokenString), + } + w := bytes.NewBufferString("") + t, _ := template.New("auth_complete").Parse(htmlBody) + t.Execute(w, data) + // fmt.Println(w.String()) + return SendMail(base.Cfg.Issuer+"平台通知", user.Email, w.String()) +} diff --git a/admin/common.go b/admin/common.go new file mode 100644 index 0000000..e8d31ea --- /dev/null +++ b/admin/common.go @@ -0,0 +1,115 @@ +package admin + +import ( + "crypto/tls" + "errors" + "time" + + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" + "github.com/dgrijalva/jwt-go" + "github.com/mojocn/base64Captcha" + mail "github.com/xhit/go-simple-mail/v2" +) + +func SetJwtData(data map[string]interface{}, expiresAt int64) (string, error) { + jwtData := jwt.MapClaims{"exp": expiresAt} + for k, v := range data { + jwtData[k] = v + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtData) + + // Sign and get the complete encoded token as a string using the secret + tokenString, err := token.SignedString([]byte(base.Cfg.JwtSecret)) + return tokenString, err +} + +func GetJwtData(jwtToken string) (map[string]interface{}, error) { + token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) { + // since we only use the one private key to sign the tokens, + // we also only use its public counter part to verify + return []byte(base.Cfg.JwtSecret), nil + }) + + if err != nil || !token.Valid { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("data is parse err") + } + + return claims, nil +} + +func createCaptcha() { + var store = base64Captcha.DefaultMemStore + var driver base64Captcha.Driver + driverString := &base64Captcha.DriverString{} + driver = driverString.ConvertFonts() + c := base64Captcha.NewCaptcha(driver, store) + _ = c +} + +func SendMail(subject, to, htmlBody string) error { + + dataSmtp := &dbdata.SettingSmtp{} + err := dbdata.SettingGet(dataSmtp) + if err != nil { + base.Error(err) + return err + } + + server := mail.NewSMTPClient() + + // SMTP Server + server.Host = dataSmtp.Host + server.Port = dataSmtp.Port + server.Username = dataSmtp.Username + server.Password = dataSmtp.Password + // server.Encryption = mail.EncryptionTLS + + // Since v2.3.0 you can specified authentication type: + // - PLAIN (default) + // - LOGIN + // - CRAM-MD5 + server.Authentication = mail.AuthPlain + + // Variable to keep alive connection + server.KeepAlive = false + + // Timeout for connect to SMTP Server + server.ConnectTimeout = 10 * time.Second + + // Timeout for send the data and wait respond + server.SendTimeout = 10 * time.Second + + // Set TLSConfig to provide custom TLS configuration. For example, + // to skip TLS verification (useful for testing): + server.TLSConfig = &tls.Config{InsecureSkipVerify: true} + + // SMTP client + smtpClient, err := server.Connect() + + if err != nil { + base.Error(err) + return err + } + + // New email simple html with inline and CC + email := mail.NewMSG() + email.SetFrom(dataSmtp.From). + AddTo(to). + SetSubject(subject) + + email.SetBody(mail.TextHTML, htmlBody) + + // Call Send and pass the client + err = email.Send(smtpClient) + if err != nil { + base.Error(err) + } + + return err +} diff --git a/admin/error.go b/admin/error.go new file mode 100644 index 0000000..ad5d4cb --- /dev/null +++ b/admin/error.go @@ -0,0 +1,15 @@ +package admin + +// 返回码 +const ( + RespSuccess = 0 + RespInternalErr = 1 + RespTokenErr = 2 + RespUserOrPassErr = 3 + RespParamErr = 4 +) + +var RespMap = map[int]string{ + RespTokenErr: "客户端TOKEN错误", + RespUserOrPassErr: "用户名或密码错误", +} diff --git a/admin/resp.go b/admin/resp.go new file mode 100644 index 0000000..fdd953f --- /dev/null +++ b/admin/resp.go @@ -0,0 +1,62 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "runtime" + + "github.com/bjdgyc/anylink/base" +) + +type Resp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Location string `json:"location"` + Data interface{} `json:"data"` +} + +func respHttp(w http.ResponseWriter, respCode int, data interface{}, errS ...interface{}) { + resp := Resp{ + Code: respCode, + Msg: "success", + Data: data, + } + _, file, line, _ := runtime.Caller(2) + resp.Location = fmt.Sprintf("%v:%v", file, line) + + if respCode != 0 { + resp.Msg = "" + if v, ok := RespMap[respCode]; ok { + resp.Msg += v + } + + if len(errS) > 0 { + resp.Msg += fmt.Sprint(errS...) + } + } + + b, err := json.Marshal(resp) + if err != nil { + base.Error(err, resp) + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(b) + + // 记录返回数据 + // logger.Category("response").Debug(string(b)) +} + +func RespSucess(w http.ResponseWriter, data interface{}) { + respHttp(w, 0, data, "") +} + +func RespError(w http.ResponseWriter, respCode int, errS ...interface{}) { + respHttp(w, respCode, nil, errS...) +} + +func RespData(w http.ResponseWriter, data interface{}, err error) { + respHttp(w, http.StatusOK, data, "") +} diff --git a/admin/server.go b/admin/server.go new file mode 100644 index 0000000..838104a --- /dev/null +++ b/admin/server.go @@ -0,0 +1,66 @@ +// admin:后台管理接口 +package admin + +import ( + "net/http" + "net/http/pprof" + + "github.com/bjdgyc/anylink/base" + "github.com/gorilla/mux" +) + +// 开启服务 +func StartAdmin() { + r := mux.NewRouter() + r.Use(authMiddleware) + + r.HandleFunc("/base/login", Login).Name("login") + r.HandleFunc("/set/home", SetHome) + r.HandleFunc("/set/system", SetSystem) + r.HandleFunc("/set/soft", SetSoft) + r.HandleFunc("/set/other", SetOther) + r.HandleFunc("/set/other/edit", SetOtherEdit) + r.HandleFunc("/set/other/smtp", SetOtherSmtp) + r.HandleFunc("/set/other/smtp/edit", SetOtherSmtpEdit) + + r.HandleFunc("/user/list", UserList) + r.HandleFunc("/user/detail", UserDetail) + r.HandleFunc("/user/set", UserSet) + r.HandleFunc("/user/del", UserDel) + r.HandleFunc("/user/online", UserOnline) + r.HandleFunc("/user/offline", UserOffline) + r.HandleFunc("/user/otp_qr", UserOtpQr) + r.HandleFunc("/user/ip_map/list", UserIpMapList) + r.HandleFunc("/user/ip_map/detail", UserIpMapDetail) + r.HandleFunc("/user/ip_map/set", UserIpMapSet) + r.HandleFunc("/user/ip_map/del", UserIpMapDel) + + r.HandleFunc("/group/list", GroupList) + r.HandleFunc("/group/names", GroupNames) + r.HandleFunc("/group/detail", GroupDetail) + r.HandleFunc("/group/set", GroupSet) + r.HandleFunc("/group/del", GroupDel) + + // pprof + r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + r.HandleFunc("/debug/pprof/profile", pprof.Profile) + r.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + r.HandleFunc("/debug/pprof/trace", pprof.Trace) + r.HandleFunc("/debug/pprof", location("/debug/pprof/")) + r.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index) + + r.PathPrefix("/").Handler(http.FileServer(http.Dir(base.Cfg.UiPath))).Name("static") + + base.Info("Listen admin", base.Cfg.AdminAddr) + err := http.ListenAndServe(base.Cfg.AdminAddr, r) + if err != nil { + base.Fatal(err) + } +} + +func location(url string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", url) + w.WriteHeader(http.StatusFound) + } +} diff --git a/base/app_ver.go b/base/app_ver.go new file mode 100644 index 0000000..3cb5287 --- /dev/null +++ b/base/app_ver.go @@ -0,0 +1,6 @@ +package base + +const ( + APP_NAME = "AnyLink" + APP_VER = "0.0.5" +) diff --git a/base/cfg_server.go b/base/cfg_server.go new file mode 100644 index 0000000..f08e266 --- /dev/null +++ b/base/cfg_server.go @@ -0,0 +1,123 @@ +package base + +import ( + "fmt" + "io/ioutil" + "path/filepath" + "reflect" + "strings" + + "github.com/pelletier/go-toml" +) + +const ( + LinkModeTUN = "tun" + LinkModeTAP = "tap" +) + +var ( + Cfg = &ServerConfig{} +) + +// # ReKey time (in seconds) +// rekey-time = 172800 +// # ReKey method +// # Valid options: ssl, new-tunnel +// # ssl: Will perform an efficient rehandshake on the channel allowing +// # a seamless connection during rekey. +// # new-tunnel: Will instruct the client to discard and re-establish the channel. +// # Use this option only if the connecting clients have issues with the ssl +// # option. +// rekey-method = ssl + +type ServerConfig struct { + LinkAddr string `toml:"link_addr" info:"vpn服务对外地址"` + ServerAddr string `toml:"server_addr" info:"前台服务监听地址"` + AdminAddr string `toml:"admin_addr" info:"后台服务监听地址"` + ProxyProtocol bool `toml:"proxy_protocol" info:"TCP代理协议"` + DbFile string `toml:"db_file" info:"数据库地址"` + CertFile string `toml:"cert_file" info:"证书文件"` + CertKey string `toml:"cert_key" info:"证书密钥"` + UiPath string `toml:"ui_path" info:"ui文件路径"` + FilesPath string `toml:"files_path" info:"外部下载文件路径"` + LogLevel string `toml:"log_level" info:"日志等级"` + Issuer string `toml:"issuer" info:"系统名称"` + AdminUser string `toml:"admin_user" info:"管理用户名"` + AdminPass string `toml:"admin_pass" info:"管理用户密码"` + JwtSecret string `toml:"jwt_secret" info:"JWT密钥"` + + LinkMode string `toml:"link_mode" info:"虚拟网络类型"` // tun tap + Ipv4Network string `toml:"ipv4_network" info:"ipv4_network"` // 192.168.1.0 + Ipv4Netmask string `toml:"ipv4_netmask" info:"ipv4_netmask"` // 255.255.255.0 + Ipv4Gateway string `toml:"ipv4_gateway" info:"ipv4_gateway"` + Ipv4Pool []string `toml:"ipv4_pool" info:"IPV4起止地址池"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200 + IpLease int `toml:"ip_lease" info:"IP租期(秒)"` + + MaxClient int `toml:"max_client" info:"最大用户连接"` + MaxUserClient int `toml:"max_user_client" info:"最大单用户连接"` + DefaultGroup string `toml:"default_group" info:"默认用户组"` + CstpKeepalive int `toml:"cstp_keepalive" info:"keepalive时间(秒)"` // in seconds + CstpDpd int `toml:"cstp_dpd" info:"死链接检测时间(秒)"` // Dead peer detection in seconds + MobileKeepalive int `toml:"mobile_keepalive" info:"移动端keepalive接检测时间(秒)"` + MobileDpd int `toml:"mobile_dpd" info:"移动端死链接检测时间(秒)"` + + SessionTimeout int `toml:"session_timeout" info:"session过期时间(秒)"` // in seconds + AuthTimeout int `toml:"auth_timeout" info:"auth_timeout"` // in seconds +} + +func initServerCfg() { + b, err := ioutil.ReadFile(serverFile) + if err != nil { + panic(err) + } + err = toml.Unmarshal(b, Cfg) + if err != nil { + panic(err) + } + + sf, _ := filepath.Abs(serverFile) + base := filepath.Dir(sf) + + // 转换成绝对路径 + Cfg.DbFile = getAbsPath(base, Cfg.DbFile) + Cfg.CertFile = getAbsPath(base, Cfg.CertFile) + Cfg.CertKey = getAbsPath(base, Cfg.CertKey) + Cfg.UiPath = getAbsPath(base, Cfg.UiPath) + Cfg.FilesPath = getAbsPath(base, Cfg.FilesPath) + + fmt.Printf("ServerCfg: %+v \n", Cfg) +} + +func getAbsPath(base, cfile string) string { + abs := filepath.IsAbs(cfile) + if abs { + return cfile + } + return filepath.Join(base, cfile) +} + +func ServerCfg2Slice() interface{} { + ref := reflect.ValueOf(Cfg) + s := ref.Elem() + + type cfg struct { + Name string `json:"name"` + Info string `json:"info"` + Data interface{} `json:"data"` + } + var datas []cfg + + typ := s.Type() + numFields := s.NumField() + for i := 0; i < numFields; i++ { + field := typ.Field(i) + value := s.Field(i) + tag := field.Tag.Get("toml") + tags := strings.Split(tag, ",") + info := field.Tag.Get("info") + + datas = append(datas, cfg{Name: tags[0], Info: info, Data: value.Interface()}) + } + + return datas +} diff --git a/common/flag.go b/base/flag.go similarity index 66% rename from common/flag.go rename to base/flag.go index b0ada7d..800fa4f 100644 --- a/common/flag.go +++ b/base/flag.go @@ -1,10 +1,12 @@ -package common +package base import ( "flag" "fmt" "os" "runtime" + + "github.com/bjdgyc/anylink/pkg/utils" ) var ( @@ -12,24 +14,27 @@ var ( CommitId string // 配置文件 serverFile string + // pass明文 + passwd string // 显示版本信息 rev bool ) func initFlag() { flag.StringVar(&serverFile, "conf", "./conf/server.toml", "server config file path") + flag.StringVar(&passwd, "passwd", "", "the password plaintext") flag.BoolVar(&rev, "rev", false, "display version info") flag.Parse() + if passwd != "" { + pass, _ := utils.PasswordHash(passwd) + fmt.Printf("Passwd:%s\n", pass) + os.Exit(0) + } + if rev { fmt.Printf("%s v%s build on %s [%s, %s] commit_id(%s) \n", APP_NAME, APP_VER, runtime.Version(), runtime.GOOS, runtime.GOARCH, CommitId) os.Exit(0) } } - -func InitConfig() { - initFlag() - loadServer() - initLog() -} diff --git a/base/log.go b/base/log.go new file mode 100644 index 0000000..a4a7974 --- /dev/null +++ b/base/log.go @@ -0,0 +1,90 @@ +package base + +import ( + "fmt" + "log" + "os" + "strings" +) + +const ( + _Debug = iota + _Info + _Warn + _Error + _Fatal +) + +var ( + baseLog *log.Logger + baseLevel int + level map[int]string +) + +func initLog() { + baseLog = log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) + baseLevel = logLevel2Int(Cfg.LogLevel) +} + +func logLevel2Int(l string) int { + level = map[int]string{ + _Debug: "Debug", + _Info: "Info", + _Warn: "Warn", + _Error: "Error", + _Fatal: "Fatal", + } + lvl := _Info + for k, v := range level { + if strings.ToLower(l) == strings.ToLower(v) { + lvl = k + } + } + return lvl +} + +func output(l int, s ...interface{}) { + lvl := fmt.Sprintf("[%s] ", level[l]) + baseLog.Output(3, lvl+fmt.Sprintln(s...)) +} + +func Debug(v ...interface{}) { + l := _Debug + if baseLevel > l { + return + } + output(l, v...) +} + +func Info(v ...interface{}) { + l := _Info + if baseLevel > l { + return + } + output(l, v...) +} + +func Warn(v ...interface{}) { + l := _Warn + if baseLevel > l { + return + } + output(l, v...) +} + +func Error(v ...interface{}) { + l := _Error + if baseLevel > l { + return + } + output(l, v...) +} + +func Fatal(v ...interface{}) { + l := _Fatal + if baseLevel > l { + return + } + output(l, v...) + os.Exit(1) +} diff --git a/base/start.go b/base/start.go new file mode 100644 index 0000000..61bd8bf --- /dev/null +++ b/base/start.go @@ -0,0 +1,7 @@ +package base + +func Start() { + initFlag() + initServerCfg() + initLog() +} diff --git a/bridge-init.sh b/bridge-init.sh index f3ca4ca..f60dac1 100644 --- a/bridge-init.sh +++ b/bridge-init.sh @@ -16,10 +16,10 @@ tap="tap0" # with TAP interface(s) above. eth="eth0" -eth_ip="192.168.1.4" +eth_ip="192.168.10.4" eth_netmask="255.255.255.0" -eth_broadcast="192.168.1.255" -eth_gateway="192.168.1.1" +eth_broadcast="192.168.10.255" +eth_gateway="192.168.10.1" brctl addbr $br diff --git a/common/app_ver.go b/common/app_ver.go deleted file mode 100644 index 85c3d0c..0000000 --- a/common/app_ver.go +++ /dev/null @@ -1,6 +0,0 @@ -package common - -const ( - APP_NAME = "AnyLink" - APP_VER = "0.0.3" -) diff --git a/common/cfg_server.go b/common/cfg_server.go deleted file mode 100644 index 27cb661..0000000 --- a/common/cfg_server.go +++ /dev/null @@ -1,89 +0,0 @@ -package common - -import ( - "fmt" - "io/ioutil" - "path/filepath" - - "github.com/pelletier/go-toml" -) - -const ( - LinkModeTUN = "tun" - LinkModeTAP = "tap" -) - -var ( - ServerCfg = &ServerConfig{} -) - -// # ReKey time (in seconds) -// rekey-time = 172800 -// # ReKey method -// # Valid options: ssl, new-tunnel -// # ssl: Will perform an efficient rehandshake on the channel allowing -// # a seamless connection during rekey. -// # new-tunnel: Will instruct the client to discard and re-establish the channel. -// # Use this option only if the connecting clients have issues with the ssl -// # option. -// rekey-method = ssl - -type ServerConfig struct { - ServerAddr string `toml:"server_addr"` - AdminAddr string `toml:"admin_addr"` - ProxyProtocol bool `toml:"proxy_protocol"` - DbFile string `toml:"db_file"` - CertFile string `toml:"cert_file"` - CertKey string `toml:"cert_key"` - LogLevel string `toml:"log_level"` - - LinkMode string `toml:"link_mode"` // tun tap - Ipv4Network string `toml:"ipv4_network"` // 192.168.1.0 - Ipv4Netmask string `toml:"ipv4_netmask"` // 255.255.255.0 - Ipv4Gateway string `toml:"ipv4_gateway"` - Ipv4Pool []string `toml:"ipv4_pool"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200 - Include []string `toml:"include"` // 10.10.10.0/255.255.255.0 - Exclude []string `toml:"exclude"` // 192.168.5.0/255.255.255.0 - ClientDns []string `toml:"client_dns"` // 114.114.114.114 - AllowLan bool `toml:"allow_lan"` // 允许本地LAN访问vpn网络 - MaxClient int `toml:"max_client"` - MaxUserClient int `toml:"max_user_client"` - - UserGroups []string `toml:"user_groups"` - DefaultGroup string `toml:"default_group"` - Banner string `toml:"banner"` // 欢迎语 - CstpDpd int `toml:"cstp_dpd"` // Dead peer detection in seconds - MobileDpd int `toml:"mobile_dpd"` - CstpKeepalive int `toml:"cstp_keepalive"` // in seconds - SessionTimeout int `toml:"session_timeout"` // in seconds - AuthTimeout int `toml:"auth_timeout"` // in seconds -} - -func loadServer() { - b, err := ioutil.ReadFile(serverFile) - if err != nil { - panic(err) - } - err = toml.Unmarshal(b, ServerCfg) - if err != nil { - panic(err) - } - - sf, _ := filepath.Abs(serverFile) - base := filepath.Dir(sf) - - // 转换成绝对路径 - ServerCfg.DbFile = getAbsPath(base, ServerCfg.DbFile) - ServerCfg.CertFile = getAbsPath(base, ServerCfg.CertFile) - ServerCfg.CertKey = getAbsPath(base, ServerCfg.CertKey) - - fmt.Printf("ServerCfg: %+v \n", ServerCfg) -} - -func getAbsPath(base, cfile string) string { - abs := filepath.IsAbs(cfile) - if abs { - return cfile - } - return filepath.Join(base, cfile) -} diff --git a/common/log.go b/common/log.go deleted file mode 100644 index 8481bbf..0000000 --- a/common/log.go +++ /dev/null @@ -1,73 +0,0 @@ -package common - -import ( - "log" - "os" -) - -const ( - debug = iota - info - error - fatal -) - -var Log *logger - -type logger struct { - log *log.Logger - level int -} - -func initLog() { - // log.SetFlags(log.LstdFlags | log.Lshortfile) - l := log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) - Log = &logger{log: l, level: logLevel2Int(ServerCfg.LogLevel)} -} - -func logLevel2Int(l string) int { - switch l { - case "debug": - return debug - case "info": - return info - case "error": - return error - case "fatal": - return fatal - default: - return info - } -} - -func (l *logger) Debug(v ...interface{}) { - if l.level > debug { - return - } - data := append([]interface{}{"[Debug]"}, v...) - l.log.Println(data...) -} - -func (l *logger) Info(v ...interface{}) { - if l.level > info { - return - } - data := append([]interface{}{"[Info]"}, v...) - l.log.Println(data...) -} - -func (l *logger) Error(v ...interface{}) { - if l.level > error { - return - } - data := append([]interface{}{"[Error]"}, v...) - l.log.Println(data...) -} - -func (l *logger) Fatal(v ...interface{}) { - if l.level > fatal { - return - } - data := append([]interface{}{"[Fatal]"}, v...) - l.log.Fatalln(data...) -} diff --git a/common/util.go b/common/util.go deleted file mode 100644 index 6f87d48..0000000 --- a/common/util.go +++ /dev/null @@ -1,39 +0,0 @@ -package common - -import "fmt" - -func InArrStr(arr []string, str string) bool { - for _, d := range arr { - if d == str { - return true - } - } - return false -} - -const ( - KB = 1024 - MB = 1024 * KB - GB = 1024 * MB - TB = 1024 * GB - PB = 1024 * TB -) - -func HumanByte(bAll float64) string { - var hb string - - switch { - case bAll >= TB: - hb = fmt.Sprintf("%0.2f TB", bAll/TB) - case bAll >= GB: - hb = fmt.Sprintf("%0.2f GB", bAll/GB) - case bAll >= MB: - hb = fmt.Sprintf("%0.2f MB", bAll/MB) - case bAll >= KB: - hb = fmt.Sprintf("%0.2f KB", bAll/KB) - default: - hb = fmt.Sprintf("%0.2f B", bAll) - } - - return hb -} diff --git a/conf/server.toml b/conf/server.toml index 197348b..a3c4d66 100644 --- a/conf/server.toml +++ b/conf/server.toml @@ -8,11 +8,26 @@ db_file = "./data.db" #证书文件 cert_file = "./vpn_cert.pem" cert_key = "./vpn_cert.key" +ui_path = "../ui" +files_path = "../downfiles" + log_level = "info" -#服务监听的地址 +#系统名称 +issuer = "XX公司VPN" +#后台管理用户 +admin_user = "admin" +#pass 123456 +admin_pass = "$2a$10$UQ7C.EoPifDeJh6d8.31TeSPQU7hM/NOM2nixmBucJpAuXDQNqNke" +jwt_secret = "7IrsKW3JuDJ68TPPrdsfweDFYJrO1Xg7JcdsfasMv3P3" + + +#vpn服务对外地址 +link_addr = "vpn.xx.com" + +#前台服务监听地址 server_addr = ":443" -#一般设置 本机地址 +#后台服务监听地址 admin_addr = ":8800" #开启tcp proxy protocol协议 proxy_protocol = false @@ -25,33 +40,21 @@ ipv4_netmask = "255.255.255.0" ipv4_gateway = "192.168.10.1" ipv4_pool = ["192.168.10.100", "192.168.10.200"] -#需加密传输的ip规则 -#include = ["10.10.10.0/255.255.255.0"] -#非加密传输的ip规则 -#exclude = ["192.168.5.0/255.255.255.0"] -#客户端使用的dns ios客户端必须配置 -client_dns = ["114.114.114.114"] -#是否允许本地LAN访问vpn网络 -allow_lan = true - #最大客户端数量 -max_client = 300 +max_client = 100 #单个用户同时在线数量 max_user_client = 3 +#IP租期(秒) +ip_lease = 1209600 - -#用户组 -user_groups = ["one", "two"] #默认选择的组 -default_group = "two" - -#登陆成功的欢迎语 -banner = "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!" +default_group = "one" #客户端失效检测时间(秒) dpd > keepalive -cstp_dpd = 30 cstp_keepalive = 20 -mobile_dpd = 300 +cstp_dpd = 30 +mobile_keepalive = 50 +mobile_dpd = 60 #session过期时间,用于断线重连,0永不过期 session_timeout = 3600 auth_timeout = 0 diff --git a/dbdata/db.go b/dbdata/db.go index 67c5d8b..b5bb20b 100644 --- a/dbdata/db.go +++ b/dbdata/db.go @@ -1,137 +1,90 @@ package dbdata import ( - "encoding/json" - "errors" - "log" + "time" - "github.com/bjdgyc/anylink/common" + "github.com/asdine/storm/v3" + "github.com/asdine/storm/v3/codec/json" + "github.com/bjdgyc/anylink/base" bolt "go.etcd.io/bbolt" ) -const pageSize = 10 - var ( - db *bolt.DB - ErrNoKey = errors.New("db no this key") + sdb *storm.DB ) func initDb() { var err error - db, err = bolt.Open(common.ServerCfg.DbFile, 0666, nil) + sdb, err = storm.Open(base.Cfg.DbFile, storm.Codec(json.Codec), + storm.BoltOptions(0600, &bolt.Options{Timeout: 10 * time.Second})) if err != nil { - log.Fatal(err) + base.Fatal(err) } - // 创建bucket - err = db.Update(func(tx *bolt.Tx) error { - var err error - _, err = tx.CreateBucketIfNotExists([]byte(BucketUser)) - if err != nil { - return err - } - _, err = tx.CreateBucketIfNotExists([]byte(BucketGroup)) - if err != nil { - return err - } - _, err = tx.CreateBucketIfNotExists([]byte(BucketMacIp)) - if err != nil { - return err - } - return nil - }) - + // 初始化数据库 + err = sdb.Init(&User{}) if err != nil { - log.Fatal(err) + base.Fatal(err) } + + // fmt.Println("s1") } -func NextId(bucket string) int { - var i int - db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte(bucket)) - id, err := b.NextSequence() - i = int(id) - // discard error - return err - }) - return i +func initData() { + var ( + err error + install bool + ) + + // 判断是否初次使用 + err = Get(SettingBucket, Installed, &install) + if err == nil && install { + // 已经安装过 + return + } + + defer Set(SettingBucket, Installed, true) + + smtp := &SettingSmtp{ + Host: "127.0.0.1", + Port: 25, + From: "vpn@xx.com", + } + SettingSet(smtp) + + other := &SettingOther{ + Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!", + AccountMail: accountMail, + } + SettingSet(other) + } -func GetCount(bucket string) int { - count := 0 - db.View(func(tx *bolt.Tx) error { - bkt := tx.Bucket([]byte(bucket)) - s := bkt.Stats() - // fmt.Printf("%+v \n", s) - count = s.KeyN - return nil - }) - return count +func CheckErrNotFound(err error) bool { + if err == storm.ErrNotFound { + return true + } + return false } -func Set(bucket, key string, v interface{}) error { - return db.Update(func(tx *bolt.Tx) error { - bkt := tx.Bucket([]byte(bucket)) - b, err := json.Marshal(v) - if err != nil { - return err - } - return bkt.Put([]byte(key), b) - }) -} - -func Del(bucket, key string) error { - return db.Update(func(tx *bolt.Tx) error { - bkt := tx.Bucket([]byte(bucket)) - return bkt.Delete([]byte(key)) - }) -} - -func Get(bucket, key string, v interface{}) error { - return db.View(func(tx *bolt.Tx) error { - bkt := tx.Bucket([]byte(bucket)) - b := bkt.Get([]byte(key)) - if b == nil { - return ErrNoKey - } - return json.Unmarshal(b, v) - }) -} - -// 分页获取 -func getList(bucket, lastKey string, prev bool) [][]byte { - res := make([][]byte, 0) - db.View(func(tx *bolt.Tx) error { - c := tx.Bucket([]byte(bucket)).Cursor() - size := pageSize - k, b := c.Seek([]byte(lastKey)) - - if prev { - for i := 0; i < size; i++ { - k, b = c.Prev() - if k == nil { - break - } - res = append(res, b) - } - return nil - } - - // next - if string(k) != lastKey { - // 不相同,说明找出其他的 - size -= 1 - res = append(res, b) - } - for i := 0; i < size; i++ { - k, b = c.Next() - if k == nil { - break - } - res = append(res, b) - } - return nil - }) - return res -} +const accountMail = `

您好:

+

  您的{{.Issuer}}账号已经审核开通。

+

+ 登陆地址: {{.LinkAddr}}
+ 用户组: {{.Group}}
+ 用户名: {{.Username}}
+ 用户PIN码: {{.PinCode}}
+ 用户动态码(3天后失效):
+ +

+
+ 使用说明: + +
+

+ 软件下载地址: https://gitee.com/bjdgyc/anylink-soft/blob/master/README.md +

` diff --git a/dbdata/db_orm.go b/dbdata/db_orm.go new file mode 100644 index 0000000..e7d9a1e --- /dev/null +++ b/dbdata/db_orm.go @@ -0,0 +1,66 @@ +package dbdata + +import "github.com/asdine/storm/v3/index" + +const PageSize = 10 + +func Save(data interface{}) error { + return sdb.Save(data) +} + +func Update(data interface{}) error { + return sdb.Update(data) +} + +func UpdateField(data interface{}, fieldName string, value interface{}) error { + return sdb.UpdateField(data, fieldName, value) +} + +func Del(data interface{}) error { + return sdb.DeleteStruct(data) +} + +func Set(bucket, key string, data interface{}) error { + return sdb.Set(bucket, key, data) +} + +func Get(bucket, key string, data interface{}) error { + return sdb.Get(bucket, key, data) +} + +func CountAll(data interface{}) int { + n, _ := sdb.Count(data) + return n +} + +func One(fieldName string, value interface{}, to interface{}) error { + return sdb.One(fieldName, value, to) +} + +func Find(fieldName string, value interface{}, to interface{}, options ...func(q *index.Options)) error { + return sdb.Find(fieldName, value, to, options...) +} + +func All(to interface{}, limit, page int) error { + opt := getOpt(limit, page) + return sdb.All(to, opt) +} + +func Prefix(fieldName string, prefix string, to interface{}, limit, page int) error { + opt := getOpt(limit, page) + return sdb.Prefix(fieldName, prefix, to, opt) +} + +func getOpt(limit, page int) func(*index.Options) { + skip := (page - 1) * limit + opt := func(opt *index.Options) { + opt.Reverse = true + if limit > 0 { + opt.Limit = limit + } + if skip > 0 { + opt.Skip = skip + } + } + return opt +} diff --git a/dbdata/db_test.go b/dbdata/db_test.go index 1423d1e..60c5c72 100644 --- a/dbdata/db_test.go +++ b/dbdata/db_test.go @@ -1,23 +1,22 @@ package dbdata import ( - "net" "os" "path" "testing" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/stretchr/testify/assert" ) func preIpData() { tmpDb := path.Join(os.TempDir(), "anylink_test.db") - common.ServerCfg.DbFile = tmpDb + base.Cfg.DbFile = tmpDb initDb() } func closeIpdata() { - db.Close() + sdb.Close() tmpDb := path.Join(os.TempDir(), "anylink_test.db") os.Remove(tmpDb) } @@ -27,37 +26,8 @@ func TestDb(t *testing.T) { preIpData() defer closeIpdata() - Set(BucketUser, "a", User{Username: "a"}) - Set(BucketUser, "b", User{Username: "b"}) - Set(BucketUser, "c", User{Username: "c"}) - Set(BucketUser, "d", User{Username: "d"}) - Set(BucketUser, "e", User{Username: "e"}) - Set(BucketUser, "f", User{Username: "f"}) - Set(BucketUser, "g", User{Username: "g"}) + u := User{Username: "a"} + Save(&u) - c := GetCount(BucketUser) - assert.Equal(c, 7) - Del(BucketUser, "g") - c = GetCount(BucketUser) - assert.Equal(c, 6) - - // 分页查询 - us := GetUsers("d", false) - assert.Equal(us[0].Username, "e") - assert.Equal(us[1].Username, "f") - us = GetUsers("d", true) - assert.Equal(us[0].Username, "c") - assert.Equal(us[1].Username, "b") - assert.Equal(us[2].Username, "a") - - mac1 := MacIp{Ip: net.ParseIP("192.168.3.11"), MacAddr: "mac1"} - mac2 := MacIp{Ip: net.ParseIP("192.168.3.12"), MacAddr: "mac2"} - Set(BucketMacIp, "mac1", mac1) - Set(BucketMacIp, "mac2", mac2) - - mp := GetAllMacIp() - assert.Equal(mp[0].MacAddr, "mac1") - assert.Equal(mp[1].MacAddr, "mac2") - - os.Exit(0) + assert.Equal(u.Id, 1) } diff --git a/dbdata/group.go b/dbdata/group.go index 9a8ebe4..e21b07e 100644 --- a/dbdata/group.go +++ b/dbdata/group.go @@ -1,36 +1,133 @@ package dbdata import ( - "encoding/json" + "errors" + "fmt" "net" + "strings" "time" + + "github.com/bjdgyc/anylink/base" ) -const BucketGroup = "group" +const ( + Allow = "allow" + Deny = "deny" +) + +type GroupLinkAcl struct { + // 自上而下匹配 默认 allow * * + Action string `json:"action"` // allow、deny + Val string `json:"val"` + Port uint8 `json:"port"` + IpNet *net.IPNet `json:"-"` +} + +type ValData struct { + Val string `json:"val"` +} type Group struct { - Id int - Name string - RouteInclude []string - RouteExclude []string - AllowLan bool - LinkAcl []struct { - Action string // allow、deny - IpNet string - IPNet net.IPNet - } - Bandwidth int // 带宽限制 - CreatedAt time.Time - UpdatedAt time.Time + Id int `json:"id" storm:"id,increment"` + Name string `json:"name" storm:"unique"` + Note string `json:"note"` + AllowLan bool `json:"allow_lan"` + ClientDns []ValData `json:"client_dns"` + RouteInclude []ValData `json:"route_include"` + RouteExclude []ValData `json:"route_exclude"` + LinkAcl []GroupLinkAcl `json:"link_acl"` + Bandwidth int `json:"bandwidth"` // 带宽限制 + Status int8 `json:"status"` // 1正常 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } -func GetGroups(lastKey string, prev bool) []Group { - res := getList(BucketUser, lastKey, prev) - datas := make([]Group, 0) - for _, data := range res { - d := Group{} - json.Unmarshal(data, &d) - datas = append(datas, d) +func GetGroupNames() []string { + var datas []Group + err := All(&datas, 0, 0) + if err != nil { + base.Error(err) + return nil } - return datas + var names []string + for _, v := range datas { + names = append(names, v.Name) + } + return names +} + +func SetGroup(g *Group) error { + var err error + if g.Name == "" { + return errors.New("用户组名错误") + } + + // 判断数据 + clientDns := []ValData{} + for _, v := range g.ClientDns { + if v.Val != "" { + clientDns = append(clientDns, v) + } + } + g.ClientDns = clientDns + + routeInclude := []ValData{} + for _, v := range g.RouteInclude { + if v.Val != "" { + v1, _ := parseIpNet(v.Val) + vn := ValData{Val: v1} + routeInclude = append(routeInclude, vn) + } + } + g.RouteInclude = routeInclude + routeExclude := []ValData{} + for _, v := range g.RouteExclude { + if v.Val != "" { + v1, _ := parseIpNet(v.Val) + vn := ValData{Val: v1} + routeExclude = append(routeExclude, vn) + } + } + g.RouteExclude = routeExclude + // 转换数据 + linkAcl := []GroupLinkAcl{} + for _, v := range g.LinkAcl { + if v.Val != "" { + v1, v2 := parseIpNet(v.Val) + if v2 != nil { + vn := v + vn.Val = v1 + vn.IpNet = v2 + linkAcl = append(linkAcl, vn) + } + } + } + g.LinkAcl = linkAcl + + g.UpdatedAt = time.Now() + err = Save(g) + + return err +} + +func parseIpNet(s string) (string, *net.IPNet) { + ips := strings.Split(s, "/") + if len(ips) != 2 { + return "", nil + } + ip := net.ParseIP(ips[0]) + mask := net.ParseIP(ips[1]) + + if strings.Contains(ips[0], ".") { + ip = ip.To4() + mask = mask.To4() + } + + ipmask := net.IPMask(mask) + ip0 := ip.Mask(ipmask) + + ipNetS := fmt.Sprintf("%s/%s", ip0, mask) + ipNet := &net.IPNet{IP: ip0, Mask: ipmask} + + return ipNetS, ipNet } diff --git a/dbdata/ip_map.go b/dbdata/ip_map.go new file mode 100644 index 0000000..2a81305 --- /dev/null +++ b/dbdata/ip_map.go @@ -0,0 +1,18 @@ +package dbdata + +import ( + "net" + "time" +) + +type IpMap struct { + Id int `json:"id" storm:"id,increment"` + IpAddr net.IP `json:"ip_addr" storm:"unique"` + MacAddr string `json:"mac_addr" storm:"unique"` + Username string `json:"username"` + Keep bool `json:"keep"` // 保留 ip-mac 绑定 + KeepTime time.Time `json:"keep_time"` + Note string `json:"note"` // 备注 + LastLogin time.Time `json:"last_login"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/dbdata/mac_ip.go b/dbdata/mac_ip.go deleted file mode 100644 index ab220f9..0000000 --- a/dbdata/mac_ip.go +++ /dev/null @@ -1,34 +0,0 @@ -package dbdata - -import ( - "encoding/json" - "net" - "time" - - bolt "go.etcd.io/bbolt" -) - -const BucketMacIp = "macIp" - -type MacIp struct { - IsActive bool // db存储没有使用 - Ip net.IP - MacAddr string - LastLogin time.Time -} - -func GetAllMacIp() []MacIp { - datas := make([]MacIp, 0) - db.View(func(tx *bolt.Tx) error { - bkt := tx.Bucket([]byte(BucketMacIp)) - bkt.ForEach(func(k, v []byte) error { - d := MacIp{} - json.Unmarshal(v, &d) - datas = append(datas, d) - return nil - }) - return nil - }) - - return datas -} diff --git a/dbdata/setting.go b/dbdata/setting.go new file mode 100644 index 0000000..acc89f0 --- /dev/null +++ b/dbdata/setting.go @@ -0,0 +1,46 @@ +package dbdata + +import ( + "reflect" +) + +const ( + SettingBucket = "SettingBucket" + Installed = "Installed" +) + +func StructName(data interface{}) string { + ref := reflect.ValueOf(data) + s := &ref + if s.Kind() == reflect.Ptr { + e := s.Elem() + s = &e + } + name := s.Type().Name() + return name +} + +func SettingSet(data interface{}) error { + key := StructName(data) + err := Set(SettingBucket, key, data) + return err +} + +func SettingGet(data interface{}) error { + key := StructName(data) + err := Get(SettingBucket, key, data) + return err +} + +type SettingSmtp struct { + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + From string `json:"from"` +} + +type SettingOther struct { + Banner string `json:"banner"` + AccountMail string `json:"account_mail"` +} diff --git a/dbdata/start.go b/dbdata/start.go index a9e14c2..c8110be 100644 --- a/dbdata/start.go +++ b/dbdata/start.go @@ -2,8 +2,9 @@ package dbdata func Start() { initDb() + initData() } func Stop() error { - return db.Close() + return sdb.Close() } diff --git a/dbdata/user.go b/dbdata/user.go index df59a00..cd54868 100644 --- a/dbdata/user.go +++ b/dbdata/user.go @@ -1,29 +1,98 @@ package dbdata import ( - "encoding/json" + "errors" "time" + + "github.com/bjdgyc/anylink/pkg/utils" + "github.com/xlzd/gotp" ) -const BucketUser = "user" - type User struct { - Id int - Username string - Password string - OtpSecret string - Group []string - // CreatedAt time.Time - UpdatedAt time.Time + Id int `json:"id" storm:"id,increment"` + Username string `json:"username" storm:"unique"` + Nickname string `json:"nickname"` + Email string `json:"email"` + // Password string `json:"password"` + PinCode string `json:"pin_code"` + OtpSecret string `json:"otp_secret"` + Groups []string `json:"groups"` + Status int8 `json:"status"` // 1正常 + SendEmail bool `json:"send_email"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } -func GetUsers(lastKey string, prev bool) []User { - res := getList(BucketUser, lastKey, prev) - datas := make([]User, 0) - for _, data := range res { - d := User{} - json.Unmarshal(data, &d) - datas = append(datas, d) +// 验证用户登陆信息 +func CheckUser(name, pwd, group string) error { + // return nil + + pl := len(pwd) + if name == "" || pl < 6 { + return errors.New("密码错误") } - return datas + v := &User{} + err := One("Username", name, v) + if err != nil || v.Status != 1 { + return errors.New("用户名错误") + } + pass := pwd[:pl-6] + // if !utils.PasswordVerify(pass, v.Password) { + if pass != v.PinCode { + return errors.New("密码错误") + } + otp := pwd[pl-6:] + totp := gotp.NewDefaultTOTP(v.OtpSecret) + unix := time.Now().Unix() + verify := totp.Verify(otp, int(unix)) + if !verify { + return errors.New("动态码错误") + } + + // 判断用户组信息 + if !utils.InArrStr(v.Groups, group) { + return errors.New("用户组错误") + } + groupData := &Group{} + err = One("Name", group, groupData) + if err != nil || groupData.Status != 1 { + return errors.New("用户组错误") + } + return nil +} + +func SetUser(v *User) error { + var err error + if v.Username == "" || len(v.Groups) == 0 { + return errors.New("用户名或组错误") + } + + planPass := v.PinCode + // 自动生成密码 + if len(planPass) < 6 { + planPass = utils.RandomNum(8) + } + v.PinCode = planPass + + if v.OtpSecret == "" { + v.OtpSecret = gotp.RandomSecret(24) + } + + // 判断组是否有效 + ng := []string{} + groups := GetGroupNames() + for _, g := range v.Groups { + if utils.InArrStr(groups, g) { + ng = append(ng, g) + } + } + if len(ng) == 0 { + return errors.New("用户名或组错误") + } + v.Groups = ng + + v.UpdatedAt = time.Now() + err = Save(v) + + return err } diff --git a/downfiles/.gitignore b/downfiles/.gitignore new file mode 100644 index 0000000..e2718ef --- /dev/null +++ b/downfiles/.gitignore @@ -0,0 +1,4 @@ +# Binaries for programs and plugins + +* +!.gitignore \ No newline at end of file diff --git a/go.mod b/go.mod index ef6046a..22af352 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,25 @@ module github.com/bjdgyc/anylink -go 1.14 +go 1.15 require ( - github.com/google/gopacket v1.1.17 - github.com/pelletier/go-toml v1.8.0 + github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect + github.com/asdine/storm/v3 v3.2.1 + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/go-ole/go-ole v1.2.4 // indirect + github.com/google/gopacket v1.1.19 + github.com/gorilla/mux v1.8.0 + github.com/mojocn/base64Captcha v1.3.1 + github.com/pelletier/go-toml v1.8.1 + github.com/shirou/gopsutil v3.20.11+incompatible + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.5.1 + github.com/stretchr/testify v1.6.1 + github.com/xhit/go-simple-mail/v2 v2.6.0 github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119 go.etcd.io/bbolt v1.3.5 - golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 - golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect - golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e + golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9 + golang.org/x/net v0.0.0-20201209123823-ac852fbbde11 + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 38a1be0..207da05 100644 --- a/go.sum +++ b/go.sum @@ -1,39 +1,96 @@ -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM= +github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM= +github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUWq3EgK3CesDbo8upS2Vm9/P3FtgI+Jk= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac= +github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/gopacket v1.1.17 h1:rMrlX2ZY2UbvT+sdz3+6J+pp2z+msCq9MxTU6ymxbBY= -github.com/google/gopacket v1.1.17/go.mod h1:UdDNZ1OO62aGYVnPhxT1U6aI7ukYtA/kB8vaU0diBUM= -github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMFtw= -github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= +github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0= +github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY= +github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM= +github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shirou/gopsutil v3.20.11+incompatible h1:LJr4ZQK4mPpIV5gOa4jCOKOGb4ty4DZO54I4FGqIpto= +github.com/shirou/gopsutil v3.20.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091 h1:1zN6ImoqhSJhN8hGXFaJlSC8msLmIbX8bFqOfWLKw0w= github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091/go.mod h1:N20Z5Y8oye9a7HmytmZ+tr8Q2vlP0tAHP13kTHzwvQY= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= +github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= +github.com/xhit/go-simple-mail/v2 v2.6.0 h1:pvPmpDUUWy07cnTgwxwEe5fjdyYtETnxcvdGPQxtv/k= +github.com/xhit/go-simple-mail/v2 v2.6.0/go.mod h1:kA1XbQfCI4JxQ9ccSN6VFyIEkkugOm7YiPkA5hKiQn4= github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119 h1:YyPWX3jLOtYKulBR6AScGIs74lLrJcgeKRwcbAuQOG4= github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119/go.mod h1:/nuTSlK+okRfR/vnIPqR89fFKonnWPiZymN5ydRJkX8= +go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9 h1:sYNJzB4J8toYPQTM6pAkcmBRgw9SnQKP9oXCHfgy604= +golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= +golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201209123823-ac852fbbde11 h1:lwlPPsmjDKK0J6eG6xDWd5XPehI0R024zxjDnw3esPA= +golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw= -golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler/link_auth.go b/handler/link_auth.go index 0b369ac..9e82bdf 100644 --- a/handler/link_auth.go +++ b/handler/link_auth.go @@ -8,7 +8,8 @@ import ( "strings" "text/template" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/sessdata" ) @@ -40,7 +41,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { if cr.Type == "init" { w.WriteHeader(http.StatusOK) - data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.UserGroups} + data := RequestData{Group: cr.GroupSelect, Groups: dbdata.GetGroupNames()} tplRequest(tpl_request, w, data) return } @@ -52,22 +53,33 @@ func LinkAuth(w http.ResponseWriter, r *http.Request) { } // TODO 用户密码校验 - if !CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) { + err = dbdata.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) + if err != nil { + base.Info(err) w.WriteHeader(http.StatusOK) - data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.UserGroups, Error: true} + data := RequestData{Group: cr.GroupSelect, Groups: dbdata.GetGroupNames(), Error: "用户名或密码错误"} tplRequest(tpl_request, w, data) return } + // if !ok { + // w.WriteHeader(http.StatusOK) + // data := RequestData{Group: cr.GroupSelect, Groups: base.Cfg.UserGroups, Error: "请先激活用户"} + // tplRequest(tpl_request, w, data) + // return + // } // 创建新的session信息 - sess := sessdata.NewSession() - sess.UserName = cr.Auth.Username + sess := sessdata.NewSession("") + sess.Username = cr.Auth.Username + sess.Group = cr.GroupSelect sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress) sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal - cd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token, - Banner: common.ServerCfg.Banner} + other := &dbdata.SettingOther{} + dbdata.SettingGet(other) + rd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token, + Banner: other.Banner} w.WriteHeader(http.StatusOK) - tplRequest(tpl_complete, w, cd) + tplRequest(tpl_complete, w, rd) } const ( @@ -94,7 +106,8 @@ func tplRequest(typ int, w io.Writer, data RequestData) { type RequestData struct { Groups []string Group string - Error bool + Error string + // complete SessionId string SessionToken string @@ -116,7 +129,7 @@ var auth_request = ` 请输入你的用户名和密码 {{if .Error}} - 登陆失败: %s + 登陆失败: %s {{end}}
diff --git a/handler/base.go b/handler/link_base.go similarity index 100% rename from handler/base.go rename to handler/link_base.go diff --git a/handler/link_cstp.go b/handler/link_cstp.go index ec0a04b..4313c1a 100644 --- a/handler/link_cstp.go +++ b/handler/link_cstp.go @@ -2,72 +2,67 @@ package handler import ( "encoding/binary" - "log" "net" "time" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/sessdata" ) -func LinkCstp(conn net.Conn, sess *sessdata.ConnSession) { - log.Println("HandlerCstp") - sessdata.Sess = sess +func LinkCstp(conn net.Conn, cSess *sessdata.ConnSession) { defer func() { - log.Println("LinkCstp return") + // log.Println("LinkCstp return") conn.Close() - sess.Close() + cSess.Close() }() var ( err error n int dataLen uint16 - dead = time.Duration(common.ServerCfg.CstpDpd+2) * time.Second + dead = time.Duration(cSess.CstpDpd*2) * time.Second ) - go cstpWrite(conn, sess) + go cstpWrite(conn, cSess) for { // 设置超时限制 err = conn.SetReadDeadline(time.Now().Add(dead)) if err != nil { - log.Println("SetDeadline: ", err) + base.Error("SetDeadline: ", err) return } hdata := make([]byte, BufferSize) n, err = conn.Read(hdata) if err != nil { - log.Println("read hdata: ", err) + base.Error("read hdata: ", err) return } // 限流设置 - err = sess.RateLimit(n, true) + err = cSess.RateLimit(n, true) if err != nil { - log.Println(err) + base.Error(err) } switch hdata[6] { case 0x07: // KEEPALIVE // do nothing - // log.Println("recv keepalive") + base.Debug("recv keepalive", cSess.IpAddr) case 0x05: // DISCONNECT - // log.Println("DISCONNECT") + base.Debug("DISCONNECT", cSess.IpAddr) return case 0x03: // DPD-REQ - // log.Println("recv DPD-REQ") - if payloadOut(sess, sessdata.LTypeIPData, 0x04, nil) { + base.Debug("recv DPD-REQ", cSess.IpAddr) + if payloadOut(cSess, sessdata.LTypeIPData, 0x04, nil) { return } case 0x04: // log.Println("recv DPD-RESP") case 0x00: // DATA dataLen = binary.BigEndian.Uint16(hdata[4:6]) // 4,5 - data := hdata[8 : 8+dataLen] - - if payloadIn(sess, sessdata.LTypeIPData, 0x00, data) { + if payloadIn(cSess, sessdata.LTypeIPData, 0x00, hdata[8:8+dataLen]) { return } @@ -75,11 +70,11 @@ func LinkCstp(conn net.Conn, sess *sessdata.ConnSession) { } } -func cstpWrite(conn net.Conn, sess *sessdata.ConnSession) { +func cstpWrite(conn net.Conn, cSess *sessdata.ConnSession) { defer func() { - log.Println("cstpWrite return") + // log.Println("cstpWrite return") conn.Close() - sess.Close() + cSess.Close() }() var ( @@ -91,8 +86,8 @@ func cstpWrite(conn net.Conn, sess *sessdata.ConnSession) { for { select { - case payload = <-sess.PayloadOut: - case <-sess.CloseChan: + case payload = <-cSess.PayloadOut: + case <-cSess.CloseChan: return } @@ -107,14 +102,14 @@ func cstpWrite(conn net.Conn, sess *sessdata.ConnSession) { } n, err = conn.Write(header) if err != nil { - log.Println("write err", err) + base.Error("write err", err) return } // 限流设置 - err = sess.RateLimit(n, false) + err = cSess.RateLimit(n, false) if err != nil { - log.Println(err) + base.Error(err) } } } diff --git a/handler/link_home.go b/handler/link_home.go index bc76ae3..f927f6f 100644 --- a/handler/link_home.go +++ b/handler/link_home.go @@ -3,14 +3,15 @@ package handler import ( "fmt" "net/http" - "net/http/httputil" "strings" + + "github.com/bjdgyc/anylink/admin" ) func LinkHome(w http.ResponseWriter, r *http.Request) { - hu, _ := httputil.DumpRequest(r, true) - fmt.Println("DumpHome: ", string(hu)) - fmt.Println(r.RemoteAddr) + // fmt.Println(r.RemoteAddr) + // hu, _ := httputil.DumpRequest(r, true) + // fmt.Println("DumpHome: ", string(hu)) connection := strings.ToLower(r.Header.Get("Connection")) userAgent := strings.ToLower(r.UserAgent()) @@ -23,3 +24,16 @@ func LinkHome(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "hello world") } + +func LinkOtpQr(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + idS := r.FormValue("id") + jwtToken := r.FormValue("jwt") + data, err := admin.GetJwtData(jwtToken) + if err != nil || idS != fmt.Sprint(data["id"]) { + w.WriteHeader(http.StatusForbidden) + return + } + + admin.UserOtpQr(w, r) +} diff --git a/handler/link_tap.go b/handler/link_tap.go index b7fd156..21389fb 100644 --- a/handler/link_tap.go +++ b/handler/link_tap.go @@ -2,10 +2,10 @@ package handler import ( "fmt" - "log" "net" - "github.com/bjdgyc/anylink/arpdis" + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/pkg/arpdis" "github.com/bjdgyc/anylink/sessdata" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -16,13 +16,18 @@ import ( const bridgeName = "anylink0" +var ( + bridgeIp net.IP + bridgeHw net.HardwareAddr +) + func checkTap() { brFace, err := net.InterfaceByName(bridgeName) if err != nil { - log.Fatal("testTap err: ", err) + base.Fatal("testTap err: ", err) } - bridgeHw := brFace.HardwareAddr - var bridgeIp net.IP + bridgeHw = brFace.HardwareAddr + addrs, err := brFace.Addrs() for _, addr := range addrs { ip, _, err := net.ParseCIDR(addr.String()) @@ -32,112 +37,104 @@ func checkTap() { bridgeIp = ip } if bridgeIp == nil && bridgeHw == nil { - log.Fatalln("bridgeIp is err") + base.Fatal("bridgeIp is err") } if !sessdata.IpPool.Ipv4IPNet.Contains(bridgeIp) { - log.Fatalln("bridgeIp or Ip network err") + base.Fatal("bridgeIp or Ip network err") } - - // 设置本机ip arp为静态 - addr := &arpdis.Addr{IP: bridgeIp.To4(), HardwareAddr: bridgeHw, Type: arpdis.TypeStatic} - arpdis.Add(addr) } // 创建tap网卡 -func LinkTap(sess *sessdata.ConnSession) { - defer func() { - log.Println("LinkTap return") - sess.Close() - }() - +func LinkTap(cSess *sessdata.ConnSession) error { cfg := water.Config{ DeviceType: water.TAP, } ifce, err := water.New(cfg) if err != nil { - log.Println(err) - return + base.Error(err) + return err } - sess.TunName = ifce.Name() - defer ifce.Close() + + cSess.TunName = ifce.Name() // arp on - cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast on", ifce.Name(), sess.Mtu) + cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast on", ifce.Name(), cSess.Mtu) cmdstr2 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name()) cmdstr3 := fmt.Sprintf("ip link set dev %s master %s", ifce.Name(), bridgeName) cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3} err = execCmd(cmdStrs) if err != nil { - return + base.Error(err) + ifce.Close() + return err } - // TODO 测试 - // sess.MacHw, _ = net.ParseMAC("3c:8c:40:a0:6a:3d") + go tapRead(ifce, cSess) + go tapWrite(ifce, cSess) + return nil +} - go loopArp(sess) - go tapRead(ifce, sess) +func tapWrite(ifce *water.Interface, cSess *sessdata.ConnSession) { + defer func() { + // log.Println("LinkTap return") + cSess.Close() + ifce.Close() + }() var ( + err error payload *sessdata.Payload ) for { select { - case payload = <-sess.PayloadIn: - case <-sess.CloseChan: + case payload = <-cSess.PayloadIn: + case <-cSess.CloseChan: return } var frame ethernet.Frame switch payload.LType { default: - log.Println(payload) + // log.Println(payload) case sessdata.LTypeEthernet: frame = payload.Data case sessdata.LTypeIPData: // 需要转换成 Ethernet 数据 data := payload.Data ip_src := waterutil.IPv4Source(data) - if waterutil.IsIPv6(data) || !ip_src.Equal(sess.Ip) { + if waterutil.IsIPv6(data) || !ip_src.Equal(cSess.IpAddr) { // 过滤掉IPv6的数据 // 非分配给客户端ip,直接丢弃 continue } + // packet := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default) + // fmt.Println("get:", packet) + ip_dst := waterutil.IPv4Destination(data) // fmt.Println("get:", ip_src, ip_dst) - var dstAddr *arpdis.Addr + var dstHw net.HardwareAddr if !sessdata.IpPool.Ipv4IPNet.Contains(ip_dst) || ip_dst.Equal(sessdata.IpPool.Ipv4Gateway) { // 不是同一网段,使用网关mac地址 - ip_dst = sessdata.IpPool.Ipv4Gateway - dstAddr = arpdis.Lookup(ip_dst, false) - if dstAddr == nil { - log.Println("Ipv4Gateway mac err", ip_dst) - return - } - // fmt.Println("Gateway", ip_dst, dstAddr.HardwareAddr) + dstAddr := arpdis.Lookup(sessdata.IpPool.Ipv4Gateway, false) + dstHw = dstAddr.HardwareAddr } else { - // 同一网段内的其他主机 - dstAddr = arpdis.Lookup(ip_dst, true) - // fmt.Println("other", ip_src, ip_dst, dstAddr) - if dstAddr == nil || dstAddr.Type == arpdis.TypeUnreachable { - // 异步检测发送数据包 - select { - case sess.PayloadArp <- payload: - case <-sess.CloseChan: - return - default: - // PayloadArp 容量已经满了 - log.Println("PayloadArp is full", sess.Ip, ip_dst) - } - continue + dstAddr := arpdis.Lookup(ip_dst, true) + // fmt.Println("dstAddr", dstAddr) + if dstAddr != nil { + dstHw = dstAddr.HardwareAddr + } else { + dstHw = bridgeHw } - } - frame.Prepare(dstAddr.HardwareAddr, sess.MacHw, ethernet.NotTagged, ethernet.IPv4, len(data)) + } + // fmt.Println("Gateway", ip_dst, dstAddr.HardwareAddr) + + frame.Prepare(dstHw, cSess.MacHw, ethernet.NotTagged, ethernet.IPv4, len(data)) copy(frame[12+2:], data) } @@ -145,52 +142,15 @@ func LinkTap(sess *sessdata.ConnSession) { // fmt.Println("write:", packet) _, err = ifce.Write(frame) if err != nil { - log.Println("tap Write err", err) + base.Error("tap Write err", err) return } } - } -// 异步处理获取ip对应的mac地址的数据 -func loopArp(sess *sessdata.ConnSession) { +func tapRead(ifce *water.Interface, cSess *sessdata.ConnSession) { defer func() { - log.Println("loopArp return") - }() - - var ( - payload *sessdata.Payload - dstAddr *arpdis.Addr - ip_dst net.IP - ) - - for { - select { - case payload = <-sess.PayloadArp: - case <-sess.CloseChan: - return - } - - ip_dst = waterutil.IPv4Destination(payload.Data) - dstAddr = arpdis.Lookup(ip_dst, false) - // 不可达数据包 - if dstAddr == nil || dstAddr.Type == arpdis.TypeUnreachable { - // 直接丢弃数据 - // fmt.Println("Lookup", ip_dst) - continue - } - - // 正常获取mac地址 - if payloadInData(sess, payload) { - return - } - - } -} - -func tapRead(ifce *water.Interface, sess *sessdata.ConnSession) { - defer func() { - log.Println("tapRead return") + // log.Println("tapRead return") ifce.Close() }() @@ -199,14 +159,13 @@ func tapRead(ifce *water.Interface, sess *sessdata.ConnSession) { n int buf []byte ) - fmt.Println(sess.MacHw) for { var frame ethernet.Frame frame.Resize(BufferSize) n, err = ifce.Read(frame) if err != nil { - log.Println("tap Read err", n, err) + base.Error("tap Read err", n, err) return } frame = frame[:n] @@ -223,47 +182,54 @@ func tapRead(ifce *water.Interface, sess *sessdata.ConnSession) { data := frame.Payload() ip_dst := waterutil.IPv4Destination(data) - if !ip_dst.Equal(sess.Ip) { + if !ip_dst.Equal(cSess.IpAddr) { // 过滤非本机地址 // log.Println(ip_dst, sess.Ip) continue } - if payloadOut(sess, sessdata.LTypeIPData, 0x00, data) { + // packet := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default) + // fmt.Println("put:", packet) + + if payloadOut(cSess, sessdata.LTypeIPData, 0x00, data) { return } case ethernet.ARP: // 暂时仅实现了ARP协议 - packet := gopacket.NewPacket(frame, layers.LayerTypeEthernet, gopacket.NoCopy) + packet := gopacket.NewPacket(frame, layers.LayerTypeEthernet, gopacket.Default) layer := packet.Layer(layers.LayerTypeARP) arpReq := layer.(*layers.ARP) - // fmt.Println("arp", net.IP(arpReq.SourceProtAddress), sess.Ip) - if !sess.Ip.Equal(arpReq.DstProtAddress) { + if !cSess.IpAddr.Equal(arpReq.DstProtAddress) { // 过滤非本机地址 continue } - // fmt.Println("arp", arpReq.SourceProtAddress, sess.Ip) + // fmt.Println("arp", net.IP(arpReq.SourceProtAddress), sess.Ip) // fmt.Println(packet) // 返回ARP数据 - src := &arpdis.Addr{IP: sess.Ip, HardwareAddr: sess.MacHw} + src := &arpdis.Addr{IP: cSess.IpAddr, HardwareAddr: cSess.MacHw} dst := &arpdis.Addr{IP: arpReq.SourceProtAddress, HardwareAddr: frame.Source()} buf, err = arpdis.NewARPReply(src, dst) if err != nil { - log.Println(err) + base.Error(err) return } // 从接受的arp信息添加arp地址 - addr := &arpdis.Addr{} + addr := &arpdis.Addr{ + IP: make([]byte, len(arpReq.SourceProtAddress)), + HardwareAddr: make([]byte, len(frame.Source())), + } + // addr.IP = arpReq.SourceProtAddress + // addr.HardwareAddr = frame.Source() copy(addr.IP, arpReq.SourceProtAddress) copy(addr.HardwareAddr, frame.Source()) arpdis.Add(addr) - if payloadIn(sess, sessdata.LTypeEthernet, 0x00, buf) { + if payloadIn(cSess, sessdata.LTypeEthernet, 0x00, buf) { return } diff --git a/handler/link_tun.go b/handler/link_tun.go index 91d04a2..536fc31 100644 --- a/handler/link_tun.go +++ b/handler/link_tun.go @@ -2,9 +2,8 @@ package handler import ( "fmt" - "log" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/sessdata" "github.com/songgao/water" ) @@ -17,7 +16,7 @@ func checkTun() { ifce, err := water.New(cfg) if err != nil { - log.Fatal("open tun err: ", err) + base.Fatal("open tun err: ", err) } defer ifce.Close() @@ -25,63 +24,72 @@ func checkTun() { cmdstr := fmt.Sprintf("ip link set dev %s up mtu %s multicast off", ifce.Name(), "1399") err = execCmd([]string{cmdstr}) if err != nil { - log.Fatal("testTun err: ", err) + base.Fatal("testTun err: ", err) } } // 创建tun网卡 -func LinkTun(sess *sessdata.ConnSession) { - defer func() { - log.Println("LinkTun return") - sess.Close() - }() - +func LinkTun(cSess *sessdata.ConnSession) error { cfg := water.Config{ DeviceType: water.TUN, } ifce, err := water.New(cfg) if err != nil { - log.Println(err) - return + base.Error(err) + return err } // log.Printf("Interface Name: %s\n", ifce.Name()) - sess.TunName = ifce.Name() - defer ifce.Close() + cSess.SetTunName(ifce.Name()) + // cSess.TunName = ifce.Name() - cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast off", ifce.Name(), sess.Mtu) + cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast off", ifce.Name(), cSess.Mtu) cmdstr2 := fmt.Sprintf("ip addr add dev %s local %s peer %s/32", - ifce.Name(), common.ServerCfg.Ipv4Gateway, sess.Ip) + ifce.Name(), base.Cfg.Ipv4Gateway, cSess.IpAddr) cmdstr3 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name()) cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3} err = execCmd(cmdStrs) if err != nil { - return + base.Error(err) + ifce.Close() + return err } - go tunRead(ifce, sess) + go tunRead(ifce, cSess) + go tunWrite(ifce, cSess) + return nil +} - var payload *sessdata.Payload +func tunWrite(ifce *water.Interface, cSess *sessdata.ConnSession) { + defer func() { + // log.Println("LinkTun return") + cSess.Close() + ifce.Close() + }() + + var ( + err error + payload *sessdata.Payload + ) for { select { - case payload = <-sess.PayloadIn: - case <-sess.CloseChan: + case payload = <-cSess.PayloadIn: + case <-cSess.CloseChan: return } _, err = ifce.Write(payload.Data) if err != nil { - log.Println("tun Write err", err) + base.Error("tun Write err", err) return } } - } -func tunRead(ifce *water.Interface, sess *sessdata.ConnSession) { +func tunRead(ifce *water.Interface, cSess *sessdata.ConnSession) { defer func() { - log.Println("tunRead return") + // log.Println("tunRead return") ifce.Close() }() var ( @@ -93,7 +101,7 @@ func tunRead(ifce *water.Interface, sess *sessdata.ConnSession) { data := make([]byte, BufferSize) n, err = ifce.Read(data) if err != nil { - log.Println("tun Read err", n, err) + base.Error("tun Read err", n, err) return } @@ -106,7 +114,7 @@ func tunRead(ifce *water.Interface, sess *sessdata.ConnSession) { // packet := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default) // fmt.Println("read:", packet) - if payloadOut(sess, sessdata.LTypeIPData, 0x00, data) { + if payloadOut(cSess, sessdata.LTypeIPData, 0x00, data) { return } diff --git a/handler/link_tunnel.go b/handler/link_tunnel.go index 29c5b6e..57fc8d8 100644 --- a/handler/link_tunnel.go +++ b/handler/link_tunnel.go @@ -7,7 +7,8 @@ import ( "net/http" "os" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/sessdata" ) @@ -44,49 +45,62 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } - fmt.Println(cSess.Ip, cSess.MacHw) // 客户端信息 - cstp_mtu := r.Header.Get("X-CSTP-MTU") - master_Secret := r.Header.Get("X-DTLS-Master-Secret") - local_ip := r.Header.Get("X-Cstp-Local-Address-Ip4") + cstpMtu := r.Header.Get("X-CSTP-MTU") + masterSecret := r.Header.Get("X-DTLS-Master-Secret") + localIp := r.Header.Get("X-Cstp-Local-Address-Ip4") mobile := r.Header.Get("X-Cstp-License") - cSess.SetMtu(cstp_mtu) - cSess.MasterSecret = master_Secret + platform := r.Header.Get("X-AnyConnect-Identifier-Platform") + cSess.SetMtu(cstpMtu) + cSess.MasterSecret = masterSecret cSess.RemoteAddr = r.RemoteAddr - cSess.LocalIp = net.ParseIP(local_ip) - cstpDpd := common.ServerCfg.CstpDpd + cSess.LocalIp = net.ParseIP(localIp) + cstpKeepalive := base.Cfg.CstpKeepalive + cstpDpd := base.Cfg.CstpDpd + cSess.Client = "pc" if mobile == "mobile" { // 手机客户端 - cstpDpd = common.ServerCfg.MobileDpd + cstpKeepalive = base.Cfg.MobileKeepalive + cstpDpd = base.Cfg.MobileDpd + cSess.Client = "mobile" } + cSess.CstpDpd = cstpDpd + + // iPhone手机需要最少一个dns + if platform == "apple-ios" && len(cSess.Group.ClientDns) == 0 { + dnsVal := dbdata.ValData{Val: "114.114.114.114"} + cSess.Group.ClientDns = append(cSess.Group.ClientDns, dnsVal) + } + + base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile) // 返回客户端数据 - w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER)) + w.Header().Set("Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER)) w.Header().Set("X-CSTP-Version", "1") w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.") - w.Header().Set("X-CSTP-Address", cSess.Ip.String()) // 分配的ip地址 - w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码 - w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 - for _, v := range common.ServerCfg.ClientDns { - w.Header().Add("X-CSTP-DNS", v) // dns地址 - } + w.Header().Set("X-CSTP-Address", cSess.IpAddr.String()) // 分配的ip地址 + w.Header().Set("X-CSTP-Netmask", base.Cfg.Ipv4Netmask) // 子网掩码 + w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 + // 允许本地LAN访问vpn网络,必须放在路由的第一个 - if common.ServerCfg.AllowLan { + if cSess.Group.AllowLan { w.Header().Set("X-CSTP-Split-Exclude", "0.0.0.0/255.255.255.255") } + // dns地址 + for _, v := range cSess.Group.ClientDns { + w.Header().Add("X-CSTP-DNS", v.Val) + } // 允许的路由 - for _, v := range common.ServerCfg.Include { - w.Header().Add("X-CSTP-Split-Include", v) + for _, v := range cSess.Group.RouteInclude { + w.Header().Add("X-CSTP-Split-Include", v.Val) } // 不允许的路由 - for _, v := range common.ServerCfg.Exclude { - w.Header().Add("X-CSTP-Split-Exclude", v) + for _, v := range cSess.Group.RouteExclude { + w.Header().Add("X-CSTP-Split-Exclude", v.Val) } - // w.Header().Add("X-CSTP-Split-Include", "192.168.0.0/255.255.0.0") - // w.Header().Add("X-CSTP-Split-Exclude", "10.1.5.2/255.255.255.255") - w.Header().Set("X-CSTP-Lease-Duration", fmt.Sprintf("%d", sessdata.IpLease)) // ip地址租期 + w.Header().Set("X-CSTP-Lease-Duration", fmt.Sprintf("%d", base.Cfg.IpLease)) // ip地址租期 w.Header().Set("X-CSTP-Session-Timeout", "none") w.Header().Set("X-CSTP-Session-Timeout-Alert-Interval", "60") w.Header().Set("X-CSTP-Session-Timeout-Remaining", "none") @@ -98,9 +112,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-CSTP-Rekey-Time", "172800") w.Header().Set("X-CSTP-Rekey-Method", "new-tunnel") - w.Header().Set("X-CSTP-DPD", fmt.Sprintf("%d", cstpDpd)) // 30 Dead peer detection in seconds - w.Header().Set("X-CSTP-Keepalive", fmt.Sprintf("%d", common.ServerCfg.CstpKeepalive)) // 20 - w.Header().Set("X-CSTP-Banner", common.ServerCfg.Banner) // urlencode + w.Header().Set("X-CSTP-DPD", fmt.Sprintf("%d", cstpDpd)) + w.Header().Set("X-CSTP-Keepalive", fmt.Sprintf("%d", cstpKeepalive)) + // w.Header().Set("X-CSTP-Banner", banner.Banner) w.Header().Set("X-CSTP-MSIE-Proxy-Lockdown", "true") w.Header().Set("X-CSTP-Smartcard-Removal-Disconnect", "true") @@ -109,7 +123,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-DTLS-Session-ID", sess.DtlsSid) w.Header().Set("X-DTLS-Port", "4433") - w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", common.ServerCfg.CstpKeepalive)) + w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", base.Cfg.CstpKeepalive)) w.Header().Set("X-DTLS-Rekey-Time", "5400") w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-ECDSA-AES128-GCM-SHA256") // w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-RSA-AES128-GCM-SHA256") @@ -123,6 +137,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // w.Header().Set("X-CSTP-Post-Auth-XML", ``) w.WriteHeader(http.StatusOK) + // h := w.Header().Clone() + // h.Write(os.Stdout) + hj := w.(http.Hijacker) conn, _, err := hj.Hijack() if err != nil { @@ -131,11 +148,15 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { } // 开始数据处理 - switch common.ServerCfg.LinkMode { - case common.LinkModeTUN: - go LinkTun(cSess) - case common.LinkModeTAP: - go LinkTap(cSess) + switch base.Cfg.LinkMode { + case base.LinkModeTUN: + err = LinkTun(cSess) + case base.LinkModeTAP: + err = LinkTap(cSess) + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } go LinkCstp(conn, cSess) diff --git a/handler/payload.go b/handler/payload.go index 5bef3c8..6413598 100644 --- a/handler/payload.go +++ b/handler/payload.go @@ -2,44 +2,44 @@ package handler import "github.com/bjdgyc/anylink/sessdata" -func payloadIn(sess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { +func payloadIn(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { payload := &sessdata.Payload{ LType: lType, PType: pType, Data: data, } - return payloadInData(sess, payload) + return payloadInData(cSess, payload) } -func payloadInData(sess *sessdata.ConnSession, payload *sessdata.Payload) bool { +func payloadInData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool { closed := false select { - case sess.PayloadIn <- payload: - case <-sess.CloseChan: + case cSess.PayloadIn <- payload: + case <-cSess.CloseChan: closed = true } return closed } -func payloadOut(sess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { +func payloadOut(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { payload := &sessdata.Payload{ LType: lType, PType: pType, Data: data, } - return payloadOutData(sess, payload) + return payloadOutData(cSess, payload) } -func payloadOutData(sess *sessdata.ConnSession, payload *sessdata.Payload) bool { +func payloadOutData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool { closed := false select { - case sess.PayloadOut <- payload: - case <-sess.CloseChan: + case cSess.PayloadOut <- payload: + case <-cSess.CloseChan: closed = true } diff --git a/handler/server.go b/handler/server.go index df0e94d..941e506 100644 --- a/handler/server.go +++ b/handler/server.go @@ -6,35 +6,17 @@ import ( "log" "net" "net/http" - "net/http/pprof" "time" - "github.com/bjdgyc/anylink/common" - "github.com/bjdgyc/anylink/proxyproto" - "github.com/bjdgyc/anylink/router" + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/pkg/proxyproto" + "github.com/gorilla/mux" ) -func startAdmin() { - mux := router.NewHttpMux() - mux.HandleFunc(router.ANY, "/", notFound) - // mux.ServeFile(router.ANY, "/static/*", http.Dir("./static")) - - // pprof - mux.HandleFunc(router.ANY, "/debug/pprof/*", pprof.Index) - mux.HandleFunc(router.ANY, "/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc(router.ANY, "/debug/pprof/profile", pprof.Profile) - mux.HandleFunc(router.ANY, "/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc(router.ANY, "/debug/pprof/trace", pprof.Trace) - - fmt.Println("Listen admin", common.ServerCfg.AdminAddr) - err := http.ListenAndServe(common.ServerCfg.AdminAddr, mux) - fmt.Println(err) -} - func startTls() { - addr := common.ServerCfg.ServerAddr - certFile := common.ServerCfg.CertFile - keyFile := common.ServerCfg.CertKey + addr := base.Cfg.ServerAddr + certFile := base.Cfg.CertFile + keyFile := base.Cfg.CertKey // 设置tls信息 tlsConfig := &tls.Config{ @@ -55,24 +37,30 @@ func startTls() { } defer ln.Close() - if common.ServerCfg.ProxyProtocol { + if base.Cfg.ProxyProtocol { ln = &proxyproto.Listener{Listener: ln, ProxyHeaderTimeout: time.Second * 5} } - fmt.Println("listen ", addr) + base.Info("listen server", addr) err = srv.ServeTLS(ln, certFile, keyFile) if err != nil { - log.Fatal(err) + base.Fatal(err) } } func initRoute() http.Handler { - mux := router.NewHttpMux() - mux.HandleFunc("GET", "/", checkLinkClient(LinkHome)) - mux.HandleFunc("POST", "/", checkLinkClient(LinkAuth)) - mux.HandleFunc("CONNECT", "/CSCOSSLC/tunnel", LinkTunnel) - mux.SetNotFound(http.HandlerFunc(notFound)) - return mux + r := mux.NewRouter() + // r.HandleFunc("/", checkLinkClient(LinkHome)).Methods(http.MethodGet) + r.HandleFunc("/", checkLinkClient(LinkAuth)).Methods(http.MethodPost) + r.HandleFunc("/CSCOSSLC/tunnel", LinkTunnel).Methods(http.MethodConnect) + r.HandleFunc("/otp_qr", LinkOtpQr).Methods(http.MethodGet) + r.PathPrefix("/files/").Handler( + http.StripPrefix("/files/", + http.FileServer(http.Dir(base.Cfg.FilesPath)), + ), + ) + r.NotFoundHandler = http.HandlerFunc(notFound) + return r } func notFound(w http.ResponseWriter, r *http.Request) { diff --git a/handler/start.go b/handler/start.go index ee80b2c..feee6ea 100644 --- a/handler/start.go +++ b/handler/start.go @@ -1,7 +1,8 @@ package handler import ( - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/admin" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" "github.com/bjdgyc/anylink/sessdata" ) @@ -11,10 +12,10 @@ func Start() { sessdata.Start() checkTun() - if common.ServerCfg.LinkMode == common.LinkModeTAP { + if base.Cfg.LinkMode == base.LinkModeTAP { checkTap() } - go startAdmin() + go admin.StartAdmin() go startTls() go startDtls() } diff --git a/handler/user.go b/handler/user.go deleted file mode 100644 index e1bf044..0000000 --- a/handler/user.go +++ /dev/null @@ -1,71 +0,0 @@ -package handler - -import ( - "crypto/sha1" - "fmt" - "os" - "time" - - "github.com/bjdgyc/anylink/common" - "github.com/bjdgyc/anylink/dbdata" - "github.com/xlzd/gotp" -) - -func CheckUser(name, pwd, group string) bool { - return true - - pl := len(pwd) - if name == "" || pl < 6 { - return false - } - v := &dbdata.User{} - err := dbdata.Get(dbdata.BucketUser, name, v) - if err != nil { - return false - } - if !common.InArrStr(v.Group, group) { - return false - } - pass := pwd[:pl-6] - pwdHash := hashPass(pass) - if v.Password != pwdHash { - return false - } - otp := pwd[pl-6:] - totp := gotp.NewDefaultTOTP(v.OtpSecret) - unix := time.Now().Unix() - verify := totp.Verify(otp, int(unix)) - if !verify { - return false - } - return true -} - -func UserAdd(name, pwd string, group []string) dbdata.User { - v := dbdata.User{ - Id: dbdata.NextId(dbdata.BucketUser), - Username: name, - Password: hashPass(pwd), - OtpSecret: gotp.RandomSecret(32), - Group: group, - UpdatedAt: time.Now(), - } - fmt.Println(v) - secret := "WHH7WA6POOGGEYVIQYXLZU75QLM7YLUX" - totp := gotp.NewDefaultTOTP(secret) - s := totp.ProvisioningUri("bjdtest", "bjdpro") - fmt.Println(s) - - // qr, _ := qrcode.New(s, qrcode.Medium) - // a := qr.ToSmallString(false) - // fmt.Println(a) - // qr.WriteFile(512, "a.png") - - os.Exit(0) - return v -} - -func hashPass(pwd string) string { - sum := sha1.Sum([]byte(pwd)) - return fmt.Sprintf("%x", sum) -} diff --git a/main.go b/main.go index 5f52d7e..9157d22 100644 --- a/main.go +++ b/main.go @@ -1,40 +1,40 @@ +// AnyLink 是一个企业级远程办公vpn软件,可以支持多人同时在线使用。 package main import ( - "fmt" - "log" "os" "os/signal" "syscall" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/handler" ) +// 程序版本 var COMMIT_ID string func main() { - log.Println("start") - common.CommitId = COMMIT_ID - common.InitConfig() + base.CommitId = COMMIT_ID + base.Start() handler.Start() signalWatch() } func signalWatch() { - fmt.Println("Server pid: ", os.Getpid()) + base.Info("Server pid: ", os.Getpid()) sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGALRM, syscall.SIGUSR2) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGALRM) for { sig := <-sigs - fmt.Printf("Get signal: %v \n", sig) + base.Info("Get signal:", sig) switch sig { case syscall.SIGUSR2: // reload - fmt.Println("reload") + base.Info("Reload") default: // stop + base.Info("Stop") handler.Stop() return } diff --git a/arpdis/addr.go b/pkg/arpdis/addr.go similarity index 100% rename from arpdis/addr.go rename to pkg/arpdis/addr.go diff --git a/arpdis/addr_test.go b/pkg/arpdis/addr_test.go similarity index 100% rename from arpdis/addr_test.go rename to pkg/arpdis/addr_test.go diff --git a/arpdis/arp.go b/pkg/arpdis/arp.go similarity index 100% rename from arpdis/arp.go rename to pkg/arpdis/arp.go diff --git a/arpdis/icmp.go b/pkg/arpdis/icmp.go similarity index 67% rename from arpdis/icmp.go rename to pkg/arpdis/icmp.go index 56fe6b0..d89d7ae 100644 --- a/arpdis/icmp.go +++ b/pkg/arpdis/icmp.go @@ -6,50 +6,10 @@ import ( "os" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" ) -var ( - ipv4Id uint16 = 1 - icmpSeq uint16 = 1 -) - -func NewIPv4Unreachable(src *Addr, dst *Addr) ([]byte, error) { - ipv4Id++ - icmpSeq++ - - ipv4 := layers.IPv4{ - Version: 4, - IHL: 5, - TOS: 0, - // Length : uint16, - Id: ipv4Id, - Flags: 0, - FragOffset: 0, - TTL: 10, - Protocol: layers.IPProtocolICMPv4, - // Checksum : uint16, - SrcIP: src.IP.To4(), - DstIP: dst.IP.To4(), - Options: nil, - Padding: nil, - } - - icmp := layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeDestinationUnreachable, layers.ICMPv4CodeNet), - Id: uint16(os.Getpid()), - Seq: icmpSeq, - } - - buf := gopacket.NewSerializeBuffer() - err := gopacket.SerializeLayers(buf, defaultSerializeOpts, &ipv4, &icmp) - - return buf.Bytes(), err -} - const ( ProtocolICMP = 1 ProtocolIPv6ICMP = 58 diff --git a/arpdis/lookup.go b/pkg/arpdis/lookup.go similarity index 98% rename from arpdis/lookup.go rename to pkg/arpdis/lookup.go index 9ffda6c..f5457be 100644 --- a/arpdis/lookup.go +++ b/pkg/arpdis/lookup.go @@ -1,7 +1,5 @@ // Currently only Darwin and Linux support this. -// +build darwin linux - package arpdis import ( diff --git a/proxyproto/protocol.go b/pkg/proxyproto/protocol.go similarity index 99% rename from proxyproto/protocol.go rename to pkg/proxyproto/protocol.go index c729063..c4468f0 100644 --- a/proxyproto/protocol.go +++ b/pkg/proxyproto/protocol.go @@ -1,5 +1,7 @@ // copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go // design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt + +// HAProxy proxy proto v1 package proxyproto import ( diff --git a/proxyproto/protocol_test.go b/pkg/proxyproto/protocol_test.go similarity index 100% rename from proxyproto/protocol_test.go rename to pkg/proxyproto/protocol_test.go diff --git a/pkg/utils/password_hash.go b/pkg/utils/password_hash.go new file mode 100644 index 0000000..ca2cf91 --- /dev/null +++ b/pkg/utils/password_hash.go @@ -0,0 +1,40 @@ +package utils + +import ( + "crypto/rand" + "encoding/base64" + mt "math/rand" + + "golang.org/x/crypto/bcrypt" +) + +func PasswordHash(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +func PasswordVerify(password, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} + +// $sha-256$salt-key$hash-abcd +// $sha-512$salt-key$hash-abcd +const ( + saltSize = 16 + delmiter = "$" +) + +func saltSecret() (string, error) { + rb := make([]byte, randInt(10, 100)) + _, err := rand.Read(rb) + if err != nil { + return "", err + } + + return base64.URLEncoding.EncodeToString(rb), nil +} + +func randInt(min int, max int) int { + return min + mt.Intn(max-min) +} diff --git a/pkg/utils/util.go b/pkg/utils/util.go new file mode 100644 index 0000000..0ea8b56 --- /dev/null +++ b/pkg/utils/util.go @@ -0,0 +1,74 @@ +package utils + +import ( + "fmt" + "math/rand" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func InArrStr(arr []string, str string) bool { + for _, d := range arr { + if d == str { + return true + } + } + return false +} + +const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + TB = 1024 * GB + PB = 1024 * TB +) + +func HumanByte(bf interface{}) string { + var hb string + var bAll float64 + switch bi := bf.(type) { + case int: + bAll = float64(bi) + case int32: + bAll = float64(bi) + case uint32: + bAll = float64(bi) + case int64: + bAll = float64(bi) + case uint64: + bAll = float64(bi) + case float64: + bAll = float64(bi) + } + + switch { + case bAll >= TB: + hb = fmt.Sprintf("%0.2f TB", bAll/TB) + case bAll >= GB: + hb = fmt.Sprintf("%0.2f GB", bAll/GB) + case bAll >= MB: + hb = fmt.Sprintf("%0.2f MB", bAll/MB) + case bAll >= KB: + hb = fmt.Sprintf("%0.2f KB", bAll/KB) + default: + hb = fmt.Sprintf("%0.2f B", bAll) + } + + return hb +} + +func RandomNum(length int) string { + letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890") + + bytes := make([]rune, length) + + for i := range bytes { + bytes[i] = letterRunes[rand.Intn(len(letterRunes))] + } + + return string(bytes) +} diff --git a/common/util_test.go b/pkg/utils/util_test.go similarity index 97% rename from common/util_test.go rename to pkg/utils/util_test.go index a4c9e42..aa5df2a 100644 --- a/common/util_test.go +++ b/pkg/utils/util_test.go @@ -1,4 +1,4 @@ -package common +package utils import ( "testing" diff --git a/router/router.go b/router/router.go deleted file mode 100644 index 0f5e700..0000000 --- a/router/router.go +++ /dev/null @@ -1,167 +0,0 @@ -package router - -import ( - "net/http" - "path" - "sort" - "strings" - "sync" -) - -const ( - ANY = "ANY" // 包含所有 method -) - -type HttpMux struct { - no http.Handler // NotFoundHandler - mu sync.RWMutex - m map[string]muxEntry // example: GET/index:muxEntry{} - es []muxEntry // 模糊匹配,pattern需要添加后缀 * -} - -type muxEntry struct { - h http.Handler - pattern string - method string -} - -func NewHttpMux() *HttpMux { - http.NewServeMux() - return &HttpMux{ - m: make(map[string]muxEntry), - es: make([]muxEntry, 0), - } -} - -func (mux *HttpMux) SetNotFound(no http.Handler) { - mux.mu.Lock() - defer mux.mu.Unlock() - mux.no = no -} - -func (mux *HttpMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "*" { - w.Header().Set("Connection", "close") - w.WriteHeader(http.StatusBadRequest) - return - } - - h := mux.match(r.Method, r.URL.Path) - h.ServeHTTP(w, r) -} - -func (mux *HttpMux) match(method, rpath string) http.Handler { - mux.mu.RLock() - defer mux.mu.RUnlock() - - path := mux.cleanPath(rpath) - // any 路径 匹配 - p_a := ANY + path - if v, ok := mux.m[p_a]; ok { - return v.h - } - // method 路径 匹配 - method = strings.ToUpper(method) - p_m := method + path - if e, ok := mux.m[p_m]; ok { - return e.h - } - - // Check for longest valid match. mux.es contains all patterns - // that end in / sorted from longest to shortest. - for _, e := range mux.es { - // trim last * - pattern := e.pattern[:len(e.pattern)-1] - // fmt.Println(pattern, p_a, p_m) - if strings.HasPrefix(p_a, pattern) { - return e.h - } - if strings.HasPrefix(p_m, pattern) { - return e.h - } - } - - if mux.no != nil { - return mux.no - } - return http.NotFoundHandler() -} - -func (mux *HttpMux) cleanPath(p string) string { - if p == "" { - return "/" - } - if p[0] != '/' { - p = "/" + p - } - np := path.Clean(p) - // path.Clean removes trailing slash except for root; - // put the trailing slash back if necessary. - if p[len(p)-1] == '/' && np != "/" { - // Fast path for common case of p being the string we want: - if len(p) == len(np)+1 && strings.HasPrefix(p, np) { - np = p - } else { - np += "/" - } - } - return np -} - -func (mux *HttpMux) HandleFunc(method, pattern string, handler func(http.ResponseWriter, *http.Request)) { - if handler == nil { - panic("http: nil handler") - } - mux.Handle(method, pattern, http.HandlerFunc(handler)) -} - -func (mux *HttpMux) Handle(method, pattern string, handler http.Handler) { - mux.mu.Lock() - defer mux.mu.Unlock() - - if pattern == "" || method == "" { - panic("http: invalid pattern") - } - if handler == nil { - panic("http: nil handler") - } - method = strings.ToUpper(method) - p := method + pattern - if _, exist := mux.m[p]; exist { - panic("http: multiple registrations for " + p) - } - - e := muxEntry{h: handler, pattern: p} - mux.m[p] = e - if pattern[len(pattern)-1] == '*' { - mux.es = mux.appendSorted(mux.es, e) - } -} - -func (mux *HttpMux) appendSorted(es []muxEntry, e muxEntry) []muxEntry { - n := len(es) - i := sort.Search(n, func(i int) bool { - return len(es[i].pattern) < len(e.pattern) - }) - if i == n { - return append(es, e) - } - // we now know that i points at where we want to insert - es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. - copy(es[i+1:], es[i:]) // Move shorter entries down - es[i] = e - return es -} - -// ANY /static/* /var/www -func (mux *HttpMux) ServeFile(method, pattern string, root http.FileSystem) { - fs := http.FileServer(root) - - // trim * - pt := pattern[:len(pattern)-1] - mux.HandleFunc(method, pattern, func(w http.ResponseWriter, r *http.Request) { - // 过滤前缀路径 - r.URL.Path = strings.TrimPrefix(r.URL.Path, pt) - fs.ServeHTTP(w, r) - }) -} diff --git a/router/router_test.go b/router/router_test.go deleted file mode 100644 index 42ad205..0000000 --- a/router/router_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Server unit tests - -package router - -import ( - "fmt" - "net/http" - "testing" -) - -func BenchmarkServerMatch(b *testing.B) { - fn := func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "OK") - } - mux := NewHttpMux() - mux.HandleFunc("GET", "/", fn) - mux.HandleFunc("GET", "/index", fn) - mux.HandleFunc("GET", "/home", fn) - mux.HandleFunc("GET", "/about", fn) - mux.HandleFunc("GET", "/contact", fn) - mux.HandleFunc("GET", "/robots.txt", fn) - mux.HandleFunc("GET", "/products/", fn) - mux.HandleFunc("GET", "/products/1", fn) - mux.HandleFunc("GET", "/products/2", fn) - mux.HandleFunc("GET", "/products/3", fn) - mux.HandleFunc("GET", "/products/3/image.jpg", fn) - mux.HandleFunc("GET", "/admin", fn) - mux.HandleFunc("GET", "/admin/products/", fn) - mux.HandleFunc("GET", "/admin/products/create", fn) - mux.HandleFunc("GET", "/admin/products/update", fn) - mux.HandleFunc("GET", "/admin/products/delete", fn) - - paths := []string{"/", "/notfound", "/admin/", "/admin/foo", "/contact", "/products", - "/products/", "/products/3/image.jpg"} - b.StartTimer() - for i := 0; i < b.N; i++ { - path := paths[i%len(paths)] - if h := mux.match("GET", path); h == nil { - b.Error("impossible", path) - } - } - b.StopTimer() -} diff --git a/sessdata/ip_pool.go b/sessdata/ip_pool.go index 1d511ea..55f3079 100644 --- a/sessdata/ip_pool.go +++ b/sessdata/ip_pool.go @@ -6,29 +6,16 @@ import ( "sync" "time" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" ) -const ( - // ip租期 (秒) - IpLease = 1209600 -) - var ( - IpPool = &IpPoolConfig{} - macInfo = map[string]*MacIp{} - ipInfo = map[string]*MacIp{} + IpPool = &ipPoolConfig{} + ipActive = map[string]bool{} ) -type MacIp struct { - IsActive bool - Ip net.IP - MacAddr string - LastLogin time.Time -} - -type IpPoolConfig struct { +type ipPoolConfig struct { mux sync.Mutex // 计算动态ip Ipv4Gateway net.IP @@ -37,25 +24,15 @@ type IpPoolConfig struct { IpLongMax uint32 } -func initIpMac() { - macs := dbdata.GetAllMacIp() - for _, v := range macs { - mi := &MacIp{} - CopyStruct(mi, v) - macInfo[v.MacAddr] = mi - ipInfo[v.Ip.String()] = mi - } -} - func initIpPool() { // 地址处理 // ip地址 - ip := net.ParseIP(common.ServerCfg.Ipv4Network) + ip := net.ParseIP(base.Cfg.Ipv4Network) // 子网掩码 - maskIp := net.ParseIP(common.ServerCfg.Ipv4Netmask).To4() + maskIp := net.ParseIP(base.Cfg.Ipv4Netmask).To4() IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)} - IpPool.Ipv4Gateway = net.ParseIP(common.ServerCfg.Ipv4Gateway) + IpPool.Ipv4Gateway = net.ParseIP(base.Cfg.Ipv4Gateway) // 网络地址零值 // zero := binary.BigEndian.Uint32(ip.Mask(mask)) @@ -64,8 +41,8 @@ func initIpPool() { // max := min | uint32(math.Pow(2, float64(32-one))-1) // ip地址池 - IpPool.IpLongMin = ip2long(net.ParseIP(common.ServerCfg.Ipv4Pool[0])) - IpPool.IpLongMax = ip2long(net.ParseIP(common.ServerCfg.Ipv4Pool[1])) + IpPool.IpLongMin = ip2long(net.ParseIP(base.Cfg.Ipv4Pool[0])) + IpPool.IpLongMax = ip2long(net.ParseIP(base.Cfg.Ipv4Pool[1])) } func long2ip(i uint32) net.IP { @@ -80,79 +57,96 @@ func ip2long(ip net.IP) uint32 { } // 获取动态ip -func AcquireIp(macAddr string) net.IP { +func AcquireIp(username, macAddr string) net.IP { IpPool.mux.Lock() defer IpPool.mux.Unlock() + tNow := time.Now() // 判断已经分配过 - if mi, ok := macInfo[macAddr]; ok { - ip := mi.Ip + mi := &dbdata.IpMap{} + err := dbdata.One("MacAddr", macAddr, mi) + if err == nil { + ip := mi.IpAddr + ipStr := ip.String() // 检测原有ip是否在新的ip池内 if IpPool.Ipv4IPNet.Contains(ip) { - mi.IsActive = true + mi.Username = username mi.LastLogin = tNow // 回写db数据 - dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + dbdata.Save(mi) + ipActive[ipStr] = true return ip } else { - delete(macInfo, macAddr) - delete(ipInfo, ip.String()) - dbdata.Del(dbdata.BucketMacIp, macAddr) + dbdata.Del(mi) } } - farMac := &MacIp{LastLogin: tNow} // 全局遍历未分配ip + // 优先获取没有使用的ip for i := IpPool.IpLongMin; i <= IpPool.IpLongMax; i++ { ip := long2ip(i) ipStr := ip.String() - v, ok := ipInfo[ipStr] - // 该ip没有被使用 - if !ok { - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macInfo[macAddr] = mi - ipInfo[ipStr] = mi - // 回写db数据 - dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + mi := &dbdata.IpMap{} + err := dbdata.One("IpAddr", ip, mi) + if err != nil && dbdata.CheckErrNotFound(err) { + // 该ip没有被使用 + mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} + dbdata.Save(mi) + ipActive[ipStr] = true + return ip + } + } + + farIp := &dbdata.IpMap{LastLogin: tNow} + // 遍历超过租期ip + for i := IpPool.IpLongMin; i <= IpPool.IpLongMax; i++ { + ip := long2ip(i) + ipStr := ip.String() + + // 跳过活跃连接 + if _, ok := ipActive[ipStr]; ok { + continue + } + + v := &dbdata.IpMap{} + err := dbdata.One("IpAddr", ip, v) + if err != nil { + base.Error(err) + return nil + } + if v.Keep { + continue + } + + // 已经超过租期 + if tNow.Sub(v.LastLogin) > time.Duration(base.Cfg.IpLease)*time.Second { + dbdata.Del(v) + mi := &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} + // 重写db数据 + dbdata.Save(mi) + ipActive[ipStr] = true return ip } - // 跳过活跃连接 - if v.IsActive { - continue - } - // 已经超过租期 - if tNow.Sub(v.LastLogin) > IpLease*time.Second { - delete(macInfo, v.MacAddr) - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macInfo[macAddr] = mi - ipInfo[ipStr] = mi - // 回写db数据 - dbdata.Del(dbdata.BucketMacIp, v.MacAddr) - dbdata.Set(dbdata.BucketMacIp, macAddr, mi) - return ip - } - // 其他情况判断最早登陆的mac - if v.LastLogin.Before(farMac.LastLogin) { - farMac = v + // 其他情况判断最早登陆 + if v.LastLogin.Before(farIp.LastLogin) { + farIp = v } } // 全都在线,没有数据可用 - if farMac.MacAddr == "" { + if farIp.Id == 0 { return nil } // 使用最早登陆的mac ip - delete(macInfo, farMac.MacAddr) - ip := farMac.Ip - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macInfo[macAddr] = mi - ipInfo[ip.String()] = mi + ip := farIp.IpAddr + ipStr := ip.String() + mi = &dbdata.IpMap{IpAddr: ip, MacAddr: macAddr, Username: username, LastLogin: tNow} // 回写db数据 - dbdata.Del(dbdata.BucketMacIp, farMac.MacAddr) - dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + dbdata.Save(mi) + ipActive[ipStr] = true return ip } @@ -160,12 +154,12 @@ func AcquireIp(macAddr string) net.IP { func ReleaseIp(ip net.IP, macAddr string) { IpPool.mux.Lock() defer IpPool.mux.Unlock() - if mi, ok := macInfo[macAddr]; ok { - if mi.Ip.Equal(ip) { - mi.IsActive = false - mi.LastLogin = time.Now() - // 回写db数据 - dbdata.Set(dbdata.BucketMacIp, macAddr, mi) - } + + delete(ipActive, ip.String()) + mi := &dbdata.IpMap{} + err := dbdata.One("IpAddr", ip, mi) + if err == nil { + mi.LastLogin = time.Now() + dbdata.Save(mi) } } diff --git a/sessdata/ip_pool_test.go b/sessdata/ip_pool_test.go index dd20cbf..72d32f0 100644 --- a/sessdata/ip_pool_test.go +++ b/sessdata/ip_pool_test.go @@ -7,17 +7,17 @@ import ( "path" "testing" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" "github.com/stretchr/testify/assert" ) func preIpData() { - common.ServerCfg.Ipv4Network = "192.168.3.0" - common.ServerCfg.Ipv4Netmask = "255.255.255.0" - common.ServerCfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"} + base.Cfg.Ipv4Network = "192.168.3.0" + base.Cfg.Ipv4Netmask = "255.255.255.0" + base.Cfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"} tmpDb := path.Join(os.TempDir(), "anylink_test.db") - common.ServerCfg.DbFile = tmpDb + base.Cfg.DbFile = tmpDb dbdata.Start() } @@ -32,27 +32,25 @@ func TestIpPool(t *testing.T) { preIpData() defer closeIpdata() - macInfo = map[string]*MacIp{} - ipInfo = map[string]*MacIp{} initIpPool() var ip net.IP for i := 1; i <= 100; i++ { - ip = AcquireIp(fmt.Sprintf("mac-%d", i)) + ip = AcquireIp("user", fmt.Sprintf("mac-%d", i)) } - ip = AcquireIp(fmt.Sprintf("mac-new")) + ip = AcquireIp("user", fmt.Sprintf("mac-new")) assert.True(net.IPv4(192, 168, 3, 101).Equal(ip)) for i := 102; i <= 199; i++ { - ip = AcquireIp(fmt.Sprintf("mac-%d", i)) + ip = AcquireIp("user", fmt.Sprintf("mac-%d", i)) } assert.True(net.IPv4(192, 168, 3, 199).Equal(ip)) - ip = AcquireIp(fmt.Sprintf("mac-nil")) + ip = AcquireIp("user", fmt.Sprintf("mac-nil")) assert.Nil(ip) ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88") ReleaseIp(net.IPv4(192, 168, 3, 77), "mac-77") // 最早过期的ip - ip = AcquireIp("mac-release-new") + ip = AcquireIp("user", "mac-release-new") assert.True(net.IPv4(192, 168, 3, 88).Equal(ip)) } diff --git a/sessdata/limit_client.go b/sessdata/limit_client.go index c308473..bd92a2f 100644 --- a/sessdata/limit_client.go +++ b/sessdata/limit_client.go @@ -3,7 +3,7 @@ package sessdata import ( "sync" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" ) const limitAllKey = "__ALL__" @@ -16,7 +16,6 @@ var ( func LimitClient(user string, close bool) bool { limitMux.Lock() defer limitMux.Unlock() - // defer fmt.Println(limitClient) _all := limitClient[limitAllKey] c, ok := limitClient[user] @@ -31,12 +30,12 @@ func LimitClient(user string, close bool) bool { } // 全局判断 - if _all >= common.ServerCfg.MaxClient { + if _all >= base.Cfg.MaxClient { return false } // 超出同一个用户限制 - if c >= common.ServerCfg.MaxUserClient { + if c >= base.Cfg.MaxUserClient { return false } diff --git a/sessdata/limit_rate.go b/sessdata/limit_rate.go index e100925..db94ce4 100644 --- a/sessdata/limit_rate.go +++ b/sessdata/limit_rate.go @@ -2,32 +2,10 @@ package sessdata import ( "context" - "fmt" - "time" - - "github.com/bjdgyc/anylink/common" "golang.org/x/time/rate" ) -var Sess = &ConnSession{} - -func init() { - return - tick := time.Tick(time.Second * 2) - go func() { - for range tick { - uP := common.HumanByte(float64(Sess.BandwidthUpPeriod / BandwidthPeriodSec)) - dP := common.HumanByte(float64(Sess.BandwidthDownPeriod / BandwidthPeriodSec)) - uA := common.HumanByte(float64(Sess.BandwidthUpAll)) - dA := common.HumanByte(float64(Sess.BandwidthDownAll)) - - fmt.Printf("rateUp:%s rateDown:%s allUp %s allDown %s \n", - uP, dP, uA, dA) - } - }() -} - type LimitRater struct { limit *rate.Limiter } diff --git a/sessdata/limit_test.go b/sessdata/limit_test.go index 908c496..04252e3 100644 --- a/sessdata/limit_test.go +++ b/sessdata/limit_test.go @@ -6,12 +6,12 @@ import ( "github.com/stretchr/testify/assert" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" ) // func TestCheckUser(t *testing.T) { -// users["user1"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941b"} -// users["user2"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941c"} +// user["user1"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941b"} +// user["user2"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941c"} // // var res bool // res = CheckUser("user1", "123456", "") @@ -23,8 +23,8 @@ import ( func TestLimitClient(t *testing.T) { assert := assert.New(t) - common.ServerCfg.MaxClient = 2 - common.ServerCfg.MaxUserClient = 1 + base.Cfg.MaxClient = 2 + base.Cfg.MaxUserClient = 1 res1 := LimitClient("user1", false) res2 := LimitClient("user1", false) diff --git a/sessdata/online.go b/sessdata/online.go new file mode 100644 index 0000000..634214d --- /dev/null +++ b/sessdata/online.go @@ -0,0 +1,76 @@ +package sessdata + +import ( + "bytes" + "net" + "sort" + "sync/atomic" + "time" + + "github.com/bjdgyc/anylink/pkg/utils" +) + +type Online struct { + Token string `json:"token"` + Username string `json:"username"` + Group string `json:"group"` + MacAddr string `json:"mac_addr"` + Ip net.IP `json:"ip"` + RemoteAddr string `json:"remote_addr"` + TunName string `json:"tun_name"` + Mtu int `json:"mtu"` + Client string `json:"client"` + BandwidthUp string `json:"bandwidth_up"` + BandwidthDown string `json:"bandwidth_down"` + BandwidthUpAll string `json:"bandwidth_up_all"` + BandwidthDownAll string `json:"bandwidth_down_all"` + LastLogin time.Time `json:"last_login"` +} + +type Onlines []Online + +func (o Onlines) Len() int { + return len(o) +} + +func (o Onlines) Less(i, j int) bool { + if bytes.Compare(o[i].Ip, o[j].Ip) < 0 { + return true + } + return false +} + +func (o Onlines) Swap(i, j int) { + o[i], o[j] = o[j], o[i] +} + +func OnlineSess() []Online { + var datas Onlines + sessMux.Lock() + for _, v := range sessions { + v.mux.Lock() + if v.IsActive { + val := Online{ + Token: v.Token, + Ip: v.CSess.IpAddr, + Username: v.Username, + Group: v.Group, + MacAddr: v.MacAddr, + RemoteAddr: v.CSess.RemoteAddr, + TunName: v.CSess.TunName, + Mtu: v.CSess.Mtu, + Client: v.CSess.Client, + BandwidthUp: utils.HumanByte(atomic.LoadUint32(&v.CSess.BandwidthUpPeriod)) + "/s", + BandwidthDown: utils.HumanByte(atomic.LoadUint32(&v.CSess.BandwidthDownPeriod)) + "/s", + BandwidthUpAll: utils.HumanByte(atomic.LoadUint32(&v.CSess.BandwidthUpAll)), + BandwidthDownAll: utils.HumanByte(atomic.LoadUint32(&v.CSess.BandwidthDownAll)), + LastLogin: v.LastLogin, + } + datas = append(datas, val) + } + v.mux.Unlock() + } + sessMux.Unlock() + sort.Sort(&datas) + return datas +} diff --git a/sessdata/session.go b/sessdata/session.go index 70725e4..d8ce5d4 100644 --- a/sessdata/session.go +++ b/sessdata/session.go @@ -12,26 +12,29 @@ import ( "sync/atomic" "time" - "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/base" + "github.com/bjdgyc/anylink/dbdata" ) -const BandwidthPeriodSec = 2 // 流量速率统计周期(秒) - var ( // session_token -> SessUser - sessions = sync.Map{} // make(map[string]*Session) + sessions = make(map[string]*Session) + sessMux sync.Mutex ) // 连接sess type ConnSession struct { Sess *Session MasterSecret string // dtls协议的 master_secret - Ip net.IP // 分配的ip地址 + IpAddr net.IP // 分配的ip地址 LocalIp net.IP MacHw net.HardwareAddr // 客户端mac地址,从Session取出 RemoteAddr string Mtu int TunName string + Client string // 客户端 mobile pc + CstpDpd int + Group *dbdata.Group Limit *LimitRater BandwidthUp uint32 // 使用上行带宽 Byte BandwidthDown uint32 // 使用下行带宽 Byte @@ -43,7 +46,6 @@ type ConnSession struct { CloseChan chan struct{} PayloadIn chan *Payload PayloadOut chan *Payload - PayloadArp chan *Payload } type Session struct { @@ -53,7 +55,10 @@ type Session struct { DtlsSid string // dtls协议的 session_id MacAddr string // 客户端mac地址 UniqueIdGlobal string // 客户端唯一标示 - UserName string // 用户名 + Username string // 用户名 + Group string + AuthStep string + AuthPass string LastLogin time.Time IsActive bool @@ -67,45 +72,48 @@ func init() { } func checkSession() { - // 检测过期的session go func() { - if common.ServerCfg.SessionTimeout == 0 { + if base.Cfg.SessionTimeout == 0 { return } - timeout := time.Duration(common.ServerCfg.SessionTimeout) * time.Second + timeout := time.Duration(base.Cfg.SessionTimeout) * time.Second tick := time.Tick(time.Second * 60) for range tick { + sessMux.Lock() t := time.Now() - - sessions.Range(func(key, value interface{}) bool { - v := value.(*Session) + for k, v := range sessions { v.mux.Lock() - defer v.mux.Unlock() - - if v.IsActive == true { - return true + if v.IsActive != true { + if t.Sub(v.LastLogin) > timeout { + delete(sessions, k) + } } - if t.Sub(v.LastLogin) > timeout { - sessions.Delete(key) - } - return true - }) - + v.mux.Unlock() + } + sessMux.Unlock() } }() } -func NewSession() *Session { +func GenToken() string { // 生成32位的 token btoken := make([]byte, 32) rand.Read(btoken) + return fmt.Sprintf("%x", btoken) +} + +func NewSession(token string) *Session { + if token == "" { + btoken := make([]byte, 32) + rand.Read(btoken) + token = fmt.Sprintf("%x", btoken) + } // 生成 dtlsn session_id dtlsid := make([]byte, 32) rand.Read(dtlsid) - token := fmt.Sprintf("%x", btoken) sess := &Session{ Sid: fmt.Sprintf("%d", time.Now().Unix()), Token: token, @@ -113,7 +121,9 @@ func NewSession() *Session { LastLogin: time.Now(), } - sessions.Store(token, sess) + sessMux.Lock() + sessions[token] = sess + sessMux.Unlock() return sess } @@ -121,12 +131,13 @@ func (s *Session) NewConn() *ConnSession { s.mux.Lock() active := s.IsActive macAddr := s.MacAddr + username := s.Username s.mux.Unlock() if active == true { s.CSess.Close() } - limit := LimitClient(s.UserName, false) + limit := LimitClient(username, false) if limit == false { return nil } @@ -138,21 +149,34 @@ func (s *Session) NewConn() *ConnSession { macHw = append([]byte{0x00}, macHw...) macAddr = macHw.String() } - ip := AcquireIp(macAddr) + ip := AcquireIp(username, macAddr) if ip == nil { + LimitClient(username, true) return nil } cSess := &ConnSession{ Sess: s, MacHw: macHw, - Ip: ip, + IpAddr: ip, closeOnce: sync.Once{}, CloseChan: make(chan struct{}), PayloadIn: make(chan *Payload), PayloadOut: make(chan *Payload), - PayloadArp: make(chan *Payload, 1000), - // Limit: NewLimitRater(1024 * 1024), + } + + // 查询group信息 + group := &dbdata.Group{} + err = dbdata.One("Name", s.Group, group) + if err != nil { + base.Error(err) + cSess.Close() + return nil + } + cSess.Group = group + if group.Bandwidth > 0 { + // 限流设置 + cSess.Limit = NewLimitRater(group.Bandwidth, group.Bandwidth) } go cSess.ratePeriod() @@ -167,7 +191,7 @@ func (s *Session) NewConn() *ConnSession { func (cs *ConnSession) Close() { cs.closeOnce.Do(func() { - log.Println("closeOnce:", cs.Ip) + log.Println("closeOnce:", cs.IpAddr) cs.Sess.mux.Lock() defer cs.Sess.mux.Unlock() @@ -176,11 +200,13 @@ func (cs *ConnSession) Close() { cs.Sess.LastLogin = time.Now() cs.Sess.CSess = nil - ReleaseIp(cs.Ip, cs.Sess.MacAddr) - LimitClient(cs.Sess.UserName, true) + ReleaseIp(cs.IpAddr, cs.Sess.MacAddr) + LimitClient(cs.Sess.Username, true) }) } +const BandwidthPeriodSec = 2 // 流量速率统计周期(秒) + func (cs *ConnSession) ratePeriod() { tick := time.Tick(time.Second * BandwidthPeriodSec) for range tick { @@ -193,9 +219,9 @@ func (cs *ConnSession) ratePeriod() { // 实时流量清零 rtUp := atomic.SwapUint32(&cs.BandwidthUp, 0) rtDown := atomic.SwapUint32(&cs.BandwidthDown, 0) - // 设置上一周期的流量 - atomic.SwapUint32(&cs.BandwidthUpPeriod, rtUp) - atomic.SwapUint32(&cs.BandwidthDownPeriod, rtDown) + // 设置上一周期每秒的流量 + atomic.SwapUint32(&cs.BandwidthUpPeriod, rtUp/BandwidthPeriodSec) + atomic.SwapUint32(&cs.BandwidthDownPeriod, rtDown/BandwidthPeriodSec) // 累加所有流量 atomic.AddUint32(&cs.BandwidthUpAll, rtUp) atomic.AddUint32(&cs.BandwidthDownAll, rtDown) @@ -217,6 +243,12 @@ func (cs *ConnSession) SetMtu(mtu string) { } } +func (cs *ConnSession) SetTunName(name string) { + cs.Sess.mux.Lock() + defer cs.Sess.mux.Unlock() + cs.TunName = name +} + func (cs *ConnSession) RateLimit(byt int, isUp bool) error { if isUp { atomic.AddUint32(&cs.BandwidthUp, uint32(byt)) @@ -235,11 +267,13 @@ func SToken2Sess(stoken string) *Session { sarr := strings.Split(stoken, "@") token := sarr[1] - if sess, ok := sessions.Load(token); ok { - return sess.(*Session) - } + return Token2Sess(token) +} - return nil +func Token2Sess(token string) *Session { + sessMux.Lock() + defer sessMux.Unlock() + return sessions[token] } func Dtls2Sess(dtlsid []byte) *Session { @@ -250,9 +284,23 @@ func DelSess(token string) { // sessions.Delete(token) } +func CloseSess(token string) { + sessMux.Lock() + defer sessMux.Unlock() + sess, ok := sessions[token] + if !ok { + return + } + + delete(sessions, token) + sess.CSess.Close() +} + func DelSessByStoken(stoken string) { stoken = strings.TrimSpace(stoken) sarr := strings.Split(stoken, "@") token := sarr[1] - sessions.Delete(token) + sessMux.Lock() + delete(sessions, token) + sessMux.Unlock() } diff --git a/sessdata/session_test.go b/sessdata/session_test.go index 76022f5..3f243eb 100644 --- a/sessdata/session_test.go +++ b/sessdata/session_test.go @@ -1,7 +1,6 @@ package sessdata import ( - "sync" "testing" "github.com/stretchr/testify/assert" @@ -9,10 +8,10 @@ import ( func TestNewSession(t *testing.T) { assert := assert.New(t) - sessions = sync.Map{} - sess := NewSession() + sessions = make(map[string]*Session) + sess := NewSession("") token := sess.Token - v, ok := sessions.Load(token) + v, ok := sessions[token] assert.True(ok) assert.Equal(sess, v) } @@ -21,7 +20,7 @@ func TestConnSession(t *testing.T) { assert := assert.New(t) preIpData() defer closeIpdata() - sess := NewSession() + sess := NewSession("") cSess := sess.NewConn() cSess.RateLimit(100, true) assert.Equal(cSess.BandwidthUp, uint32(100)) diff --git a/sessdata/start.go b/sessdata/start.go index fa8d86a..73773ce 100644 --- a/sessdata/start.go +++ b/sessdata/start.go @@ -2,6 +2,5 @@ package sessdata func Start() { initIpPool() - initIpMac() checkSession() }