From 31b1f12dbe2bd19eeb00f86b4fbb66514c62e2ed Mon Sep 17 00:00:00 2001 From: bjdgyc Date: Mon, 14 Sep 2020 17:17:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=9F=BA=E4=BA=8Etap?= =?UTF-8?q?=E8=AE=BE=E5=A4=87=E7=9A=84=E6=A1=A5=E6=8E=A5=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 70 ++++-- arpdis/addr.go | 99 +++++++++ arpdis/addr_test.go | 35 +++ arpdis/arp.go | 56 +++++ arpdis/icmp.go | 141 ++++++++++++ arpdis/lookup.go | 63 ++++++ bridge-init.sh | 42 ++++ common/app_ver.go | 2 +- common/assert_test.go | 10 - common/cfg_server.go | 49 ++-- common/cfg_user.go | 91 -------- common/cfg_user_test.go | 33 --- common/flag.go | 11 +- common/ip_pool.go | 139 ------------ common/ip_pool_test.go | 50 ----- common/log.go | 72 +++++- common/util.go | 39 ++++ common/util_test.go | 29 +++ conf/server.toml | 52 +++-- conf/user.toml | 17 -- dbdata/db.go | 137 ++++++++++++ dbdata/db_test.go | 63 ++++++ dbdata/group.go | 36 +++ dbdata/mac_ip.go | 34 +++ dbdata/start.go | 9 + dbdata/user.go | 29 +++ go.mod | 10 +- go.sum | 31 ++- handler/base.go | 30 ++- handler/dtls.go | 1 - handler/link_auth.go | 17 +- handler/link_cstp.go | 78 ++++--- handler/link_home.go | 7 +- handler/link_tap.go | 272 +++++++++++++++++++++++ handler/link_tun.go | 60 ++--- handler/link_tunnel.go | 44 +++- handler/payload.go | 47 ++++ handler/server.go | 46 ++-- handler/session.go | 157 ------------- handler/start.go | 24 ++ handler/user.go | 71 ++++++ main.go | 3 + proxyproto/protocol.go | 3 - proxyproto/protocol_test.go | 2 - router/router.go | 167 ++++++++++++++ router/router_test.go | 47 ++++ sessdata/copy_struct.go | 53 +++++ sessdata/copy_struct_test.go | 38 ++++ sessdata/ip_pool.go | 171 ++++++++++++++ sessdata/ip_pool_test.go | 58 +++++ sessdata/limit_client.go | 46 ++++ sessdata/limit_rate.go | 45 ++++ sessdata/limit_test.go | 55 +++++ handler/proto.go => sessdata/protocol.go | 14 +- sessdata/session.go | 258 +++++++++++++++++++++ sessdata/session_test.go | 31 +++ sessdata/start.go | 7 + 57 files changed, 2598 insertions(+), 703 deletions(-) create mode 100644 arpdis/addr.go create mode 100644 arpdis/addr_test.go create mode 100644 arpdis/arp.go create mode 100644 arpdis/icmp.go create mode 100644 arpdis/lookup.go create mode 100644 bridge-init.sh delete mode 100644 common/assert_test.go delete mode 100644 common/cfg_user.go delete mode 100644 common/cfg_user_test.go delete mode 100644 common/ip_pool.go delete mode 100644 common/ip_pool_test.go create mode 100644 common/util.go create mode 100644 common/util_test.go delete mode 100644 conf/user.toml create mode 100644 dbdata/db.go create mode 100644 dbdata/db_test.go create mode 100644 dbdata/group.go create mode 100644 dbdata/mac_ip.go create mode 100644 dbdata/start.go create mode 100644 dbdata/user.go create mode 100644 handler/link_tap.go create mode 100644 handler/payload.go delete mode 100644 handler/session.go create mode 100644 handler/start.go create mode 100644 handler/user.go create mode 100644 router/router.go create mode 100644 router/router_test.go create mode 100644 sessdata/copy_struct.go create mode 100644 sessdata/copy_struct_test.go create mode 100644 sessdata/ip_pool.go create mode 100644 sessdata/ip_pool_test.go create mode 100644 sessdata/limit_client.go create mode 100644 sessdata/limit_rate.go create mode 100644 sessdata/limit_test.go rename handler/proto.go => sessdata/protocol.go (95%) create mode 100644 sessdata/session.go create mode 100644 sessdata/session_test.go create mode 100644 sessdata/start.go diff --git a/README.md b/README.md index a7139e3..4f02bfa 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannop AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。 -AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。 +AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun/tap功能、ip设置命令。 + ## Installation @@ -22,43 +23,80 @@ sudo ./anylink -conf="conf/server.toml" ## Feature -- [x] IP分配 +- [x] IP分配(实现IP、MAC映射信息的持久化) - [x] TLS-TCP通道 - [x] 兼容AnyConnect +- [x] 基于tun设备的nat访问模式 +- [x] 基于tap设备的桥接访问模式 - [x] 多用户支持 - [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议 -- [ ] DTLS-UDP通道 -- [ ] 后台管理界面 + - [ ] 用户组支持 - [ ] TOTP令牌支持 - [ ] 流量控制 - [ ] 访问权限管理 +- [ ] 后台管理界面 +- [ ] DTLS-UDP通道 + ## Config +默认配置文件内有详细的注释,根据注释填写配置即可。 + - [conf/server.toml](https://github.com/bjdgyc/anylink/blob/master/conf/server.toml) -- [conf/user.toml](https://github.com/bjdgyc/anylink/blob/master/conf/user.toml) ## Setting -1. 开启服务器转发 - ``` - # flie: /etc/sysctl.conf - net.ipv4.ip_forward = 1 +网络模式选择,需要配置 `link_mode` 参数,如 `link_mode="tun"`,`link_mode="tap"` 两种参数。 +不同的参数需要对服务器做相应的设置。 - #执行如下命令 - sysctl -w net.ipv4.ip_forward=1 - ``` +建议优先选择tun模式,因客户端传输的是IP层数据,无须进行数据转换。 +tap模式是在用户态做的链路层到IP层的数据互相转换,性能会有所下降。 + +### tun设置 + +1. 开启服务器转发 + ``` + # flie: /etc/sysctl.conf + net.ipv4.ip_forward = 1 + + #执行如下命令 + sysctl -w net.ipv4.ip_forward=1 + ``` 2. 设置nat转发规则 - ``` - # eth0为服务器内网网卡 - iptables -t nat -A POSTROUTING -s 192.168.10.0/255.255.255.0 -o eth0 -j MASQUERADE - ``` + ``` + # eth0为服务器内网网卡 + iptables -t nat -A POSTROUTING -s 192.168.10.0/255.255.255.0 -o eth0 -j MASQUERADE + ``` 3. 使用AnyConnect客户端连接即可 +### tap设置 + +1. 创建桥接网卡 + ``` + 注意 server.toml 的ip参数,需要与 bridge.sh 的配置参数一致 + ``` + +2. 修改 bridge-init.sh 内的参数 + ``` + # file: ./bridge.sh + eth="eth0" + eth_ip="192.168.1.4" + eth_netmask="255.255.255.0" + eth_broadcast="192.168.1.255" + eth_gateway="192.168.1.1" + ``` + +3. 执行 bridge.sh 文件 + ``` + sh bridge.sh + ``` + + + ## License 本项目采用 MIT 开源授权许可证,完整的授权说明已放置在 LICENSE 文件中。 diff --git a/arpdis/addr.go b/arpdis/addr.go new file mode 100644 index 0000000..8db9c10 --- /dev/null +++ b/arpdis/addr.go @@ -0,0 +1,99 @@ +package arpdis + +import ( + "net" + "sync" + "time" +) + +const ( + StaleTimeNormal = time.Minute * 5 + StaleTimeUnreachable = time.Minute * 10 + + TypeNormal = 0 + TypeUnreachable = 1 + TypeStatic = 2 +) + +var ( + table = make(map[string]*Addr) + tableMu sync.RWMutex +) + +type Addr struct { + IP net.IP + HardwareAddr net.HardwareAddr + disTime time.Time + Type int8 +} + +func Lookup(ip net.IP, onlyTable bool) *Addr { + addr := tableLookup(ip.To4()) + if addr != nil || onlyTable { + return addr + } + + addr = doLookup(ip.To4()) + Add(addr) + return addr +} + +// Add adds a IP-MAC map to a runtime ARP table. +func tableLookup(ip net.IP) *Addr { + tableMu.Lock() + addr := table[ip.To4().String()] + tableMu.Unlock() + if addr == nil { + return nil + } + + // 判断老化过期时间 + tsub := time.Now().Sub(addr.disTime) + switch addr.Type { + case TypeNormal: + if tsub > StaleTimeNormal { + return nil + } + case TypeUnreachable: + if tsub > StaleTimeUnreachable { + return nil + } + case TypeStatic: + } + + return addr +} + +// Add adds a IP-MAC map to a runtime ARP table. +func Add(addr *Addr) { + if addr == nil { + return + } + if addr.disTime.IsZero() { + addr.disTime = time.Now() + } + ip := addr.IP.To4().String() + tableMu.Lock() + defer tableMu.Unlock() + if a, ok := table[ip]; ok { + // 静态地址只能设置一次 + if a.Type == TypeStatic { + return + } + } + table[ip] = addr +} + +// Delete removes an IP from the runtime ARP table. +func Delete(ip net.IP) { + tableMu.Lock() + defer tableMu.Unlock() + delete(table, ip.To4().String()) +} + +// List returns the current runtime ARP table. +func List() map[string]*Addr { + tableMu.RLock() + defer tableMu.RUnlock() + return table +} diff --git a/arpdis/addr_test.go b/arpdis/addr_test.go new file mode 100644 index 0000000..70a629b --- /dev/null +++ b/arpdis/addr_test.go @@ -0,0 +1,35 @@ +package arpdis + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLookup(t *testing.T) { + assert := assert.New(t) + ip := net.IPv4(192, 168, 10, 2) + hw, _ := net.ParseMAC("08:00:27:a0:17:42") + now := time.Now() + addr1 := &Addr{ + IP: ip, + HardwareAddr: hw, + Type: TypeStatic, + disTime: now, + } + Add(addr1) + addr2 := Lookup(ip, true) + assert.Equal(addr1, addr2) + addr3 := &Addr{ + IP: ip, + HardwareAddr: hw, + Type: TypeNormal, + disTime: now, + } + Add(addr3) + addr4 := Lookup(ip, true) + // 静态地址只能设置一次 + assert.NotEqual(addr3, addr4) +} diff --git a/arpdis/arp.go b/arpdis/arp.go new file mode 100644 index 0000000..0df7198 --- /dev/null +++ b/arpdis/arp.go @@ -0,0 +1,56 @@ +package arpdis + +// Reference: github.com/malfunkt/arpfox +// TODO now only support IPv4 + +import ( + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +var defaultSerializeOpts = gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, +} + +// NewARPRequest creates a bew ARP packet of type "request. +func NewARPRequest(src *Addr, dst *Addr) ([]byte, error) { + return buildPacket(src, dst, layers.ARPRequest) +} + +// NewARPReply creates a new ARP packet of type "reply". +func NewARPReply(src *Addr, dst *Addr) ([]byte, error) { + return buildPacket(src, dst, layers.ARPReply) +} + +// buildPacket creates an template ARP packet with the given source and +// destination. +func buildPacket(src *Addr, dst *Addr, opt uint16) ([]byte, error) { + ether := layers.Ethernet{ + EthernetType: layers.EthernetTypeARP, + SrcMAC: src.HardwareAddr, + DstMAC: dst.HardwareAddr, + } + arp := layers.ARP{ + AddrType: layers.LinkTypeEthernet, + Protocol: layers.EthernetTypeIPv4, + + HwAddressSize: 6, + ProtAddressSize: 4, + Operation: opt, + + SourceHwAddress: src.HardwareAddr, + SourceProtAddress: src.IP.To4(), + + DstHwAddress: dst.HardwareAddr, + DstProtAddress: dst.IP.To4(), + } + + buf := gopacket.NewSerializeBuffer() + err := gopacket.SerializeLayers(buf, defaultSerializeOpts, ðer, &arp) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/arpdis/icmp.go b/arpdis/icmp.go new file mode 100644 index 0000000..56fe6b0 --- /dev/null +++ b/arpdis/icmp.go @@ -0,0 +1,141 @@ +package arpdis + +import ( + "errors" + "net" + "os" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" +) + +var ( + ipv4Id uint16 = 1 + icmpSeq uint16 = 1 +) + +func NewIPv4Unreachable(src *Addr, dst *Addr) ([]byte, error) { + ipv4Id++ + icmpSeq++ + + ipv4 := layers.IPv4{ + Version: 4, + IHL: 5, + TOS: 0, + // Length : uint16, + Id: ipv4Id, + Flags: 0, + FragOffset: 0, + TTL: 10, + Protocol: layers.IPProtocolICMPv4, + // Checksum : uint16, + SrcIP: src.IP.To4(), + DstIP: dst.IP.To4(), + Options: nil, + Padding: nil, + } + + icmp := layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeDestinationUnreachable, layers.ICMPv4CodeNet), + Id: uint16(os.Getpid()), + Seq: icmpSeq, + } + + buf := gopacket.NewSerializeBuffer() + err := gopacket.SerializeLayers(buf, defaultSerializeOpts, &ipv4, &icmp) + + return buf.Bytes(), err +} + +const ( + ProtocolICMP = 1 + ProtocolIPv6ICMP = 58 +) + +func doPing(ip string) error { + raddr, _ := net.ResolveIPAddr("ip4:icmp", ip) + conn, err := icmp.ListenPacket("ip4:icmp", "") + if err != nil { + return err + } + + ipv4Conn := conn.IPv4PacketConn() + // 限制跳跃数 + err = ipv4Conn.SetTTL(10) + if err != nil { + return err + } + + msg := &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: os.Getpid() & 0xffff, + Seq: 1, + Data: timeToBytes(time.Now()), + }, + } + + b, err := msg.Marshal(nil) + if err != nil { + return err + } + _, err = conn.WriteTo(b, raddr) + if err != nil { + return err + } + + conn.SetReadDeadline(time.Now().Add(time.Second * 2)) + + for { + buf := make([]byte, 512) + n, dst, err := conn.ReadFrom(buf) + if err != nil { + return err + } + if dst.String() != ip { + continue + } + + var result *icmp.Message + result, err = icmp.ParseMessage(ProtocolICMP, buf[:n]) + if err != nil { + return err + } + + switch result.Type { + case ipv4.ICMPTypeEchoReply: + // success + if rply, ok := result.Body.(*icmp.Echo); ok { + _ = rply + // log.Printf("%+v \n", rply) + } + return nil + + // case ipv4.ICMPTypeTimeExceeded: + // case ipv4.ICMPTypeDestinationUnreachable: + default: + return errors.New("DestinationUnreachable") + } + } +} + +func timeToBytes(t time.Time) []byte { + nsec := t.UnixNano() + b := make([]byte, 8) + for i := uint8(0); i < 8; i++ { + b[i] = byte((nsec >> ((7 - i) * 8)) & 0xff) + } + return b +} + +func bytesToTime(b []byte) time.Time { + var nsec int64 + for i := uint8(0); i < 8; i++ { + nsec += int64(b[i]) << ((7 - i) * 8) + } + return time.Unix(nsec/1000000000, nsec%1000000000) +} diff --git a/arpdis/lookup.go b/arpdis/lookup.go new file mode 100644 index 0000000..9ffda6c --- /dev/null +++ b/arpdis/lookup.go @@ -0,0 +1,63 @@ +// Currently only Darwin and Linux support this. + +// +build darwin linux + +package arpdis + +import ( + "log" + "net" + "os/exec" + "strings" +) + +func doLookup(ip net.IP) *Addr { + // ping := exec.Command("ping", "-c1", "-t1", ip.String()) + // if err := ping.Run(); err != nil { + // addr := &Addr{IP: ip, Type: TypeUnreachable} + // return addr + // } + + err := doPing(ip.String()) + if err != nil { + // log.Println(err) + addr := &Addr{IP: ip, Type: TypeUnreachable} + return addr + } + + return doArpShow(ip) +} + +func doArpShow(ip net.IP) *Addr { + cmd := exec.Command("ip", "n", "show", ip.String()) + out, err := cmd.Output() + if err != nil { + log.Println("lookup show", err) + return nil + } + + // os.Open("/proc/net/arp") + // 192.168.1.2 0x1 0x2 e0:94:67:e2:42:5d * eth0 + // 192.168.1.2 dev eth0 lladdr 08:00:27:94:a5:a4 STALE + outS := strings.ReplaceAll(string(out), " ", " ") + outS = strings.TrimSpace(outS) + arpArr := strings.Split(outS, " ") + if len(arpArr) != 6 { + log.Println("lookup arpArr", outS, ip) + return nil + } + mac, err := net.ParseMAC(arpArr[4]) + if err != nil { + log.Println("lookup mac", outS, err) + return nil + } + + return &Addr{IP: ip, HardwareAddr: mac} +} + +// IP address HW type Flags HW address Mask Device +// 172.23.24.12 0x1 0x2 00:e0:4c:73:5c:48 * anylink0 +// 172.23.24.1 0x1 0x2 3c:8c:40:a0:7a:2d * anylink0 +// 172.23.24.13 0x1 0x2 00:1c:42:4d:33:46 * anylink0 +// 172.23.24.2 0x1 0x0 00:00:00:00:00:00 * anylink0 +// 172.23.24.14 0x1 0x0 00:00:00:00:00:00 * anylink0 diff --git a/bridge-init.sh b/bridge-init.sh new file mode 100644 index 0000000..f3ca4ca --- /dev/null +++ b/bridge-init.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +################################# +# Set up Ethernet bridge on Linux +# Requires: bridge-utils +################################# + +# Define Bridge Interface +br="anylink0" + +# Define list of TAP interfaces to be bridged, +# for example tap="tap0 tap1 tap2". +tap="tap0" + +# Define physical ethernet interface to be bridged +# with TAP interface(s) above. + +eth="eth0" +eth_ip="192.168.1.4" +eth_netmask="255.255.255.0" +eth_broadcast="192.168.1.255" +eth_gateway="192.168.1.1" + + +brctl addbr $br +brctl addif $br $eth + +ifconfig $eth 0.0.0.0 up + +mac=`cat /sys/class/net/$eth/address` +ifconfig $br hw ether $mac +ifconfig $br $eth_ip netmask $eth_netmask broadcast $eth_broadcast up + +route add default gateway $eth_gateway + + + + + + + + diff --git a/common/app_ver.go b/common/app_ver.go index 485658c..85c3d0c 100644 --- a/common/app_ver.go +++ b/common/app_ver.go @@ -2,5 +2,5 @@ package common const ( APP_NAME = "AnyLink" - APP_VER = "0.0.1" + APP_VER = "0.0.3" ) diff --git a/common/assert_test.go b/common/assert_test.go deleted file mode 100644 index 90fe9c2..0000000 --- a/common/assert_test.go +++ /dev/null @@ -1,10 +0,0 @@ -package common - -import "testing" - -func AssertTrue(t *testing.T, a bool) { - t.Helper() - if !a { - t.Errorf("Not True %t", a) - } -} diff --git a/common/cfg_server.go b/common/cfg_server.go index 4f6854f..27cb661 100644 --- a/common/cfg_server.go +++ b/common/cfg_server.go @@ -8,6 +8,11 @@ import ( "github.com/pelletier/go-toml" ) +const ( + LinkModeTUN = "tun" + LinkModeTAP = "tap" +) + var ( ServerCfg = &ServerConfig{} ) @@ -24,28 +29,34 @@ var ( // rekey-method = ssl type ServerConfig struct { - UserFile string `toml:"user_file"` - ServerAddr string `toml:"server_addr"` - DebugAddr string `toml:"debug_addr"` - ProxyProtocol bool `toml:"proxy_protocol"` - CertFile string `toml:"cert_file"` - CertKey string `toml:"cert_key"` - LinkGroups []string `toml:"link_groups"` + ServerAddr string `toml:"server_addr"` + AdminAddr string `toml:"admin_addr"` + ProxyProtocol bool `toml:"proxy_protocol"` + DbFile string `toml:"db_file"` + CertFile string `toml:"cert_file"` + CertKey string `toml:"cert_key"` + LogLevel string `toml:"log_level"` + + LinkMode string `toml:"link_mode"` // tun tap + Ipv4Network string `toml:"ipv4_network"` // 192.168.1.0 + Ipv4Netmask string `toml:"ipv4_netmask"` // 255.255.255.0 + Ipv4Gateway string `toml:"ipv4_gateway"` + Ipv4Pool []string `toml:"ipv4_pool"` // Pool[0]=192.168.1.100 Pool[1]=192.168.1.200 + Include []string `toml:"include"` // 10.10.10.0/255.255.255.0 + Exclude []string `toml:"exclude"` // 192.168.5.0/255.255.255.0 + ClientDns []string `toml:"client_dns"` // 114.114.114.114 + AllowLan bool `toml:"allow_lan"` // 允许本地LAN访问vpn网络 + MaxClient int `toml:"max_client"` + MaxUserClient int `toml:"max_user_client"` + + UserGroups []string `toml:"user_groups"` DefaultGroup string `toml:"default_group"` - Banner string `toml:"banner"` // 欢迎语 - CstpDpd int `toml:"cstp_dpd"` // Dead peer detection in seconds + Banner string `toml:"banner"` // 欢迎语 + CstpDpd int `toml:"cstp_dpd"` // Dead peer detection in seconds + MobileDpd int `toml:"mobile_dpd"` CstpKeepalive int `toml:"cstp_keepalive"` // in seconds SessionTimeout int `toml:"session_timeout"` // in seconds AuthTimeout int `toml:"auth_timeout"` // in seconds - MaxClient int `toml:"max_client"` - MaxUserClient int `toml:"max_user_client"` - Ipv4Network string `toml:"ipv4_network"` // 192.168.1.0 - Ipv4Netmask string `toml:"ipv4_netmask"` // 255.255.255.0 - Ipv4GateWay string `toml:"-"` - Include []string `toml:"include"` // 10.10.10.0/255.255.255.0 - Exclude []string `toml:"exclude"` // 192.168.5.0/255.255.255.0 - ClientDns []string `toml:"client_dns"` // 114.114.114.114 - AllowLan bool `toml:"allow_lan"` // 允许本地LAN访问vpn网络 } func loadServer() { @@ -62,7 +73,7 @@ func loadServer() { base := filepath.Dir(sf) // 转换成绝对路径 - ServerCfg.UserFile = getAbsPath(base, ServerCfg.UserFile) + ServerCfg.DbFile = getAbsPath(base, ServerCfg.DbFile) ServerCfg.CertFile = getAbsPath(base, ServerCfg.CertFile) ServerCfg.CertKey = getAbsPath(base, ServerCfg.CertKey) diff --git a/common/cfg_user.go b/common/cfg_user.go deleted file mode 100644 index f023dc6..0000000 --- a/common/cfg_user.go +++ /dev/null @@ -1,91 +0,0 @@ -package common - -import ( - "crypto/sha1" - "fmt" - "io/ioutil" - "sync" - - "github.com/pelletier/go-toml" -) - -var ( - users = map[string]User{} - limitClient = map[string]int{"_all": 0} - limitMux = sync.Mutex{} -) - -type User struct { - Group string `toml:"group"` - Username string `toml:"-"` - Password string `toml:"password"` - OtpSecret string `toml:"otp_secret"` -} - -func CheckUser(name, pwd, group string) bool { - user, ok := users[name] - if !ok { - return false - } - pwdHash := hashPass(pwd) - if user.Password == pwdHash { - return true - } - return false -} - -func hashPass(pwd string) string { - sum := sha1.Sum([]byte(pwd)) - return fmt.Sprintf("%x", sum) -} - -func LimitClient(name string, close bool) bool { - limitMux.Lock() - defer limitMux.Unlock() - // defer fmt.Println(limitClient) - - _all := limitClient["_all"] - c, ok := limitClient[name] - if !ok { // 不存在用户 - limitClient[name] = 0 - } - - if close { - limitClient[name] = c - 1 - limitClient["_all"] = _all - 1 - return true - } - - // 全局判断 - if _all >= ServerCfg.MaxClient { - return false - } - - // 超出同一个用户限制 - if c >= ServerCfg.MaxUserClient { - return false - } - - limitClient[name] = c + 1 - limitClient["_all"] = _all + 1 - return true -} - -func loadUser() { - b, err := ioutil.ReadFile(ServerCfg.UserFile) - if err != nil { - panic(err) - } - err = toml.Unmarshal(b, &users) - if err != nil { - panic(err) - } - - // 添加用户名 - for k, v := range users { - v.Username = k - users[k] = v - } - - fmt.Println("users:", users) -} diff --git a/common/cfg_user_test.go b/common/cfg_user_test.go deleted file mode 100644 index 6466534..0000000 --- a/common/cfg_user_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package common - -import ( - "testing" -) - -func TestCheckUser(t *testing.T) { - users["user1"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941b"} - users["user2"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941c"} - - var res bool - res = CheckUser("user1", "123456", "") - AssertTrue(t, res == true) - - res = CheckUser("user2", "123457", "") - AssertTrue(t, res == false) -} - -func TestLimitClient(t *testing.T) { - ServerCfg.MaxClient = 2 - ServerCfg.MaxUserClient = 1 - - res1 := LimitClient("user1", false) - res2 := LimitClient("user1", false) - res3 := LimitClient("user2", false) - res4 := LimitClient("user3", false) - - AssertTrue(t, res1 == true) - AssertTrue(t, res2 == false) - AssertTrue(t, res3 == true) - AssertTrue(t, res4 == false) - -} diff --git a/common/flag.go b/common/flag.go index bf0c437..b0ada7d 100644 --- a/common/flag.go +++ b/common/flag.go @@ -12,23 +12,15 @@ var ( CommitId string // 配置文件 serverFile string - passwd string // 显示版本信息 rev bool ) func initFlag() { flag.StringVar(&serverFile, "conf", "./conf/server.toml", "server config file path") - flag.StringVar(&passwd, "pass", "", "generation a sha1 password") flag.BoolVar(&rev, "rev", false, "display version info") flag.Parse() - if passwd != "" { - pwdHash := hashPass(passwd) - fmt.Printf("passwd-sha1:%s\n", pwdHash) - os.Exit(0) - } - if rev { fmt.Printf("%s v%s build on %s [%s, %s] commit_id(%s) \n", APP_NAME, APP_VER, runtime.Version(), runtime.GOOS, runtime.GOARCH, CommitId) @@ -39,6 +31,5 @@ func initFlag() { func InitConfig() { initFlag() loadServer() - loadUser() - initIpPool() + initLog() } diff --git a/common/ip_pool.go b/common/ip_pool.go deleted file mode 100644 index 3f84e20..0000000 --- a/common/ip_pool.go +++ /dev/null @@ -1,139 +0,0 @@ -package common - -import ( - "encoding/binary" - "math" - "net" - "sync" - "time" -) - -const ( - // ip租期 (秒) - IpLease = 1209600 -) - -var ( - ipPool = &IpPoolConfig{} - macIps = map[string]*MacIp{} -) - -type MacIp struct { - IsActive bool - Ip net.IP - MacAddr string - LastLogin time.Time -} - -type IpPoolConfig struct { - mux sync.Mutex - // 计算动态ip - Ipv4Net *net.IPNet - Ipv4GateWay net.IP - IpLongMin uint32 - IpLongMax uint32 - IpLongNow uint32 -} - -func initIpPool() { - // ip地址 - ip := net.ParseIP(ServerCfg.Ipv4Network) - // 子网掩码 - maskIp := net.ParseIP(ServerCfg.Ipv4Netmask).To4() - mask := net.IPMask(maskIp) - - ipNet := &net.IPNet{IP: ip, Mask: mask} - ipPool.Ipv4Net = ipNet - - // 网络地址零值 - min := binary.BigEndian.Uint32(ip.Mask(mask)) - // 广播地址 - one, _ := ipNet.Mask.Size() - max := min | uint32(math.Pow(2, float64(32-one))-1) - - min += 1 // 网关 - ipPool.Ipv4GateWay = long2ip(min) - ServerCfg.Ipv4GateWay = ipPool.Ipv4GateWay.String() - // 第一个可用地址 - min += 1 - ipPool.IpLongMin = min - ipPool.IpLongMax = max - ipPool.IpLongNow = min -} - -func long2ip(i uint32) net.IP { - ip := make([]byte, 4) - binary.BigEndian.PutUint32(ip, i) - return ip -} - -// 获取动态ip -func AcquireIp(macAddr string) net.IP { - ipPool.mux.Lock() - defer ipPool.mux.Unlock() - tNow := time.Now() - - // 判断已经分配过 - if mi, ok := macIps[macAddr]; ok { - mi.IsActive = true - mi.LastLogin = tNow - return mi.Ip - } - - // ip池分配完之前 - if ipPool.IpLongNow < ipPool.IpLongMax { - // 递增分配一个ip - ip := long2ip(ipPool.IpLongNow) - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macIps[macAddr] = mi - ipPool.IpLongNow += 1 - return ip - } - - // 查找过期数据 - farMi := &MacIp{LastLogin: tNow} - for k, v := range macIps { - // 跳过活跃连接 - if v.IsActive { - continue - } - - // 已经超过租期 - if tNow.Sub(v.LastLogin) > IpLease*time.Second { - delete(macIps, k) - ip := v.Ip - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macIps[macAddr] = mi - return ip - } - - // 其他情况判断最早登陆的mac - if v.LastLogin.Before(farMi.LastLogin) { - farMi = v - } - } - - // 全都在线,没有数据可用 - if farMi.MacAddr == "" { - return nil - } - - // 使用最早登陆的mac地址 - delete(macIps, farMi.MacAddr) - ip := farMi.Ip - mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} - macIps[macAddr] = mi - return ip -} - -// 回收ip -func ReleaseIp(ip net.IP, macAddr string) { - ipPool.mux.Lock() - defer ipPool.mux.Unlock() - if mi, ok := macIps[macAddr]; ok { - if mi.Ip.Equal(ip) { - mi.IsActive = false - mi.LastLogin = time.Now() - } - } -} diff --git a/common/ip_pool_test.go b/common/ip_pool_test.go deleted file mode 100644 index 94ba1ec..0000000 --- a/common/ip_pool_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package common - -import ( - "fmt" - "net" - "testing" -) - -func TestAcquireIp(t *testing.T) { - ServerCfg.Ipv4Network = "192.168.1.0" - ServerCfg.Ipv4Netmask = "255.255.255.0" - macIps = map[string]*MacIp{} - initIpPool() - - var ip net.IP - - for i := 2; i <= 100; i++ { - ip = AcquireIp(fmt.Sprintf("mac-%d", i)) - } - ip = AcquireIp(fmt.Sprintf("mac-new")) - AssertTrue(t, ip.Equal(net.IPv4(192, 168, 1, 101))) - for i := 102; i <= 254; i++ { - ip = AcquireIp(fmt.Sprintf("mac-%d", i)) - } - ip = AcquireIp(fmt.Sprintf("mac-nil")) - AssertTrue(t, ip == nil) -} - -func TestReleaseIp(t *testing.T) { - ServerCfg.Ipv4Network = "192.168.1.0" - ServerCfg.Ipv4Netmask = "255.255.255.0" - macIps = map[string]*MacIp{} - initIpPool() - - var ip net.IP - - // 分配完所有数据 - for i := 2; i <= 254; i++ { - ip = AcquireIp(fmt.Sprintf("mac-%d", i)) - } - - ip = AcquireIp(fmt.Sprintf("mac-more")) - AssertTrue(t, ip == nil) - - ReleaseIp(net.IPv4(192, 168, 1, 123), "mac-123") - ReleaseIp(net.IPv4(192, 168, 1, 100), "mac-100") - ip = AcquireIp(fmt.Sprintf("mac-new")) - // 最早过期的ip - AssertTrue(t, ip.Equal(net.IPv4(192, 168, 1, 123))) -} diff --git a/common/log.go b/common/log.go index 18b8cb0..8481bbf 100644 --- a/common/log.go +++ b/common/log.go @@ -1,7 +1,73 @@ package common -import "log" +import ( + "log" + "os" +) -func init() { - log.SetFlags(log.LstdFlags | log.Lshortfile) +const ( + debug = iota + info + error + fatal +) + +var Log *logger + +type logger struct { + log *log.Logger + level int +} + +func initLog() { + // log.SetFlags(log.LstdFlags | log.Lshortfile) + l := log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) + Log = &logger{log: l, level: logLevel2Int(ServerCfg.LogLevel)} +} + +func logLevel2Int(l string) int { + switch l { + case "debug": + return debug + case "info": + return info + case "error": + return error + case "fatal": + return fatal + default: + return info + } +} + +func (l *logger) Debug(v ...interface{}) { + if l.level > debug { + return + } + data := append([]interface{}{"[Debug]"}, v...) + l.log.Println(data...) +} + +func (l *logger) Info(v ...interface{}) { + if l.level > info { + return + } + data := append([]interface{}{"[Info]"}, v...) + l.log.Println(data...) +} + +func (l *logger) Error(v ...interface{}) { + if l.level > error { + return + } + data := append([]interface{}{"[Error]"}, v...) + l.log.Println(data...) +} + +func (l *logger) Fatal(v ...interface{}) { + if l.level > fatal { + return + } + data := append([]interface{}{"[Fatal]"}, v...) + l.log.Fatalln(data...) } diff --git a/common/util.go b/common/util.go new file mode 100644 index 0000000..6f87d48 --- /dev/null +++ b/common/util.go @@ -0,0 +1,39 @@ +package common + +import "fmt" + +func InArrStr(arr []string, str string) bool { + for _, d := range arr { + if d == str { + return true + } + } + return false +} + +const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + TB = 1024 * GB + PB = 1024 * TB +) + +func HumanByte(bAll float64) string { + var hb string + + switch { + case bAll >= TB: + hb = fmt.Sprintf("%0.2f TB", bAll/TB) + case bAll >= GB: + hb = fmt.Sprintf("%0.2f GB", bAll/GB) + case bAll >= MB: + hb = fmt.Sprintf("%0.2f MB", bAll/MB) + case bAll >= KB: + hb = fmt.Sprintf("%0.2f KB", bAll/KB) + default: + hb = fmt.Sprintf("%0.2f B", bAll) + } + + return hb +} diff --git a/common/util_test.go b/common/util_test.go new file mode 100644 index 0000000..a4c9e42 --- /dev/null +++ b/common/util_test.go @@ -0,0 +1,29 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInArrStr(t *testing.T) { + assert := assert.New(t) + arr := []string{"a", "b", "c"} + assert.True(InArrStr(arr, "b")) + assert.False(InArrStr(arr, "d")) +} + +func TestHumanByte(t *testing.T) { + assert := assert.New(t) + var s string + s = HumanByte(999) + assert.Equal(s, "999.00 B") + s = HumanByte(10256) + assert.Equal(s, "10.02 KB") + s = HumanByte(99 * 1024 * 1024) + assert.Equal(s, "99.00 MB") + s = HumanByte(1023 * 1024 * 1024) + assert.Equal(s, "1023.00 MB") + s = HumanByte(1024 * 1024 * 1024) + assert.Equal(s, "1.00 GB") +} diff --git a/conf/server.toml b/conf/server.toml index 3a3db96..197348b 100644 --- a/conf/server.toml +++ b/conf/server.toml @@ -2,21 +2,48 @@ #其他配置文件,可以使用绝对路径 #或者相对于server.toml的路径 -user_file = "./user.toml" + +#数据文件 +db_file = "./data.db" #证书文件 cert_file = "./vpn_cert.pem" cert_key = "./vpn_cert.key" +log_level = "info" #服务监听的地址 server_addr = ":443" -debug_addr = "127.0.0.1:8800" +#一般设置 本机地址 +admin_addr = ":8800" #开启tcp proxy protocol协议 proxy_protocol = false +link_mode = "tun" + +#客户端分配的ip地址池 +ipv4_network = "192.168.10.0" +ipv4_netmask = "255.255.255.0" +ipv4_gateway = "192.168.10.1" +ipv4_pool = ["192.168.10.100", "192.168.10.200"] + +#需加密传输的ip规则 +#include = ["10.10.10.0/255.255.255.0"] +#非加密传输的ip规则 +#exclude = ["192.168.5.0/255.255.255.0"] +#客户端使用的dns ios客户端必须配置 +client_dns = ["114.114.114.114"] +#是否允许本地LAN访问vpn网络 +allow_lan = true + +#最大客户端数量 +max_client = 300 +#单个用户同时在线数量 +max_user_client = 3 + + #用户组 -link_groups = ["one", "two"] +user_groups = ["one", "two"] #默认选择的组 -default_group = "one" +default_group = "two" #登陆成功的欢迎语 banner = "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!" @@ -24,27 +51,12 @@ banner = "您已接入公司网络,请按照公司规定使用。\n请勿进 #客户端失效检测时间(秒) dpd > keepalive cstp_dpd = 30 cstp_keepalive = 20 +mobile_dpd = 300 #session过期时间,用于断线重连,0永不过期 session_timeout = 3600 auth_timeout = 0 -#最大客户端数量 -max_client = 300 -#单个用户同时在线数量 -max_user_client = 3 - -#客户端分配的ip地址池 -ipv4_network = "192.168.10.0" -ipv4_netmask = "255.255.255.0" -#需加密传输的ip规则 -#include = ["10.10.10.0/255.255.255.0"] -#非加密传输的ip规则 -#exclude = ["192.168.5.0/255.255.255.0"] -#客户端使用的dns -client_dns = ["114.114.114.114"] -#是否允许本地LAN访问vpn网络 -allow_lan = true diff --git a/conf/user.toml b/conf/user.toml deleted file mode 100644 index eb4aa57..0000000 --- a/conf/user.toml +++ /dev/null @@ -1,17 +0,0 @@ -#用户信息配置 -[test] -group = "group1" -#密码需要使用 sha1,以下密码为 123456 -password = "7c4a8d09ca3762af61e59520943dc26494f8941b" - - -[user] -group = "group2" -#以下密码为 123456 -password = "7c4a8d09ca3762af61e59520943dc26494f8941b" - - - - - - diff --git a/dbdata/db.go b/dbdata/db.go new file mode 100644 index 0000000..67c5d8b --- /dev/null +++ b/dbdata/db.go @@ -0,0 +1,137 @@ +package dbdata + +import ( + "encoding/json" + "errors" + "log" + + "github.com/bjdgyc/anylink/common" + bolt "go.etcd.io/bbolt" +) + +const pageSize = 10 + +var ( + db *bolt.DB + ErrNoKey = errors.New("db no this key") +) + +func initDb() { + var err error + db, err = bolt.Open(common.ServerCfg.DbFile, 0666, nil) + if err != nil { + log.Fatal(err) + } + + // 创建bucket + err = db.Update(func(tx *bolt.Tx) error { + var err error + _, err = tx.CreateBucketIfNotExists([]byte(BucketUser)) + if err != nil { + return err + } + _, err = tx.CreateBucketIfNotExists([]byte(BucketGroup)) + if err != nil { + return err + } + _, err = tx.CreateBucketIfNotExists([]byte(BucketMacIp)) + if err != nil { + return err + } + return nil + }) + + if err != nil { + log.Fatal(err) + } +} + +func NextId(bucket string) int { + var i int + db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(bucket)) + id, err := b.NextSequence() + i = int(id) + // discard error + return err + }) + return i +} + +func GetCount(bucket string) int { + count := 0 + db.View(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(bucket)) + s := bkt.Stats() + // fmt.Printf("%+v \n", s) + count = s.KeyN + return nil + }) + return count +} + +func Set(bucket, key string, v interface{}) error { + return db.Update(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(bucket)) + b, err := json.Marshal(v) + if err != nil { + return err + } + return bkt.Put([]byte(key), b) + }) +} + +func Del(bucket, key string) error { + return db.Update(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(bucket)) + return bkt.Delete([]byte(key)) + }) +} + +func Get(bucket, key string, v interface{}) error { + return db.View(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(bucket)) + b := bkt.Get([]byte(key)) + if b == nil { + return ErrNoKey + } + return json.Unmarshal(b, v) + }) +} + +// 分页获取 +func getList(bucket, lastKey string, prev bool) [][]byte { + res := make([][]byte, 0) + db.View(func(tx *bolt.Tx) error { + c := tx.Bucket([]byte(bucket)).Cursor() + size := pageSize + k, b := c.Seek([]byte(lastKey)) + + if prev { + for i := 0; i < size; i++ { + k, b = c.Prev() + if k == nil { + break + } + res = append(res, b) + } + return nil + } + + // next + if string(k) != lastKey { + // 不相同,说明找出其他的 + size -= 1 + res = append(res, b) + } + for i := 0; i < size; i++ { + k, b = c.Next() + if k == nil { + break + } + res = append(res, b) + } + return nil + }) + return res +} diff --git a/dbdata/db_test.go b/dbdata/db_test.go new file mode 100644 index 0000000..1423d1e --- /dev/null +++ b/dbdata/db_test.go @@ -0,0 +1,63 @@ +package dbdata + +import ( + "net" + "os" + "path" + "testing" + + "github.com/bjdgyc/anylink/common" + "github.com/stretchr/testify/assert" +) + +func preIpData() { + tmpDb := path.Join(os.TempDir(), "anylink_test.db") + common.ServerCfg.DbFile = tmpDb + initDb() +} + +func closeIpdata() { + db.Close() + tmpDb := path.Join(os.TempDir(), "anylink_test.db") + os.Remove(tmpDb) +} + +func TestDb(t *testing.T) { + assert := assert.New(t) + preIpData() + defer closeIpdata() + + Set(BucketUser, "a", User{Username: "a"}) + Set(BucketUser, "b", User{Username: "b"}) + Set(BucketUser, "c", User{Username: "c"}) + Set(BucketUser, "d", User{Username: "d"}) + Set(BucketUser, "e", User{Username: "e"}) + Set(BucketUser, "f", User{Username: "f"}) + Set(BucketUser, "g", User{Username: "g"}) + + c := GetCount(BucketUser) + assert.Equal(c, 7) + Del(BucketUser, "g") + c = GetCount(BucketUser) + assert.Equal(c, 6) + + // 分页查询 + us := GetUsers("d", false) + assert.Equal(us[0].Username, "e") + assert.Equal(us[1].Username, "f") + us = GetUsers("d", true) + assert.Equal(us[0].Username, "c") + assert.Equal(us[1].Username, "b") + assert.Equal(us[2].Username, "a") + + mac1 := MacIp{Ip: net.ParseIP("192.168.3.11"), MacAddr: "mac1"} + mac2 := MacIp{Ip: net.ParseIP("192.168.3.12"), MacAddr: "mac2"} + Set(BucketMacIp, "mac1", mac1) + Set(BucketMacIp, "mac2", mac2) + + mp := GetAllMacIp() + assert.Equal(mp[0].MacAddr, "mac1") + assert.Equal(mp[1].MacAddr, "mac2") + + os.Exit(0) +} diff --git a/dbdata/group.go b/dbdata/group.go new file mode 100644 index 0000000..9a8ebe4 --- /dev/null +++ b/dbdata/group.go @@ -0,0 +1,36 @@ +package dbdata + +import ( + "encoding/json" + "net" + "time" +) + +const BucketGroup = "group" + +type Group struct { + Id int + Name string + RouteInclude []string + RouteExclude []string + AllowLan bool + LinkAcl []struct { + Action string // allow、deny + IpNet string + IPNet net.IPNet + } + Bandwidth int // 带宽限制 + CreatedAt time.Time + UpdatedAt time.Time +} + +func GetGroups(lastKey string, prev bool) []Group { + res := getList(BucketUser, lastKey, prev) + datas := make([]Group, 0) + for _, data := range res { + d := Group{} + json.Unmarshal(data, &d) + datas = append(datas, d) + } + return datas +} diff --git a/dbdata/mac_ip.go b/dbdata/mac_ip.go new file mode 100644 index 0000000..ab220f9 --- /dev/null +++ b/dbdata/mac_ip.go @@ -0,0 +1,34 @@ +package dbdata + +import ( + "encoding/json" + "net" + "time" + + bolt "go.etcd.io/bbolt" +) + +const BucketMacIp = "macIp" + +type MacIp struct { + IsActive bool // db存储没有使用 + Ip net.IP + MacAddr string + LastLogin time.Time +} + +func GetAllMacIp() []MacIp { + datas := make([]MacIp, 0) + db.View(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(BucketMacIp)) + bkt.ForEach(func(k, v []byte) error { + d := MacIp{} + json.Unmarshal(v, &d) + datas = append(datas, d) + return nil + }) + return nil + }) + + return datas +} diff --git a/dbdata/start.go b/dbdata/start.go new file mode 100644 index 0000000..a9e14c2 --- /dev/null +++ b/dbdata/start.go @@ -0,0 +1,9 @@ +package dbdata + +func Start() { + initDb() +} + +func Stop() error { + return db.Close() +} diff --git a/dbdata/user.go b/dbdata/user.go new file mode 100644 index 0000000..df59a00 --- /dev/null +++ b/dbdata/user.go @@ -0,0 +1,29 @@ +package dbdata + +import ( + "encoding/json" + "time" +) + +const BucketUser = "user" + +type User struct { + Id int + Username string + Password string + OtpSecret string + Group []string + // CreatedAt time.Time + UpdatedAt time.Time +} + +func GetUsers(lastKey string, prev bool) []User { + res := getList(BucketUser, lastKey, prev) + datas := make([]User, 0) + for _, data := range res { + d := User{} + json.Unmarshal(data, &d) + datas = append(datas, d) + } + return datas +} diff --git a/go.mod b/go.mod index 3eaec73..ef6046a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,14 @@ module github.com/bjdgyc/anylink go 1.14 require ( - github.com/julienschmidt/httprouter v1.3.0 + github.com/google/gopacket v1.1.17 github.com/pelletier/go-toml v1.8.0 + github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - golang.org/x/sys v0.0.0-20200819171115-d785dc25833f // indirect + github.com/stretchr/testify v1.5.1 + github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119 + go.etcd.io/bbolt v1.3.5 + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 + golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect + golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e ) diff --git a/go.sum b/go.sum index c74f6b7..38a1be0 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,39 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/google/gopacket v1.1.17 h1:rMrlX2ZY2UbvT+sdz3+6J+pp2z+msCq9MxTU6ymxbBY= +github.com/google/gopacket v1.1.17/go.mod h1:UdDNZ1OO62aGYVnPhxT1U6aI7ukYtA/kB8vaU0diBUM= github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMFtw= github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091 h1:1zN6ImoqhSJhN8hGXFaJlSC8msLmIbX8bFqOfWLKw0w= +github.com/songgao/packets v0.0.0-20160404182456-549a10cd4091/go.mod h1:N20Z5Y8oye9a7HmytmZ+tr8Q2vlP0tAHP13kTHzwvQY= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= -golang.org/x/sys v0.0.0-20200819171115-d785dc25833f h1:KJuwZVtZBVzDmEDtB2zro9CXkD9O0dpCv4o2LHbQIAw= -golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119 h1:YyPWX3jLOtYKulBR6AScGIs74lLrJcgeKRwcbAuQOG4= +github.com/xlzd/gotp v0.0.0-20181030022105-c8557ba2c119/go.mod h1:/nuTSlK+okRfR/vnIPqR89fFKonnWPiZymN5ydRJkX8= +go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= +go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw= +golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/handler/base.go b/handler/base.go index e5e9943..14f6409 100644 --- a/handler/base.go +++ b/handler/base.go @@ -3,12 +3,14 @@ package handler import ( "encoding/xml" "fmt" + "log" "net/http" + "os/exec" "strings" - - "github.com/julienschmidt/httprouter" ) +const BufferSize = 2048 + type ClientRequest struct { XMLName xml.Name `xml:"config-auth"` Client string `xml:"client,attr"` // 一般都是 vpn @@ -42,19 +44,19 @@ type macAddressList struct { } // 判断anyconnect客户端 -func checkVpnClient(h httprouter.Handle) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func checkLinkClient(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { // TODO 调试信息输出 // hd, _ := httputil.DumpRequest(r, true) // fmt.Println("DumpRequest: ", string(hd)) - fmt.Println(r.RemoteAddr) + // fmt.Println(r.RemoteAddr) - user_Agent := strings.ToLower(r.UserAgent()) + userAgent := strings.ToLower(r.UserAgent()) x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth") x_Transcend_Version := r.Header.Get("X-Transcend-Version") - if strings.Contains(user_Agent, "anyconnect") && + if strings.Contains(userAgent, "anyconnect") && x_Aggregate_Auth == "1" && x_Transcend_Version == "1" { - h(w, r, ps) + h(w, r) } else { w.WriteHeader(http.StatusForbidden) fmt.Fprintf(w, "error request") @@ -73,3 +75,15 @@ func setCommonHeader(w http.ResponseWriter) { w.Header().Set("X-Aggregate-Auth", "1") w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } + +func execCmd(cmdStrs []string) error { + for _, cmdStr := range cmdStrs { + cmd := exec.Command("bash", "-c", cmdStr) + b, err := cmd.CombinedOutput() + if err != nil { + log.Println(string(b), err) + return err + } + } + return nil +} diff --git a/handler/dtls.go b/handler/dtls.go index 32c9470..9fe46bb 100644 --- a/handler/dtls.go +++ b/handler/dtls.go @@ -2,5 +2,4 @@ package handler // 暂时没有实现 func startDtls() { - } diff --git a/handler/link_auth.go b/handler/link_auth.go index af79224..0b369ac 100644 --- a/handler/link_auth.go +++ b/handler/link_auth.go @@ -9,10 +9,10 @@ import ( "text/template" "github.com/bjdgyc/anylink/common" - "github.com/julienschmidt/httprouter" + "github.com/bjdgyc/anylink/sessdata" ) -func LinkAuth(w http.ResponseWriter, r *http.Request, params httprouter.Params) { +func LinkAuth(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -32,7 +32,7 @@ func LinkAuth(w http.ResponseWriter, r *http.Request, params httprouter.Params) if cr.Type == "logout" { // 退出删除session信息 if cr.SessionToken != "" { - DelSessByStoken(cr.SessionToken) + sessdata.DelSessByStoken(cr.SessionToken) } w.WriteHeader(http.StatusOK) return @@ -40,29 +40,30 @@ func LinkAuth(w http.ResponseWriter, r *http.Request, params httprouter.Params) if cr.Type == "init" { w.WriteHeader(http.StatusOK) - data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.LinkGroups} + data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.UserGroups} tplRequest(tpl_request, w, data) return } // 登陆参数判断 - if cr.Type != "auth-reply" || cr.Auth.Username == "" || cr.Auth.Password == "" { + if cr.Type != "auth-reply" { w.WriteHeader(http.StatusBadRequest) return } // TODO 用户密码校验 - if !common.CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) { + if !CheckUser(cr.Auth.Username, cr.Auth.Password, cr.GroupSelect) { w.WriteHeader(http.StatusOK) - data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.LinkGroups, Error: true} + data := RequestData{Group: cr.GroupSelect, Groups: common.ServerCfg.UserGroups, Error: true} tplRequest(tpl_request, w, data) return } // 创建新的session信息 - sess := NewSession() + sess := sessdata.NewSession() sess.UserName = cr.Auth.Username sess.MacAddr = strings.ToLower(cr.MacAddressList.MacAddress) + sess.UniqueIdGlobal = cr.DeviceId.UniqueIdGlobal cd := RequestData{SessionId: sess.Sid, SessionToken: sess.Sid + "@" + sess.Token, Banner: common.ServerCfg.Banner} w.WriteHeader(http.StatusOK) diff --git a/handler/link_cstp.go b/handler/link_cstp.go index 4cdd11e..ec0a04b 100644 --- a/handler/link_cstp.go +++ b/handler/link_cstp.go @@ -2,16 +2,17 @@ package handler import ( "encoding/binary" - "fmt" "log" "net" "time" "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/sessdata" ) -func LinkCstp(conn net.Conn, sess *ConnSession) { - // fmt.Println("HandlerCstp") +func LinkCstp(conn net.Conn, sess *sessdata.ConnSession) { + log.Println("HandlerCstp") + sessdata.Sess = sess defer func() { log.Println("LinkCstp return") conn.Close() @@ -20,6 +21,7 @@ func LinkCstp(conn net.Conn, sess *ConnSession) { var ( err error + n int dataLen uint16 dead = time.Duration(common.ServerCfg.CstpDpd+2) * time.Second ) @@ -27,54 +29,53 @@ func LinkCstp(conn net.Conn, sess *ConnSession) { go cstpWrite(conn, sess) for { + // 设置超时限制 - err = conn.SetDeadline(time.Now().Add(dead)) + err = conn.SetReadDeadline(time.Now().Add(dead)) if err != nil { log.Println("SetDeadline: ", err) return } - hdata := make([]byte, 1500) - _, err = conn.Read(hdata) + hdata := make([]byte, BufferSize) + n, err = conn.Read(hdata) if err != nil { log.Println("read hdata: ", err) return } + // 限流设置 + err = sess.RateLimit(n, true) + if err != nil { + log.Println(err) + } + switch hdata[6] { case 0x07: // KEEPALIVE // do nothing - // fmt.Println("keepalive") + // log.Println("recv keepalive") case 0x05: // DISCONNECT - // fmt.Println("DISCONNECT") + // log.Println("DISCONNECT") return case 0x03: // DPD-REQ - fmt.Println("DPD-REQ") - payload := &Payload{ - ptype: 0x04, // DPD-RESP - } - // 直接返回给客户端 resp - select { - case sess.PayloadOut <- payload: - case <-sess.Closed: + // log.Println("recv DPD-REQ") + if payloadOut(sess, sessdata.LTypeIPData, 0x04, nil) { return } - break - case 0x00: + case 0x04: + // log.Println("recv DPD-RESP") + case 0x00: // DATA dataLen = binary.BigEndian.Uint16(hdata[4:6]) // 4,5 - payload := &Payload{ - ptype: 0x00, // DPD-RESP - data: hdata[8 : 8+dataLen], - } - select { - case sess.PayloadIn <- payload: - case <-sess.Closed: + data := hdata[8 : 8+dataLen] + + if payloadIn(sess, sessdata.LTypeIPData, 0x00, data) { return } + } } } -func cstpWrite(conn net.Conn, sess *ConnSession) { +func cstpWrite(conn net.Conn, sess *sessdata.ConnSession) { defer func() { log.Println("cstpWrite return") conn.Close() @@ -83,26 +84,37 @@ func cstpWrite(conn net.Conn, sess *ConnSession) { var ( err error + n int header []byte - payload *Payload + payload *sessdata.Payload ) for { select { case payload = <-sess.PayloadOut: - case <-sess.Closed: + case <-sess.CloseChan: return } - header = []byte{'S', 'T', 'F', 0x01, 0x00, 0x00, payload.ptype, 0x00} - if payload.ptype == 0x00 { // data - binary.BigEndian.PutUint16(header[4:6], uint16(len(payload.data))) - header = append(header, payload.data...) + if payload.LType != sessdata.LTypeIPData { + continue } - _, err = conn.Write(header) + + header = []byte{'S', 'T', 'F', 0x01, 0x00, 0x00, payload.PType, 0x00} + if payload.PType == 0x00 { // data + binary.BigEndian.PutUint16(header[4:6], uint16(len(payload.Data))) + header = append(header, payload.Data...) + } + n, err = conn.Write(header) if err != nil { log.Println("write err", err) return } + + // 限流设置 + err = sess.RateLimit(n, false) + if err != nil { + log.Println(err) + } } } diff --git a/handler/link_home.go b/handler/link_home.go index 41aaa7c..bc76ae3 100644 --- a/handler/link_home.go +++ b/handler/link_home.go @@ -5,17 +5,16 @@ import ( "net/http" "net/http/httputil" "strings" - - "github.com/julienschmidt/httprouter" ) -func LinkHome(w http.ResponseWriter, r *http.Request, params httprouter.Params) { +func LinkHome(w http.ResponseWriter, r *http.Request) { hu, _ := httputil.DumpRequest(r, true) fmt.Println("DumpHome: ", string(hu)) fmt.Println(r.RemoteAddr) connection := strings.ToLower(r.Header.Get("Connection")) - if connection == "close" { + userAgent := strings.ToLower(r.UserAgent()) + if connection == "close" && strings.Contains(userAgent, "anyconnect") { w.Header().Set("Connection", "close") w.WriteHeader(http.StatusBadRequest) return diff --git a/handler/link_tap.go b/handler/link_tap.go new file mode 100644 index 0000000..b7fd156 --- /dev/null +++ b/handler/link_tap.go @@ -0,0 +1,272 @@ +package handler + +import ( + "fmt" + "log" + "net" + + "github.com/bjdgyc/anylink/arpdis" + "github.com/bjdgyc/anylink/sessdata" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/songgao/packets/ethernet" + "github.com/songgao/water" + "github.com/songgao/water/waterutil" +) + +const bridgeName = "anylink0" + +func checkTap() { + brFace, err := net.InterfaceByName(bridgeName) + if err != nil { + log.Fatal("testTap err: ", err) + } + bridgeHw := brFace.HardwareAddr + var bridgeIp net.IP + addrs, err := brFace.Addrs() + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil || ip.To4() == nil { + continue + } + bridgeIp = ip + } + if bridgeIp == nil && bridgeHw == nil { + log.Fatalln("bridgeIp is err") + } + + if !sessdata.IpPool.Ipv4IPNet.Contains(bridgeIp) { + log.Fatalln("bridgeIp or Ip network err") + } + + // 设置本机ip arp为静态 + addr := &arpdis.Addr{IP: bridgeIp.To4(), HardwareAddr: bridgeHw, Type: arpdis.TypeStatic} + arpdis.Add(addr) +} + +// 创建tap网卡 +func LinkTap(sess *sessdata.ConnSession) { + defer func() { + log.Println("LinkTap return") + sess.Close() + }() + + cfg := water.Config{ + DeviceType: water.TAP, + } + + ifce, err := water.New(cfg) + if err != nil { + log.Println(err) + return + } + sess.TunName = ifce.Name() + defer ifce.Close() + + // arp on + cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast on", ifce.Name(), sess.Mtu) + cmdstr2 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name()) + cmdstr3 := fmt.Sprintf("ip link set dev %s master %s", ifce.Name(), bridgeName) + cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3} + err = execCmd(cmdStrs) + if err != nil { + return + } + + // TODO 测试 + // sess.MacHw, _ = net.ParseMAC("3c:8c:40:a0:6a:3d") + + go loopArp(sess) + go tapRead(ifce, sess) + + var ( + payload *sessdata.Payload + ) + + for { + select { + case payload = <-sess.PayloadIn: + case <-sess.CloseChan: + return + } + + var frame ethernet.Frame + switch payload.LType { + default: + log.Println(payload) + case sessdata.LTypeEthernet: + frame = payload.Data + case sessdata.LTypeIPData: // 需要转换成 Ethernet 数据 + data := payload.Data + + ip_src := waterutil.IPv4Source(data) + if waterutil.IsIPv6(data) || !ip_src.Equal(sess.Ip) { + // 过滤掉IPv6的数据 + // 非分配给客户端ip,直接丢弃 + continue + } + + ip_dst := waterutil.IPv4Destination(data) + // fmt.Println("get:", ip_src, ip_dst) + + var dstAddr *arpdis.Addr + if !sessdata.IpPool.Ipv4IPNet.Contains(ip_dst) || ip_dst.Equal(sessdata.IpPool.Ipv4Gateway) { + // 不是同一网段,使用网关mac地址 + ip_dst = sessdata.IpPool.Ipv4Gateway + dstAddr = arpdis.Lookup(ip_dst, false) + if dstAddr == nil { + log.Println("Ipv4Gateway mac err", ip_dst) + return + } + // fmt.Println("Gateway", ip_dst, dstAddr.HardwareAddr) + } else { + // 同一网段内的其他主机 + dstAddr = arpdis.Lookup(ip_dst, true) + // fmt.Println("other", ip_src, ip_dst, dstAddr) + if dstAddr == nil || dstAddr.Type == arpdis.TypeUnreachable { + // 异步检测发送数据包 + select { + case sess.PayloadArp <- payload: + case <-sess.CloseChan: + return + default: + // PayloadArp 容量已经满了 + log.Println("PayloadArp is full", sess.Ip, ip_dst) + } + continue + } + } + + frame.Prepare(dstAddr.HardwareAddr, sess.MacHw, ethernet.NotTagged, ethernet.IPv4, len(data)) + copy(frame[12+2:], data) + } + + // packet := gopacket.NewPacket(frame, layers.LayerTypeEthernet, gopacket.Default) + // fmt.Println("write:", packet) + _, err = ifce.Write(frame) + if err != nil { + log.Println("tap Write err", err) + return + } + } + +} + +// 异步处理获取ip对应的mac地址的数据 +func loopArp(sess *sessdata.ConnSession) { + defer func() { + log.Println("loopArp return") + }() + + var ( + payload *sessdata.Payload + dstAddr *arpdis.Addr + ip_dst net.IP + ) + + for { + select { + case payload = <-sess.PayloadArp: + case <-sess.CloseChan: + return + } + + ip_dst = waterutil.IPv4Destination(payload.Data) + dstAddr = arpdis.Lookup(ip_dst, false) + // 不可达数据包 + if dstAddr == nil || dstAddr.Type == arpdis.TypeUnreachable { + // 直接丢弃数据 + // fmt.Println("Lookup", ip_dst) + continue + } + + // 正常获取mac地址 + if payloadInData(sess, payload) { + return + } + + } +} + +func tapRead(ifce *water.Interface, sess *sessdata.ConnSession) { + defer func() { + log.Println("tapRead return") + ifce.Close() + }() + + var ( + err error + n int + buf []byte + ) + fmt.Println(sess.MacHw) + + for { + var frame ethernet.Frame + frame.Resize(BufferSize) + n, err = ifce.Read(frame) + if err != nil { + log.Println("tap Read err", n, err) + return + } + frame = frame[:n] + + switch frame.Ethertype() { + default: + // packet := gopacket.NewPacket(frame, layers.LayerTypeEthernet, gopacket.Default) + // fmt.Println(packet) + continue + case ethernet.IPv6: + continue + case ethernet.IPv4: + // 发送IP数据 + data := frame.Payload() + + ip_dst := waterutil.IPv4Destination(data) + if !ip_dst.Equal(sess.Ip) { + // 过滤非本机地址 + // log.Println(ip_dst, sess.Ip) + continue + } + + if payloadOut(sess, sessdata.LTypeIPData, 0x00, data) { + return + } + + case ethernet.ARP: + // 暂时仅实现了ARP协议 + packet := gopacket.NewPacket(frame, layers.LayerTypeEthernet, gopacket.NoCopy) + layer := packet.Layer(layers.LayerTypeARP) + arpReq := layer.(*layers.ARP) + + // fmt.Println("arp", net.IP(arpReq.SourceProtAddress), sess.Ip) + if !sess.Ip.Equal(arpReq.DstProtAddress) { + // 过滤非本机地址 + continue + } + + // fmt.Println("arp", arpReq.SourceProtAddress, sess.Ip) + // fmt.Println(packet) + + // 返回ARP数据 + src := &arpdis.Addr{IP: sess.Ip, HardwareAddr: sess.MacHw} + dst := &arpdis.Addr{IP: arpReq.SourceProtAddress, HardwareAddr: frame.Source()} + buf, err = arpdis.NewARPReply(src, dst) + if err != nil { + log.Println(err) + return + } + + // 从接受的arp信息添加arp地址 + addr := &arpdis.Addr{} + copy(addr.IP, arpReq.SourceProtAddress) + copy(addr.HardwareAddr, frame.Source()) + arpdis.Add(addr) + + if payloadIn(sess, sessdata.LTypeEthernet, 0x00, buf) { + return + } + + } + } +} diff --git a/handler/link_tun.go b/handler/link_tun.go index 8fca1db..91d04a2 100644 --- a/handler/link_tun.go +++ b/handler/link_tun.go @@ -3,13 +3,13 @@ package handler import ( "fmt" "log" - "os/exec" "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/sessdata" "github.com/songgao/water" ) -func testTun() { +func checkTun() { // 测试tun cfg := water.Config{ DeviceType: water.TUN, @@ -19,17 +19,18 @@ func testTun() { if err != nil { log.Fatal("open tun err: ", err) } + defer ifce.Close() + // 测试ip命令 cmdstr := fmt.Sprintf("ip link set dev %s up mtu %s multicast off", ifce.Name(), "1399") err = execCmd([]string{cmdstr}) if err != nil { - log.Fatal("ip cmd err: ", err) + log.Fatal("testTun err: ", err) } - ifce.Close() } // 创建tun网卡 -func LinkTun(sess *ConnSession) { +func LinkTun(sess *sessdata.ConnSession) { defer func() { log.Println("LinkTun return") sess.Close() @@ -48,10 +49,9 @@ func LinkTun(sess *ConnSession) { sess.TunName = ifce.Name() defer ifce.Close() - // arp on - cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %s multicast off", ifce.Name(), sess.Mtu) + cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast off", ifce.Name(), sess.Mtu) cmdstr2 := fmt.Sprintf("ip addr add dev %s local %s peer %s/32", - ifce.Name(), common.ServerCfg.Ipv4GateWay, sess.NetIp) + ifce.Name(), common.ServerCfg.Ipv4Gateway, sess.Ip) cmdstr3 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name()) cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3} err = execCmd(cmdStrs) @@ -61,21 +61,16 @@ func LinkTun(sess *ConnSession) { go tunRead(ifce, sess) - var payload *Payload + var payload *sessdata.Payload for { select { case payload = <-sess.PayloadIn: - case <-sess.Closed: + case <-sess.CloseChan: return } - // ip_src := waterutil.IPv4Source(payload.data) - // ip_des := waterutil.IPv4Destination(payload.data) - // ip_port := waterutil.IPv4DestinationPort(payload.data) - // fmt.Println("write: ", ip_src, ip_des.String(), ip_port, len(payload.data)) - - _, err = ifce.Write(payload.data) + _, err = ifce.Write(payload.Data) if err != nil { log.Println("tun Write err", err) return @@ -84,7 +79,7 @@ func LinkTun(sess *ConnSession) { } -func tunRead(ifce *water.Interface, sess *ConnSession) { +func tunRead(ifce *water.Interface, sess *sessdata.ConnSession) { defer func() { log.Println("tunRead return") ifce.Close() @@ -95,34 +90,25 @@ func tunRead(ifce *water.Interface, sess *ConnSession) { ) for { - packet := make([]byte, 1500) - n, err = ifce.Read(packet) + data := make([]byte, BufferSize) + n, err = ifce.Read(data) if err != nil { log.Println("tun Read err", n, err) return } - payload := &Payload{ - ptype: 0x00, - data: packet[:n], - } + data = data[:n] - select { - case sess.PayloadOut <- payload: - case <-sess.Closed: + // ip_src := waterutil.IPv4Source(data) + // ip_dst := waterutil.IPv4Destination(data) + // ip_port := waterutil.IPv4DestinationPort(data) + // fmt.Println("sent:", ip_src, ip_dst, ip_port) + // packet := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default) + // fmt.Println("read:", packet) + + if payloadOut(sess, sessdata.LTypeIPData, 0x00, data) { return } - } -} -func execCmd(cmdStrs []string) error { - for _, cmdStr := range cmdStrs { - cmd := exec.Command("bash", "-c", cmdStr) - b, err := cmd.CombinedOutput() - if err != nil { - log.Println(string(b), err) - return err - } } - return nil } diff --git a/handler/link_tunnel.go b/handler/link_tunnel.go index 82cfa87..29c5b6e 100644 --- a/handler/link_tunnel.go +++ b/handler/link_tunnel.go @@ -3,10 +3,12 @@ package handler import ( "fmt" "log" + "net" "net/http" "os" "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/sessdata" ) var hn string @@ -20,7 +22,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // TODO 调试信息输出 // hd, _ := httputil.DumpRequest(r, true) // fmt.Println("DumpRequest: ", string(hd)) - fmt.Println("LinkTunnel", r.RemoteAddr) + // fmt.Println("LinkTunnel", r.RemoteAddr) // 判断session-token的值 cookie, err := r.Cookie("webvpn") @@ -29,31 +31,41 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { return } - sess := SToken2Sess(cookie.Value) + sess := sessdata.SToken2Sess(cookie.Value) if sess == nil { w.WriteHeader(http.StatusBadRequest) return } // 开启link - cSess := sess.StartConn() + cSess := sess.NewConn() if cSess == nil { log.Println(err) w.WriteHeader(http.StatusBadRequest) return } + fmt.Println(cSess.Ip, cSess.MacHw) // 客户端信息 cstp_mtu := r.Header.Get("X-CSTP-MTU") master_Secret := r.Header.Get("X-DTLS-Master-Secret") + local_ip := r.Header.Get("X-Cstp-Local-Address-Ip4") + mobile := r.Header.Get("X-Cstp-License") + cSess.SetMtu(cstp_mtu) cSess.MasterSecret = master_Secret - cSess.Mtu = cstp_mtu cSess.RemoteAddr = r.RemoteAddr + cSess.LocalIp = net.ParseIP(local_ip) + cstpDpd := common.ServerCfg.CstpDpd + if mobile == "mobile" { + // 手机客户端 + cstpDpd = common.ServerCfg.MobileDpd + } + // 返回客户端数据 w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER)) w.Header().Set("X-CSTP-Version", "1") w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.") - w.Header().Set("X-CSTP-Address", cSess.NetIp.String()) // 分配的ip地址 + w.Header().Set("X-CSTP-Address", cSess.Ip.String()) // 分配的ip地址 w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码 w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 for _, v := range common.ServerCfg.ClientDns { @@ -74,7 +86,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // w.Header().Add("X-CSTP-Split-Include", "192.168.0.0/255.255.0.0") // w.Header().Add("X-CSTP-Split-Exclude", "10.1.5.2/255.255.255.255") - w.Header().Set("X-CSTP-Lease-Duration", fmt.Sprintf("%d", common.IpLease)) // ip地址租期 + w.Header().Set("X-CSTP-Lease-Duration", fmt.Sprintf("%d", sessdata.IpLease)) // ip地址租期 w.Header().Set("X-CSTP-Session-Timeout", "none") w.Header().Set("X-CSTP-Session-Timeout-Alert-Interval", "60") w.Header().Set("X-CSTP-Session-Timeout-Remaining", "none") @@ -82,16 +94,18 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-CSTP-Disconnected-Timeout", "18000") w.Header().Set("X-CSTP-Keep", "true") w.Header().Set("X-CSTP-Tunnel-All-DNS", "false") - w.Header().Set("X-CSTP-Rekey-Time", "5400") + + w.Header().Set("X-CSTP-Rekey-Time", "172800") w.Header().Set("X-CSTP-Rekey-Method", "new-tunnel") - w.Header().Set("X-CSTP-DPD", fmt.Sprintf("%d", common.ServerCfg.CstpDpd)) // 30 Dead peer detection in seconds + + w.Header().Set("X-CSTP-DPD", fmt.Sprintf("%d", cstpDpd)) // 30 Dead peer detection in seconds w.Header().Set("X-CSTP-Keepalive", fmt.Sprintf("%d", common.ServerCfg.CstpKeepalive)) // 20 - w.Header().Set("X-CSTP-Banner", "welcome") // urlencode + w.Header().Set("X-CSTP-Banner", common.ServerCfg.Banner) // urlencode w.Header().Set("X-CSTP-MSIE-Proxy-Lockdown", "true") w.Header().Set("X-CSTP-Smartcard-Removal-Disconnect", "true") - w.Header().Set("X-CSTP-MTU", cstp_mtu) // 1399 - w.Header().Set("X-DTLS-MTU", cstp_mtu) + w.Header().Set("X-CSTP-MTU", fmt.Sprintf("%d", cSess.Mtu)) // 1399 + w.Header().Set("X-DTLS-MTU", fmt.Sprintf("%d", cSess.Mtu)) w.Header().Set("X-DTLS-Session-ID", sess.DtlsSid) w.Header().Set("X-DTLS-Port", "4433") @@ -117,6 +131,12 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { } // 开始数据处理 - go LinkTun(cSess) + switch common.ServerCfg.LinkMode { + case common.LinkModeTUN: + go LinkTun(cSess) + case common.LinkModeTAP: + go LinkTap(cSess) + } + go LinkCstp(conn, cSess) } diff --git a/handler/payload.go b/handler/payload.go new file mode 100644 index 0000000..5bef3c8 --- /dev/null +++ b/handler/payload.go @@ -0,0 +1,47 @@ +package handler + +import "github.com/bjdgyc/anylink/sessdata" + +func payloadIn(sess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { + payload := &sessdata.Payload{ + LType: lType, + PType: pType, + Data: data, + } + + return payloadInData(sess, payload) +} + +func payloadInData(sess *sessdata.ConnSession, payload *sessdata.Payload) bool { + closed := false + + select { + case sess.PayloadIn <- payload: + case <-sess.CloseChan: + closed = true + } + + return closed +} + +func payloadOut(sess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool { + payload := &sessdata.Payload{ + LType: lType, + PType: pType, + Data: data, + } + + return payloadOutData(sess, payload) +} + +func payloadOutData(sess *sessdata.ConnSession, payload *sessdata.Payload) bool { + closed := false + + select { + case sess.PayloadOut <- payload: + case <-sess.CloseChan: + closed = true + } + + return closed +} diff --git a/handler/server.go b/handler/server.go index b7aae08..df0e94d 100644 --- a/handler/server.go +++ b/handler/server.go @@ -3,27 +3,32 @@ package handler import ( "crypto/tls" "fmt" - "github.com/bjdgyc/anylink/proxyproto" "log" "net" "net/http" - "net/http/httputil" - _ "net/http/pprof" + "net/http/pprof" "time" "github.com/bjdgyc/anylink/common" - "github.com/julienschmidt/httprouter" + "github.com/bjdgyc/anylink/proxyproto" + "github.com/bjdgyc/anylink/router" ) -func Start() { - testTun() - go startDebug() - go startDtls() - go startTls() -} +func startAdmin() { + mux := router.NewHttpMux() + mux.HandleFunc(router.ANY, "/", notFound) + // mux.ServeFile(router.ANY, "/static/*", http.Dir("./static")) -func startDebug() { - http.ListenAndServe(common.ServerCfg.DebugAddr, nil) + // pprof + mux.HandleFunc(router.ANY, "/debug/pprof/*", pprof.Index) + mux.HandleFunc(router.ANY, "/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc(router.ANY, "/debug/pprof/profile", pprof.Profile) + mux.HandleFunc(router.ANY, "/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc(router.ANY, "/debug/pprof/trace", pprof.Trace) + + fmt.Println("Listen admin", common.ServerCfg.AdminAddr) + err := http.ListenAndServe(common.ServerCfg.AdminAddr, mux) + fmt.Println(err) } func startTls() { @@ -62,17 +67,18 @@ func startTls() { } func initRoute() http.Handler { - router := httprouter.New() - router.GET("/", checkVpnClient(LinkHome)) - router.POST("/", checkVpnClient(LinkAuth)) - router.HandlerFunc("CONNECT", "/CSCOSSLC/tunnel", LinkTunnel) - router.NotFound = http.HandlerFunc(notFound) - return router + mux := router.NewHttpMux() + mux.HandleFunc("GET", "/", checkLinkClient(LinkHome)) + mux.HandleFunc("POST", "/", checkLinkClient(LinkAuth)) + mux.HandleFunc("CONNECT", "/CSCOSSLC/tunnel", LinkTunnel) + mux.SetNotFound(http.HandlerFunc(notFound)) + return mux } func notFound(w http.ResponseWriter, r *http.Request) { - hu, _ := httputil.DumpRequest(r, true) - fmt.Println("NotFound: ", string(hu)) + // fmt.Println(r.RemoteAddr) + // hu, _ := httputil.DumpRequest(r, true) + // fmt.Println("NotFound: ", string(hu)) w.WriteHeader(http.StatusNotFound) fmt.Fprintln(w, "404 page not found") diff --git a/handler/session.go b/handler/session.go deleted file mode 100644 index c67f59a..0000000 --- a/handler/session.go +++ /dev/null @@ -1,157 +0,0 @@ -package handler - -import ( - "fmt" - "log" - "math/rand" - "net" - "strings" - "sync" - "time" - - "github.com/bjdgyc/anylink/common" -) - -var ( - sessMux = sync.Mutex{} - sessions = make(map[string]*Session) // session_token -> SessUser -) - -// 连接sess -type ConnSession struct { - Sess *Session - MasterSecret string // dtls协议的 master_secret - NetIp net.IP // 分配的ip地址 - RemoteAddr string - Mtu string - TunName string - closeOnce sync.Once - Closed chan struct{} - PayloadIn chan *Payload - PayloadOut chan *Payload -} - -type Session struct { - Sid string // auth返回的 session-id - Token string // session信息的唯一token - DtlsSid string // dtls协议的 session_id - MacAddr string // 客户端mac地址 - UserName string // 用户名 - LastLogin time.Time - IsActive bool - - // 开启link需要设置的参数 - CSess *ConnSession -} - -func init() { - rand.Seed(time.Now().UnixNano()) - - // 检测过期的session - go func() { - if common.ServerCfg.SessionTimeout == 0 { - return - } - timeout := time.Duration(common.ServerCfg.SessionTimeout) * time.Second - tick := time.Tick(time.Second * 30) - for range tick { - t := time.Now() - sessMux.Lock() - for k, v := range sessions { - if v.IsActive == true { - continue - } - if t.Sub(v.LastLogin) > timeout { - delete(sessions, k) - } - } - sessMux.Unlock() - } - }() -} - -func NewSession() *Session { - // 生成32位的 token - btoken := make([]byte, 32) - rand.Read(btoken) - - // 生成 dtls session_id - dtlsid := make([]byte, 32) - rand.Read(dtlsid) - - token := fmt.Sprintf("%x", btoken) - sess := &Session{ - Sid: fmt.Sprintf("%d", time.Now().Unix()), - Token: token, - DtlsSid: fmt.Sprintf("%x", dtlsid), - LastLogin: time.Now(), - } - sessMux.Lock() - defer sessMux.Unlock() - sessions[token] = sess - return sess -} - -func (s *Session) StartConn() *ConnSession { - if s.IsActive == true { - s.CSess.Close() - } - - limit := common.LimitClient(s.UserName, false) - if limit == false { - // s.NetIp = nil - return nil - } - s.IsActive = true - cSess := &ConnSession{ - Sess: s, - NetIp: common.AcquireIp(s.MacAddr), - closeOnce: sync.Once{}, - Closed: make(chan struct{}), - PayloadIn: make(chan *Payload), - PayloadOut: make(chan *Payload), - } - s.CSess = cSess - return cSess -} - -func (cs *ConnSession) Close() { - cs.closeOnce.Do(func() { - log.Println("closeOnce") - close(cs.Closed) - cs.Sess.IsActive = false - cs.Sess.LastLogin = time.Now() - common.ReleaseIp(cs.NetIp, cs.Sess.MacAddr) - common.LimitClient(cs.Sess.UserName, true) - }) -} - -func SToken2Sess(stoken string) *Session { - stoken = strings.TrimSpace(stoken) - sarr := strings.Split(stoken, "@") - token := sarr[1] - sessMux.Lock() - defer sessMux.Unlock() - if sess, ok := sessions[token]; ok { - return sess - } - - return nil -} - -func Dtls2Sess(dtlsid []byte) *Session { - return nil -} - -func DelSess(token string) { - delete(sessions, token) -} - -func DelSessByStoken(stoken string) { - stoken = strings.TrimSpace(stoken) - sarr := strings.Split(stoken, "@") - token := sarr[1] - sessMux.Lock() - defer sessMux.Unlock() - delete(sessions, token) -} diff --git a/handler/start.go b/handler/start.go new file mode 100644 index 0000000..ee80b2c --- /dev/null +++ b/handler/start.go @@ -0,0 +1,24 @@ +package handler + +import ( + "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/sessdata" +) + +func Start() { + dbdata.Start() + sessdata.Start() + + checkTun() + if common.ServerCfg.LinkMode == common.LinkModeTAP { + checkTap() + } + go startAdmin() + go startTls() + go startDtls() +} + +func Stop() { + dbdata.Stop() +} diff --git a/handler/user.go b/handler/user.go new file mode 100644 index 0000000..e1bf044 --- /dev/null +++ b/handler/user.go @@ -0,0 +1,71 @@ +package handler + +import ( + "crypto/sha1" + "fmt" + "os" + "time" + + "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/dbdata" + "github.com/xlzd/gotp" +) + +func CheckUser(name, pwd, group string) bool { + return true + + pl := len(pwd) + if name == "" || pl < 6 { + return false + } + v := &dbdata.User{} + err := dbdata.Get(dbdata.BucketUser, name, v) + if err != nil { + return false + } + if !common.InArrStr(v.Group, group) { + return false + } + pass := pwd[:pl-6] + pwdHash := hashPass(pass) + if v.Password != pwdHash { + return false + } + otp := pwd[pl-6:] + totp := gotp.NewDefaultTOTP(v.OtpSecret) + unix := time.Now().Unix() + verify := totp.Verify(otp, int(unix)) + if !verify { + return false + } + return true +} + +func UserAdd(name, pwd string, group []string) dbdata.User { + v := dbdata.User{ + Id: dbdata.NextId(dbdata.BucketUser), + Username: name, + Password: hashPass(pwd), + OtpSecret: gotp.RandomSecret(32), + Group: group, + UpdatedAt: time.Now(), + } + fmt.Println(v) + secret := "WHH7WA6POOGGEYVIQYXLZU75QLM7YLUX" + totp := gotp.NewDefaultTOTP(secret) + s := totp.ProvisioningUri("bjdtest", "bjdpro") + fmt.Println(s) + + // qr, _ := qrcode.New(s, qrcode.Medium) + // a := qr.ToSmallString(false) + // fmt.Println(a) + // qr.WriteFile(512, "a.png") + + os.Exit(0) + return v +} + +func hashPass(pwd string) string { + sum := sha1.Sum([]byte(pwd)) + return fmt.Sprintf("%x", sum) +} diff --git a/main.go b/main.go index 6ba068b..5f52d7e 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" "os" "os/signal" "syscall" @@ -13,6 +14,7 @@ import ( var COMMIT_ID string func main() { + log.Println("start") common.CommitId = COMMIT_ID common.InitConfig() handler.Start() @@ -33,6 +35,7 @@ func signalWatch() { fmt.Println("reload") default: // stop + handler.Stop() return } } diff --git a/proxyproto/protocol.go b/proxyproto/protocol.go index 84a582e..c729063 100644 --- a/proxyproto/protocol.go +++ b/proxyproto/protocol.go @@ -284,6 +284,3 @@ func (p *Conn) checkPrefix() error { return nil } - - - diff --git a/proxyproto/protocol_test.go b/proxyproto/protocol_test.go index 43484e7..1ad37aa 100644 --- a/proxyproto/protocol_test.go +++ b/proxyproto/protocol_test.go @@ -484,5 +484,3 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { } } - - diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..0f5e700 --- /dev/null +++ b/router/router.go @@ -0,0 +1,167 @@ +package router + +import ( + "net/http" + "path" + "sort" + "strings" + "sync" +) + +const ( + ANY = "ANY" // 包含所有 method +) + +type HttpMux struct { + no http.Handler // NotFoundHandler + mu sync.RWMutex + m map[string]muxEntry // example: GET/index:muxEntry{} + es []muxEntry // 模糊匹配,pattern需要添加后缀 * +} + +type muxEntry struct { + h http.Handler + pattern string + method string +} + +func NewHttpMux() *HttpMux { + http.NewServeMux() + return &HttpMux{ + m: make(map[string]muxEntry), + es: make([]muxEntry, 0), + } +} + +func (mux *HttpMux) SetNotFound(no http.Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + mux.no = no +} + +func (mux *HttpMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "*" { + w.Header().Set("Connection", "close") + w.WriteHeader(http.StatusBadRequest) + return + } + + h := mux.match(r.Method, r.URL.Path) + h.ServeHTTP(w, r) +} + +func (mux *HttpMux) match(method, rpath string) http.Handler { + mux.mu.RLock() + defer mux.mu.RUnlock() + + path := mux.cleanPath(rpath) + // any 路径 匹配 + p_a := ANY + path + if v, ok := mux.m[p_a]; ok { + return v.h + } + // method 路径 匹配 + method = strings.ToUpper(method) + p_m := method + path + if e, ok := mux.m[p_m]; ok { + return e.h + } + + // Check for longest valid match. mux.es contains all patterns + // that end in / sorted from longest to shortest. + for _, e := range mux.es { + // trim last * + pattern := e.pattern[:len(e.pattern)-1] + // fmt.Println(pattern, p_a, p_m) + if strings.HasPrefix(p_a, pattern) { + return e.h + } + if strings.HasPrefix(p_m, pattern) { + return e.h + } + } + + if mux.no != nil { + return mux.no + } + return http.NotFoundHandler() +} + +func (mux *HttpMux) cleanPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + // Fast path for common case of p being the string we want: + if len(p) == len(np)+1 && strings.HasPrefix(p, np) { + np = p + } else { + np += "/" + } + } + return np +} + +func (mux *HttpMux) HandleFunc(method, pattern string, handler func(http.ResponseWriter, *http.Request)) { + if handler == nil { + panic("http: nil handler") + } + mux.Handle(method, pattern, http.HandlerFunc(handler)) +} + +func (mux *HttpMux) Handle(method, pattern string, handler http.Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + + if pattern == "" || method == "" { + panic("http: invalid pattern") + } + if handler == nil { + panic("http: nil handler") + } + method = strings.ToUpper(method) + p := method + pattern + if _, exist := mux.m[p]; exist { + panic("http: multiple registrations for " + p) + } + + e := muxEntry{h: handler, pattern: p} + mux.m[p] = e + if pattern[len(pattern)-1] == '*' { + mux.es = mux.appendSorted(mux.es, e) + } +} + +func (mux *HttpMux) appendSorted(es []muxEntry, e muxEntry) []muxEntry { + n := len(es) + i := sort.Search(n, func(i int) bool { + return len(es[i].pattern) < len(e.pattern) + }) + if i == n { + return append(es, e) + } + // we now know that i points at where we want to insert + es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. + copy(es[i+1:], es[i:]) // Move shorter entries down + es[i] = e + return es +} + +// ANY /static/* /var/www +func (mux *HttpMux) ServeFile(method, pattern string, root http.FileSystem) { + fs := http.FileServer(root) + + // trim * + pt := pattern[:len(pattern)-1] + mux.HandleFunc(method, pattern, func(w http.ResponseWriter, r *http.Request) { + // 过滤前缀路径 + r.URL.Path = strings.TrimPrefix(r.URL.Path, pt) + fs.ServeHTTP(w, r) + }) +} diff --git a/router/router_test.go b/router/router_test.go new file mode 100644 index 0000000..42ad205 --- /dev/null +++ b/router/router_test.go @@ -0,0 +1,47 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Server unit tests + +package router + +import ( + "fmt" + "net/http" + "testing" +) + +func BenchmarkServerMatch(b *testing.B) { + fn := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "OK") + } + mux := NewHttpMux() + mux.HandleFunc("GET", "/", fn) + mux.HandleFunc("GET", "/index", fn) + mux.HandleFunc("GET", "/home", fn) + mux.HandleFunc("GET", "/about", fn) + mux.HandleFunc("GET", "/contact", fn) + mux.HandleFunc("GET", "/robots.txt", fn) + mux.HandleFunc("GET", "/products/", fn) + mux.HandleFunc("GET", "/products/1", fn) + mux.HandleFunc("GET", "/products/2", fn) + mux.HandleFunc("GET", "/products/3", fn) + mux.HandleFunc("GET", "/products/3/image.jpg", fn) + mux.HandleFunc("GET", "/admin", fn) + mux.HandleFunc("GET", "/admin/products/", fn) + mux.HandleFunc("GET", "/admin/products/create", fn) + mux.HandleFunc("GET", "/admin/products/update", fn) + mux.HandleFunc("GET", "/admin/products/delete", fn) + + paths := []string{"/", "/notfound", "/admin/", "/admin/foo", "/contact", "/products", + "/products/", "/products/3/image.jpg"} + b.StartTimer() + for i := 0; i < b.N; i++ { + path := paths[i%len(paths)] + if h := mux.match("GET", path); h == nil { + b.Error("impossible", path) + } + } + b.StopTimer() +} diff --git a/sessdata/copy_struct.go b/sessdata/copy_struct.go new file mode 100644 index 0000000..33db3c5 --- /dev/null +++ b/sessdata/copy_struct.go @@ -0,0 +1,53 @@ +package sessdata + +import ( + "fmt" + "reflect" +) + +// 用b的所有字段覆盖a的 +// 如果fields不为空, 表示用b的特定字段覆盖a的 +// a应该为结构体指针 +func CopyStruct(a interface{}, b interface{}, fields ...string) (err error) { + at := reflect.TypeOf(a) + av := reflect.ValueOf(a) + bt := reflect.TypeOf(b) + bv := reflect.ValueOf(b) + + // 简单判断下 + if at.Kind() != reflect.Ptr { + err = fmt.Errorf("a must be a struct pointer") + return + } + av = reflect.ValueOf(av.Interface()) + + // 要复制哪些字段 + _fields := make([]string, 0) + if len(fields) > 0 { + _fields = fields + } else { + for i := 0; i < bv.NumField(); i++ { + _fields = append(_fields, bt.Field(i).Name) + } + } + + if len(_fields) == 0 { + fmt.Println("no fields to copy") + return + } + + // 复制 + for i := 0; i < len(_fields); i++ { + name := _fields[i] + f := av.Elem().FieldByName(name) + bValue := bv.FieldByName(name) + + // a中有同名的字段并且类型一致才复制 + if f.IsValid() && f.Kind() == bValue.Kind() { + f.Set(bValue) + } else { + // fmt.Printf("no such field or different kind, fieldName: %s\n", name) + } + } + return +} diff --git a/sessdata/copy_struct_test.go b/sessdata/copy_struct_test.go new file mode 100644 index 0000000..5ba519b --- /dev/null +++ b/sessdata/copy_struct_test.go @@ -0,0 +1,38 @@ +package sessdata + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type A struct { + Id int + Name string + Age int + Addr string +} + +type B struct { + IdB int + NameB string + Age int + Addr string +} + +func TestCopyStruct(t *testing.T) { + assert := assert.New(t) + a := A{ + Id: 1, + Name: "bob", + Age: 15, + Addr: "American", + } + b := B{} + err := CopyStruct(&b, a) + assert.Nil(err) + assert.Equal(b.IdB, 0) + assert.Equal(b.NameB, "") + assert.Equal(b.Age, 15) + assert.Equal(b.Addr, "American") +} diff --git a/sessdata/ip_pool.go b/sessdata/ip_pool.go new file mode 100644 index 0000000..1d511ea --- /dev/null +++ b/sessdata/ip_pool.go @@ -0,0 +1,171 @@ +package sessdata + +import ( + "encoding/binary" + "net" + "sync" + "time" + + "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/dbdata" +) + +const ( + // ip租期 (秒) + IpLease = 1209600 +) + +var ( + IpPool = &IpPoolConfig{} + macInfo = map[string]*MacIp{} + ipInfo = map[string]*MacIp{} +) + +type MacIp struct { + IsActive bool + Ip net.IP + MacAddr string + LastLogin time.Time +} + +type IpPoolConfig struct { + mux sync.Mutex + // 计算动态ip + Ipv4Gateway net.IP + Ipv4IPNet net.IPNet + IpLongMin uint32 + IpLongMax uint32 +} + +func initIpMac() { + macs := dbdata.GetAllMacIp() + for _, v := range macs { + mi := &MacIp{} + CopyStruct(mi, v) + macInfo[v.MacAddr] = mi + ipInfo[v.Ip.String()] = mi + } +} + +func initIpPool() { + + // 地址处理 + // ip地址 + ip := net.ParseIP(common.ServerCfg.Ipv4Network) + // 子网掩码 + maskIp := net.ParseIP(common.ServerCfg.Ipv4Netmask).To4() + IpPool.Ipv4IPNet = net.IPNet{IP: ip, Mask: net.IPMask(maskIp)} + IpPool.Ipv4Gateway = net.ParseIP(common.ServerCfg.Ipv4Gateway) + + // 网络地址零值 + // zero := binary.BigEndian.Uint32(ip.Mask(mask)) + // 广播地址 + // one, _ := ipNet.Mask.Size() + // max := min | uint32(math.Pow(2, float64(32-one))-1) + + // ip地址池 + IpPool.IpLongMin = ip2long(net.ParseIP(common.ServerCfg.Ipv4Pool[0])) + IpPool.IpLongMax = ip2long(net.ParseIP(common.ServerCfg.Ipv4Pool[1])) +} + +func long2ip(i uint32) net.IP { + ip := make([]byte, 4) + binary.BigEndian.PutUint32(ip, i) + return ip +} + +func ip2long(ip net.IP) uint32 { + ip = ip.To4() + return binary.BigEndian.Uint32(ip) +} + +// 获取动态ip +func AcquireIp(macAddr string) net.IP { + IpPool.mux.Lock() + defer IpPool.mux.Unlock() + tNow := time.Now() + + // 判断已经分配过 + if mi, ok := macInfo[macAddr]; ok { + ip := mi.Ip + // 检测原有ip是否在新的ip池内 + if IpPool.Ipv4IPNet.Contains(ip) { + mi.IsActive = true + mi.LastLogin = tNow + // 回写db数据 + dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + return ip + } else { + delete(macInfo, macAddr) + delete(ipInfo, ip.String()) + dbdata.Del(dbdata.BucketMacIp, macAddr) + } + } + + farMac := &MacIp{LastLogin: tNow} + // 全局遍历未分配ip + for i := IpPool.IpLongMin; i <= IpPool.IpLongMax; i++ { + ip := long2ip(i) + ipStr := ip.String() + v, ok := ipInfo[ipStr] + // 该ip没有被使用 + if !ok { + mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} + macInfo[macAddr] = mi + ipInfo[ipStr] = mi + // 回写db数据 + dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + return ip + } + + // 跳过活跃连接 + if v.IsActive { + continue + } + // 已经超过租期 + if tNow.Sub(v.LastLogin) > IpLease*time.Second { + delete(macInfo, v.MacAddr) + mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} + macInfo[macAddr] = mi + ipInfo[ipStr] = mi + // 回写db数据 + dbdata.Del(dbdata.BucketMacIp, v.MacAddr) + dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + return ip + } + // 其他情况判断最早登陆的mac + if v.LastLogin.Before(farMac.LastLogin) { + farMac = v + } + } + + // 全都在线,没有数据可用 + if farMac.MacAddr == "" { + return nil + } + + // 使用最早登陆的mac ip + delete(macInfo, farMac.MacAddr) + ip := farMac.Ip + mi := &MacIp{IsActive: true, Ip: ip, MacAddr: macAddr, LastLogin: tNow} + macInfo[macAddr] = mi + ipInfo[ip.String()] = mi + // 回写db数据 + dbdata.Del(dbdata.BucketMacIp, farMac.MacAddr) + dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + return ip +} + +// 回收ip +func ReleaseIp(ip net.IP, macAddr string) { + IpPool.mux.Lock() + defer IpPool.mux.Unlock() + if mi, ok := macInfo[macAddr]; ok { + if mi.Ip.Equal(ip) { + mi.IsActive = false + mi.LastLogin = time.Now() + // 回写db数据 + dbdata.Set(dbdata.BucketMacIp, macAddr, mi) + } + } +} diff --git a/sessdata/ip_pool_test.go b/sessdata/ip_pool_test.go new file mode 100644 index 0000000..dd20cbf --- /dev/null +++ b/sessdata/ip_pool_test.go @@ -0,0 +1,58 @@ +package sessdata + +import ( + "fmt" + "net" + "os" + "path" + "testing" + + "github.com/bjdgyc/anylink/common" + "github.com/bjdgyc/anylink/dbdata" + "github.com/stretchr/testify/assert" +) + +func preIpData() { + common.ServerCfg.Ipv4Network = "192.168.3.0" + common.ServerCfg.Ipv4Netmask = "255.255.255.0" + common.ServerCfg.Ipv4Pool = []string{"192.168.3.1", "192.168.3.199"} + tmpDb := path.Join(os.TempDir(), "anylink_test.db") + common.ServerCfg.DbFile = tmpDb + dbdata.Start() +} + +func closeIpdata() { + dbdata.Stop() + tmpDb := path.Join(os.TempDir(), "anylink_test.db") + os.Remove(tmpDb) +} + +func TestIpPool(t *testing.T) { + assert := assert.New(t) + preIpData() + defer closeIpdata() + + macInfo = map[string]*MacIp{} + ipInfo = map[string]*MacIp{} + initIpPool() + + var ip net.IP + + for i := 1; i <= 100; i++ { + ip = AcquireIp(fmt.Sprintf("mac-%d", i)) + } + ip = AcquireIp(fmt.Sprintf("mac-new")) + assert.True(net.IPv4(192, 168, 3, 101).Equal(ip)) + for i := 102; i <= 199; i++ { + ip = AcquireIp(fmt.Sprintf("mac-%d", i)) + } + assert.True(net.IPv4(192, 168, 3, 199).Equal(ip)) + ip = AcquireIp(fmt.Sprintf("mac-nil")) + assert.Nil(ip) + + ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88") + ReleaseIp(net.IPv4(192, 168, 3, 77), "mac-77") + // 最早过期的ip + ip = AcquireIp("mac-release-new") + assert.True(net.IPv4(192, 168, 3, 88).Equal(ip)) +} diff --git a/sessdata/limit_client.go b/sessdata/limit_client.go new file mode 100644 index 0000000..c308473 --- /dev/null +++ b/sessdata/limit_client.go @@ -0,0 +1,46 @@ +package sessdata + +import ( + "sync" + + "github.com/bjdgyc/anylink/common" +) + +const limitAllKey = "__ALL__" + +var ( + limitClient = map[string]int{limitAllKey: 0} + limitMux = sync.Mutex{} +) + +func LimitClient(user string, close bool) bool { + limitMux.Lock() + defer limitMux.Unlock() + // defer fmt.Println(limitClient) + + _all := limitClient[limitAllKey] + c, ok := limitClient[user] + if !ok { // 不存在用户 + limitClient[user] = 0 + } + + if close { + limitClient[user] = c - 1 + limitClient[limitAllKey] = _all - 1 + return true + } + + // 全局判断 + if _all >= common.ServerCfg.MaxClient { + return false + } + + // 超出同一个用户限制 + if c >= common.ServerCfg.MaxUserClient { + return false + } + + limitClient[user] = c + 1 + limitClient[limitAllKey] = _all + 1 + return true +} diff --git a/sessdata/limit_rate.go b/sessdata/limit_rate.go new file mode 100644 index 0000000..e100925 --- /dev/null +++ b/sessdata/limit_rate.go @@ -0,0 +1,45 @@ +package sessdata + +import ( + "context" + "fmt" + "time" + + "github.com/bjdgyc/anylink/common" + + "golang.org/x/time/rate" +) + +var Sess = &ConnSession{} + +func init() { + return + tick := time.Tick(time.Second * 2) + go func() { + for range tick { + uP := common.HumanByte(float64(Sess.BandwidthUpPeriod / BandwidthPeriodSec)) + dP := common.HumanByte(float64(Sess.BandwidthDownPeriod / BandwidthPeriodSec)) + uA := common.HumanByte(float64(Sess.BandwidthUpAll)) + dA := common.HumanByte(float64(Sess.BandwidthDownAll)) + + fmt.Printf("rateUp:%s rateDown:%s allUp %s allDown %s \n", + uP, dP, uA, dA) + } + }() +} + +type LimitRater struct { + limit *rate.Limiter +} + +// lim: 令牌产生速率 +// burst: 允许的最大爆发速率 +func NewLimitRater(lim, burst int) *LimitRater { + limit := rate.NewLimiter(rate.Limit(lim), burst) + return &LimitRater{limit: limit} +} + +// bt 不能超过burst大小 +func (l *LimitRater) Wait(bt int) error { + return l.limit.WaitN(context.Background(), bt) +} diff --git a/sessdata/limit_test.go b/sessdata/limit_test.go new file mode 100644 index 0000000..908c496 --- /dev/null +++ b/sessdata/limit_test.go @@ -0,0 +1,55 @@ +package sessdata + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/bjdgyc/anylink/common" +) + +// func TestCheckUser(t *testing.T) { +// users["user1"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941b"} +// users["user2"] = User{Password: "7c4a8d09ca3762af61e59520943dc26494f8941c"} +// +// var res bool +// res = CheckUser("user1", "123456", "") +// AssertTrue(t, res == true) +// +// res = CheckUser("user2", "123457", "") +// AssertTrue(t, res == false) +// } + +func TestLimitClient(t *testing.T) { + assert := assert.New(t) + common.ServerCfg.MaxClient = 2 + common.ServerCfg.MaxUserClient = 1 + + res1 := LimitClient("user1", false) + res2 := LimitClient("user1", false) + res3 := LimitClient("user2", false) + res4 := LimitClient("user3", false) + res5 := LimitClient("user1", true) + + assert.True(res1) + assert.False(res2) + assert.True(res3) + assert.False(res4) + assert.True(res5) + +} + +func TestLimitWait(t *testing.T) { + assert := assert.New(t) + limit := NewLimitRater(1, 2) + limit.Wait(2) + start := time.Now() + err := limit.Wait(2) + assert.Nil(err) + err = limit.Wait(1) + assert.Nil(err) + end := time.Now() + sub := end.Sub(start) + assert.Equal(3, int(sub.Seconds())) +} diff --git a/handler/proto.go b/sessdata/protocol.go similarity index 95% rename from handler/proto.go rename to sessdata/protocol.go index 12b51ca..f17a71b 100644 --- a/handler/proto.go +++ b/sessdata/protocol.go @@ -1,8 +1,16 @@ -package handler +package sessdata + +type LType int8 + +const ( + LTypeEthernet LType = 1 + LTypeIPData LType = 2 +) type Payload struct { - ptype byte - data []byte + PType byte // payload types + LType LType // LinkType + Data []byte } /* diff --git a/sessdata/session.go b/sessdata/session.go new file mode 100644 index 0000000..70725e4 --- /dev/null +++ b/sessdata/session.go @@ -0,0 +1,258 @@ +package sessdata + +import ( + "crypto/md5" + "fmt" + "log" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bjdgyc/anylink/common" +) + +const BandwidthPeriodSec = 2 // 流量速率统计周期(秒) + +var ( + // session_token -> SessUser + sessions = sync.Map{} // make(map[string]*Session) +) + +// 连接sess +type ConnSession struct { + Sess *Session + MasterSecret string // dtls协议的 master_secret + Ip net.IP // 分配的ip地址 + LocalIp net.IP + MacHw net.HardwareAddr // 客户端mac地址,从Session取出 + RemoteAddr string + Mtu int + TunName string + Limit *LimitRater + BandwidthUp uint32 // 使用上行带宽 Byte + BandwidthDown uint32 // 使用下行带宽 Byte + BandwidthUpPeriod uint32 // 前一周期的总量 + BandwidthDownPeriod uint32 + BandwidthUpAll uint32 // 使用上行带宽总量 + BandwidthDownAll uint32 // 使用下行带宽总量 + closeOnce sync.Once + CloseChan chan struct{} + PayloadIn chan *Payload + PayloadOut chan *Payload + PayloadArp chan *Payload +} + +type Session struct { + mux sync.Mutex + Sid string // auth返回的 session-id + Token string // session信息的唯一token + DtlsSid string // dtls协议的 session_id + MacAddr string // 客户端mac地址 + UniqueIdGlobal string // 客户端唯一标示 + UserName string // 用户名 + + LastLogin time.Time + IsActive bool + + // 开启link需要设置的参数 + CSess *ConnSession +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func checkSession() { + + // 检测过期的session + go func() { + if common.ServerCfg.SessionTimeout == 0 { + return + } + timeout := time.Duration(common.ServerCfg.SessionTimeout) * time.Second + tick := time.Tick(time.Second * 60) + for range tick { + t := time.Now() + + sessions.Range(func(key, value interface{}) bool { + v := value.(*Session) + v.mux.Lock() + defer v.mux.Unlock() + + if v.IsActive == true { + return true + } + if t.Sub(v.LastLogin) > timeout { + sessions.Delete(key) + } + return true + }) + + } + }() +} + +func NewSession() *Session { + // 生成32位的 token + btoken := make([]byte, 32) + rand.Read(btoken) + + // 生成 dtlsn session_id + dtlsid := make([]byte, 32) + rand.Read(dtlsid) + + token := fmt.Sprintf("%x", btoken) + sess := &Session{ + Sid: fmt.Sprintf("%d", time.Now().Unix()), + Token: token, + DtlsSid: fmt.Sprintf("%x", dtlsid), + LastLogin: time.Now(), + } + + sessions.Store(token, sess) + return sess +} + +func (s *Session) NewConn() *ConnSession { + s.mux.Lock() + active := s.IsActive + macAddr := s.MacAddr + s.mux.Unlock() + if active == true { + s.CSess.Close() + } + + limit := LimitClient(s.UserName, false) + if limit == false { + return nil + } + // 获取客户端mac地址 + macHw, err := net.ParseMAC(macAddr) + if err != nil { + sum := md5.Sum([]byte(s.UniqueIdGlobal)) + macHw = sum[8:13] // 5个byte + macHw = append([]byte{0x00}, macHw...) + macAddr = macHw.String() + } + ip := AcquireIp(macAddr) + if ip == nil { + return nil + } + + cSess := &ConnSession{ + Sess: s, + MacHw: macHw, + Ip: ip, + closeOnce: sync.Once{}, + CloseChan: make(chan struct{}), + PayloadIn: make(chan *Payload), + PayloadOut: make(chan *Payload), + PayloadArp: make(chan *Payload, 1000), + // Limit: NewLimitRater(1024 * 1024), + } + + go cSess.ratePeriod() + + s.mux.Lock() + s.MacAddr = macAddr + s.IsActive = true + s.CSess = cSess + s.mux.Unlock() + return cSess +} + +func (cs *ConnSession) Close() { + cs.closeOnce.Do(func() { + log.Println("closeOnce:", cs.Ip) + cs.Sess.mux.Lock() + defer cs.Sess.mux.Unlock() + + close(cs.CloseChan) + cs.Sess.IsActive = false + cs.Sess.LastLogin = time.Now() + cs.Sess.CSess = nil + + ReleaseIp(cs.Ip, cs.Sess.MacAddr) + LimitClient(cs.Sess.UserName, true) + }) +} + +func (cs *ConnSession) ratePeriod() { + tick := time.Tick(time.Second * BandwidthPeriodSec) + for range tick { + select { + case <-cs.CloseChan: + return + default: + } + + // 实时流量清零 + rtUp := atomic.SwapUint32(&cs.BandwidthUp, 0) + rtDown := atomic.SwapUint32(&cs.BandwidthDown, 0) + // 设置上一周期的流量 + atomic.SwapUint32(&cs.BandwidthUpPeriod, rtUp) + atomic.SwapUint32(&cs.BandwidthDownPeriod, rtDown) + // 累加所有流量 + atomic.AddUint32(&cs.BandwidthUpAll, rtUp) + atomic.AddUint32(&cs.BandwidthDownAll, rtDown) + } +} + +const MaxMtu = 1460 + +func (cs *ConnSession) SetMtu(mtu string) { + cs.Mtu = MaxMtu + + mi, err := strconv.Atoi(mtu) + if err != nil || mi < 100 { + return + } + + if mi < MaxMtu { + cs.Mtu = mi + } +} + +func (cs *ConnSession) RateLimit(byt int, isUp bool) error { + if isUp { + atomic.AddUint32(&cs.BandwidthUp, uint32(byt)) + return nil + } + // 只对下行速率限制 + atomic.AddUint32(&cs.BandwidthDown, uint32(byt)) + if cs.Limit == nil { + return nil + } + return cs.Limit.Wait(byt) +} + +func SToken2Sess(stoken string) *Session { + stoken = strings.TrimSpace(stoken) + sarr := strings.Split(stoken, "@") + token := sarr[1] + + if sess, ok := sessions.Load(token); ok { + return sess.(*Session) + } + + return nil +} + +func Dtls2Sess(dtlsid []byte) *Session { + return nil +} + +func DelSess(token string) { + // sessions.Delete(token) +} + +func DelSessByStoken(stoken string) { + stoken = strings.TrimSpace(stoken) + sarr := strings.Split(stoken, "@") + token := sarr[1] + sessions.Delete(token) +} diff --git a/sessdata/session_test.go b/sessdata/session_test.go new file mode 100644 index 0000000..76022f5 --- /dev/null +++ b/sessdata/session_test.go @@ -0,0 +1,31 @@ +package sessdata + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewSession(t *testing.T) { + assert := assert.New(t) + sessions = sync.Map{} + sess := NewSession() + token := sess.Token + v, ok := sessions.Load(token) + assert.True(ok) + assert.Equal(sess, v) +} + +func TestConnSession(t *testing.T) { + assert := assert.New(t) + preIpData() + defer closeIpdata() + sess := NewSession() + cSess := sess.NewConn() + cSess.RateLimit(100, true) + assert.Equal(cSess.BandwidthUp, uint32(100)) + cSess.RateLimit(200, false) + assert.Equal(cSess.BandwidthDown, uint32(200)) + cSess.Close() +} diff --git a/sessdata/start.go b/sessdata/start.go new file mode 100644 index 0000000..fa8d86a --- /dev/null +++ b/sessdata/start.go @@ -0,0 +1,7 @@ +package sessdata + +func Start() { + initIpPool() + initIpMac() + checkSession() +}