mirror of
				https://github.com/bjdgyc/anylink.git
				synced 2025-11-04 11:06:22 +08:00 
			
		
		
		
	修复断线重连后的bug
添加 proxy protocol v1 的支持
This commit is contained in:
		@@ -8,7 +8,7 @@ AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannop
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。
 | 
					AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
AnyLink 服务端仅在CentOs7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。
 | 
					AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Installation
 | 
					## Installation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -26,6 +26,7 @@ sudo ./anylink -conf="conf/server.toml"
 | 
				
			|||||||
- [x] TLS-TCP通道
 | 
					- [x] TLS-TCP通道
 | 
				
			||||||
- [x] 兼容AnyConnect
 | 
					- [x] 兼容AnyConnect
 | 
				
			||||||
- [x] 多用户支持
 | 
					- [x] 多用户支持
 | 
				
			||||||
 | 
					- [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议
 | 
				
			||||||
- [ ] DTLS-UDP通道
 | 
					- [ ] DTLS-UDP通道
 | 
				
			||||||
- [ ] 后台管理界面
 | 
					- [ ] 后台管理界面
 | 
				
			||||||
- [ ] 用户组支持
 | 
					- [ ] 用户组支持
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,7 @@ type ServerConfig struct {
 | 
				
			|||||||
	UserFile       string   `toml:"user_file"`
 | 
						UserFile       string   `toml:"user_file"`
 | 
				
			||||||
	ServerAddr     string   `toml:"server_addr"`
 | 
						ServerAddr     string   `toml:"server_addr"`
 | 
				
			||||||
	DebugAddr      string   `toml:"debug_addr"`
 | 
						DebugAddr      string   `toml:"debug_addr"`
 | 
				
			||||||
 | 
						ProxyProtocol  bool     `toml:"proxy_protocol"`
 | 
				
			||||||
	CertFile       string   `toml:"cert_file"`
 | 
						CertFile       string   `toml:"cert_file"`
 | 
				
			||||||
	CertKey        string   `toml:"cert_key"`
 | 
						CertKey        string   `toml:"cert_key"`
 | 
				
			||||||
	LinkGroups     []string `toml:"link_groups"`
 | 
						LinkGroups     []string `toml:"link_groups"`
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								conf/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								conf/.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,3 +0,0 @@
 | 
				
			|||||||
#过滤本地证书文件
 | 
					 | 
				
			||||||
vpn_cert.key
 | 
					 | 
				
			||||||
vpn_cert.pem
 | 
					 | 
				
			||||||
@@ -10,6 +10,8 @@ cert_key = "./vpn_cert.key"
 | 
				
			|||||||
#服务监听的地址
 | 
					#服务监听的地址
 | 
				
			||||||
server_addr = ":443"
 | 
					server_addr = ":443"
 | 
				
			||||||
debug_addr = "127.0.0.1:8800"
 | 
					debug_addr = "127.0.0.1:8800"
 | 
				
			||||||
 | 
					#开启tcp proxy protocol协议
 | 
				
			||||||
 | 
					proxy_protocol = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#用户组
 | 
					#用户组
 | 
				
			||||||
link_groups = ["one", "two"]
 | 
					link_groups = ["one", "two"]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -6,5 +6,5 @@ require (
 | 
				
			|||||||
	github.com/julienschmidt/httprouter v1.3.0
 | 
						github.com/julienschmidt/httprouter v1.3.0
 | 
				
			||||||
	github.com/pelletier/go-toml v1.8.0
 | 
						github.com/pelletier/go-toml v1.8.0
 | 
				
			||||||
	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 | 
						github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 | 
				
			||||||
	golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect
 | 
						golang.org/x/sys v0.0.0-20200819171115-d785dc25833f // indirect
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@@ -8,8 +8,8 @@ github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMF
 | 
				
			|||||||
github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs=
 | 
					github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs=
 | 
				
			||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
 | 
					github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
 | 
				
			||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
 | 
					github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
 | 
				
			||||||
golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw=
 | 
					golang.org/x/sys v0.0.0-20200819171115-d785dc25833f h1:KJuwZVtZBVzDmEDtB2zro9CXkD9O0dpCv4o2LHbQIAw=
 | 
				
			||||||
golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
					golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
				
			||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 | 
					gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 | 
				
			||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
					gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
				
			||||||
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
 | 
					gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -47,6 +47,7 @@ func checkVpnClient(h httprouter.Handle) httprouter.Handle {
 | 
				
			|||||||
		// TODO 调试信息输出
 | 
							// TODO 调试信息输出
 | 
				
			||||||
		// hd, _ := httputil.DumpRequest(r, true)
 | 
							// hd, _ := httputil.DumpRequest(r, true)
 | 
				
			||||||
		// fmt.Println("DumpRequest: ", string(hd))
 | 
							// fmt.Println("DumpRequest: ", string(hd))
 | 
				
			||||||
 | 
							fmt.Println(r.RemoteAddr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		user_Agent := strings.ToLower(r.UserAgent())
 | 
							user_Agent := strings.ToLower(r.UserAgent())
 | 
				
			||||||
		x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
 | 
							x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package handler
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -9,12 +10,12 @@ import (
 | 
				
			|||||||
	"github.com/bjdgyc/anylink/common"
 | 
						"github.com/bjdgyc/anylink/common"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func LinkCstp(conn net.Conn, sess *Session) {
 | 
					func LinkCstp(conn net.Conn, sess *ConnSession) {
 | 
				
			||||||
	// fmt.Println("HandlerCstp")
 | 
						// fmt.Println("HandlerCstp")
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
 | 
							log.Println("LinkCstp return")
 | 
				
			||||||
		conn.Close()
 | 
							conn.Close()
 | 
				
			||||||
		sess.Close()
 | 
							sess.Close()
 | 
				
			||||||
		log.Println("LinkCstp return")
 | 
					 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
@@ -47,7 +48,7 @@ func LinkCstp(conn net.Conn, sess *Session) {
 | 
				
			|||||||
			// fmt.Println("DISCONNECT")
 | 
								// fmt.Println("DISCONNECT")
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		case 0x03: // DPD-REQ
 | 
							case 0x03: // DPD-REQ
 | 
				
			||||||
			// fmt.Println("DPD-REQ")
 | 
								fmt.Println("DPD-REQ")
 | 
				
			||||||
			payload := &Payload{
 | 
								payload := &Payload{
 | 
				
			||||||
				ptype: 0x04, // DPD-RESP
 | 
									ptype: 0x04, // DPD-RESP
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -73,11 +74,11 @@ func LinkCstp(conn net.Conn, sess *Session) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func cstpWrite(conn net.Conn, sess *Session) {
 | 
					func cstpWrite(conn net.Conn, sess *ConnSession) {
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
 | 
							log.Println("cstpWrite return")
 | 
				
			||||||
		conn.Close()
 | 
							conn.Close()
 | 
				
			||||||
		sess.Close()
 | 
							sess.Close()
 | 
				
			||||||
		log.Println("cstpWrite return")
 | 
					 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -29,10 +29,10 @@ func testTun() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 创建tun网卡
 | 
					// 创建tun网卡
 | 
				
			||||||
func LinkTun(sess *Session) {
 | 
					func LinkTun(sess *ConnSession) {
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
		sess.Close()
 | 
					 | 
				
			||||||
		log.Println("LinkTun return")
 | 
							log.Println("LinkTun return")
 | 
				
			||||||
 | 
							sess.Close()
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cfg := water.Config{
 | 
						cfg := water.Config{
 | 
				
			||||||
@@ -84,7 +84,11 @@ func LinkTun(sess *Session) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func tunRead(ifce *water.Interface, sess *Session) {
 | 
					func tunRead(ifce *water.Interface, sess *ConnSession) {
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							log.Println("tunRead return")
 | 
				
			||||||
 | 
							ifce.Close()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		err error
 | 
							err error
 | 
				
			||||||
		n   int
 | 
							n   int
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package handler
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,7 +20,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	// TODO 调试信息输出
 | 
						// TODO 调试信息输出
 | 
				
			||||||
	// hd, _ := httputil.DumpRequest(r, true)
 | 
						// hd, _ := httputil.DumpRequest(r, true)
 | 
				
			||||||
	// fmt.Println("DumpRequest: ", string(hd))
 | 
						// fmt.Println("DumpRequest: ", string(hd))
 | 
				
			||||||
	// fmt.Println(r.RemoteAddr)
 | 
						fmt.Println("LinkTunnel", r.RemoteAddr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 判断session-token的值
 | 
						// 判断session-token的值
 | 
				
			||||||
	cookie, err := r.Cookie("webvpn")
 | 
						cookie, err := r.Cookie("webvpn")
 | 
				
			||||||
@@ -35,8 +36,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 开启link
 | 
						// 开启link
 | 
				
			||||||
	sess.StartLink()
 | 
						cSess := sess.StartConn()
 | 
				
			||||||
	if sess.NetIp == nil {
 | 
						if cSess == nil {
 | 
				
			||||||
 | 
							log.Println(err)
 | 
				
			||||||
		w.WriteHeader(http.StatusBadRequest)
 | 
							w.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -44,14 +46,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	// 客户端信息
 | 
						// 客户端信息
 | 
				
			||||||
	cstp_mtu := r.Header.Get("X-CSTP-MTU")
 | 
						cstp_mtu := r.Header.Get("X-CSTP-MTU")
 | 
				
			||||||
	master_Secret := r.Header.Get("X-DTLS-Master-Secret")
 | 
						master_Secret := r.Header.Get("X-DTLS-Master-Secret")
 | 
				
			||||||
	sess.MasterSecret = master_Secret
 | 
						cSess.MasterSecret = master_Secret
 | 
				
			||||||
	sess.Mtu = cstp_mtu
 | 
						cSess.Mtu = cstp_mtu
 | 
				
			||||||
	sess.RemoteAddr = r.RemoteAddr
 | 
						cSess.RemoteAddr = r.RemoteAddr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER))
 | 
						w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER))
 | 
				
			||||||
	w.Header().Set("X-CSTP-Version", "1")
 | 
						w.Header().Set("X-CSTP-Version", "1")
 | 
				
			||||||
	w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
 | 
						w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
 | 
				
			||||||
	w.Header().Set("X-CSTP-Address", sess.NetIp.String())          // 分配的ip地址
 | 
						w.Header().Set("X-CSTP-Address", cSess.NetIp.String())    // 分配的ip地址
 | 
				
			||||||
	w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码
 | 
						w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码
 | 
				
			||||||
	w.Header().Set("X-CSTP-Hostname", hn)                          // 机器名称
 | 
						w.Header().Set("X-CSTP-Hostname", hn)                          // 机器名称
 | 
				
			||||||
	for _, v := range common.ServerCfg.ClientDns {
 | 
						for _, v := range common.ServerCfg.ClientDns {
 | 
				
			||||||
@@ -115,6 +117,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 开始数据处理
 | 
						// 开始数据处理
 | 
				
			||||||
	go LinkTun(sess)
 | 
						go LinkTun(cSess)
 | 
				
			||||||
	go LinkCstp(conn, sess)
 | 
						go LinkCstp(conn, cSess)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,11 +3,13 @@ package handler
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/tls"
 | 
						"crypto/tls"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/bjdgyc/anylink/proxyproto"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httputil"
 | 
						"net/http/httputil"
 | 
				
			||||||
	_ "net/http/pprof"
 | 
						_ "net/http/pprof"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bjdgyc/anylink/common"
 | 
						"github.com/bjdgyc/anylink/common"
 | 
				
			||||||
	"github.com/julienschmidt/httprouter"
 | 
						"github.com/julienschmidt/httprouter"
 | 
				
			||||||
@@ -40,13 +42,18 @@ func startTls() {
 | 
				
			|||||||
		TLSConfig: tlsConfig,
 | 
							TLSConfig: tlsConfig,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var ln net.Listener
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ln, err := net.Listen("tcp", addr)
 | 
						ln, err := net.Listen("tcp", addr)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatal(err)
 | 
							log.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer ln.Close()
 | 
						defer ln.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	srv.SetKeepAlivesEnabled(true)
 | 
						if common.ServerCfg.ProxyProtocol {
 | 
				
			||||||
 | 
							ln = &proxyproto.Listener{Listener: ln, ProxyHeaderTimeout: time.Second * 5}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Println("listen ", addr)
 | 
						fmt.Println("listen ", addr)
 | 
				
			||||||
	err = srv.ServeTLS(ln, certFile, keyFile)
 | 
						err = srv.ServeTLS(ln, certFile, keyFile)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,27 +17,33 @@ var (
 | 
				
			|||||||
	sessions = make(map[string]*Session) // session_token -> SessUser
 | 
						sessions = make(map[string]*Session) // session_token -> SessUser
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Session struct {
 | 
					// 连接sess
 | 
				
			||||||
	Sid     string // auth返回的 session-id
 | 
					type ConnSession struct {
 | 
				
			||||||
	Token   string // session信息的唯一token
 | 
						Sess         *Session
 | 
				
			||||||
	DtlsSid string // dtls协议的 session_id
 | 
					 | 
				
			||||||
	MacAddr string // 客户端mac地址
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// 开启link需要设置的参数
 | 
					 | 
				
			||||||
	MasterSecret string // dtls协议的 master_secret
 | 
						MasterSecret string // dtls协议的 master_secret
 | 
				
			||||||
	NetIp        net.IP // 分配的ip地址
 | 
						NetIp        net.IP // 分配的ip地址
 | 
				
			||||||
	UserName     string // 用户名
 | 
					 | 
				
			||||||
	RemoteAddr   string
 | 
						RemoteAddr   string
 | 
				
			||||||
	Mtu          string
 | 
						Mtu          string
 | 
				
			||||||
	TunName      string
 | 
						TunName      string
 | 
				
			||||||
	IsActive     bool
 | 
					 | 
				
			||||||
	LastLogin    time.Time
 | 
					 | 
				
			||||||
	closeOnce    sync.Once
 | 
						closeOnce    sync.Once
 | 
				
			||||||
	Closed       chan struct{}
 | 
						Closed       chan struct{}
 | 
				
			||||||
	PayloadIn    chan *Payload
 | 
						PayloadIn    chan *Payload
 | 
				
			||||||
	PayloadOut   chan *Payload
 | 
						PayloadOut   chan *Payload
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Session struct {
 | 
				
			||||||
 | 
						Sid       string // auth返回的 session-id
 | 
				
			||||||
 | 
						Token     string // session信息的唯一token
 | 
				
			||||||
 | 
						DtlsSid   string // dtls协议的 session_id
 | 
				
			||||||
 | 
						MacAddr   string // 客户端mac地址
 | 
				
			||||||
 | 
						UserName  string // 用户名
 | 
				
			||||||
 | 
						LastLogin time.Time
 | 
				
			||||||
 | 
						IsActive  bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 开启link需要设置的参数
 | 
				
			||||||
 | 
						CSess *ConnSession
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	rand.Seed(time.Now().UnixNano())
 | 
						rand.Seed(time.Now().UnixNano())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -86,28 +92,37 @@ func NewSession() *Session {
 | 
				
			|||||||
	return sess
 | 
						return sess
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *Session) StartLink() {
 | 
					func (s *Session) StartConn() *ConnSession {
 | 
				
			||||||
 | 
						if s.IsActive == true {
 | 
				
			||||||
 | 
							s.CSess.Close()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	limit := common.LimitClient(s.UserName, false)
 | 
						limit := common.LimitClient(s.UserName, false)
 | 
				
			||||||
	if limit == false {
 | 
						if limit == false {
 | 
				
			||||||
		s.NetIp = nil
 | 
							// s.NetIp = nil
 | 
				
			||||||
		return
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	s.NetIp = common.AcquireIp(s.MacAddr)
 | 
					 | 
				
			||||||
	s.IsActive = true
 | 
						s.IsActive = true
 | 
				
			||||||
	s.closeOnce = sync.Once{}
 | 
						cSess := &ConnSession{
 | 
				
			||||||
	s.Closed = make(chan struct{})
 | 
							Sess:       s,
 | 
				
			||||||
	s.PayloadIn = make(chan *Payload)
 | 
							NetIp:      common.AcquireIp(s.MacAddr),
 | 
				
			||||||
	s.PayloadOut = make(chan *Payload)
 | 
							closeOnce:  sync.Once{},
 | 
				
			||||||
 | 
							Closed:     make(chan struct{}),
 | 
				
			||||||
 | 
							PayloadIn:  make(chan *Payload),
 | 
				
			||||||
 | 
							PayloadOut: make(chan *Payload),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.CSess = cSess
 | 
				
			||||||
 | 
						return cSess
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *Session) Close() {
 | 
					func (cs *ConnSession) Close() {
 | 
				
			||||||
	s.closeOnce.Do(func() {
 | 
						cs.closeOnce.Do(func() {
 | 
				
			||||||
		log.Println("closeOnce")
 | 
							log.Println("closeOnce")
 | 
				
			||||||
		close(s.Closed)
 | 
							close(cs.Closed)
 | 
				
			||||||
		s.IsActive = false
 | 
							cs.Sess.IsActive = false
 | 
				
			||||||
		s.LastLogin = time.Now()
 | 
							cs.Sess.LastLogin = time.Now()
 | 
				
			||||||
		common.ReleaseIp(s.NetIp, s.MacAddr)
 | 
							common.ReleaseIp(cs.NetIp, cs.Sess.MacAddr)
 | 
				
			||||||
		common.LimitClient(s.UserName, true)
 | 
							common.LimitClient(cs.Sess.UserName, true)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										289
									
								
								proxyproto/protocol.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								proxyproto/protocol.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,289 @@
 | 
				
			|||||||
 | 
					// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
 | 
				
			||||||
 | 
					// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
 | 
				
			||||||
 | 
					package proxyproto
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						// prefix is the string we look for at the start of a connection
 | 
				
			||||||
 | 
						// to check if this connection is using the proxy protocol
 | 
				
			||||||
 | 
						prefix    = []byte("PROXY ")
 | 
				
			||||||
 | 
						prefixLen = len(prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SourceChecker can be used to decide whether to trust the PROXY info or pass
 | 
				
			||||||
 | 
					// the original connection address through. If set, the connecting address is
 | 
				
			||||||
 | 
					// passed in as an argument. If the function returns an error due to the source
 | 
				
			||||||
 | 
					// being disallowed, it should return ErrInvalidUpstream.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// If error is not nil, the call to Accept() will fail. If the reason for
 | 
				
			||||||
 | 
					// triggering this failure is due to a disallowed source, it should return
 | 
				
			||||||
 | 
					// ErrInvalidUpstream.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// If bool is true, the PROXY-set address is used.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// If bool is false, the connection's remote address is used, rather than the
 | 
				
			||||||
 | 
					// address claimed in the PROXY info.
 | 
				
			||||||
 | 
					type SourceChecker func(net.Addr) (bool, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Listener is used to wrap an underlying listener,
 | 
				
			||||||
 | 
					// whose connections may be using the HAProxy Proxy Protocol (version 1).
 | 
				
			||||||
 | 
					// If the connection is using the protocol, the RemoteAddr() will return
 | 
				
			||||||
 | 
					// the correct client address.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Optionally define ProxyHeaderTimeout to set a maximum time to
 | 
				
			||||||
 | 
					// receive the Proxy Protocol Header. Zero means no timeout.
 | 
				
			||||||
 | 
					type Listener struct {
 | 
				
			||||||
 | 
						Listener           net.Listener
 | 
				
			||||||
 | 
						ProxyHeaderTimeout time.Duration
 | 
				
			||||||
 | 
						SourceCheck        SourceChecker
 | 
				
			||||||
 | 
						UnknownOK          bool // allow PROXY UNKNOWN
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Conn is used to wrap and underlying connection which
 | 
				
			||||||
 | 
					// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
 | 
				
			||||||
 | 
					// return the address of the client instead of the proxy address.
 | 
				
			||||||
 | 
					type Conn struct {
 | 
				
			||||||
 | 
						bufReader          *bufio.Reader
 | 
				
			||||||
 | 
						conn               net.Conn
 | 
				
			||||||
 | 
						dstAddr            *net.TCPAddr
 | 
				
			||||||
 | 
						srcAddr            *net.TCPAddr
 | 
				
			||||||
 | 
						useConnAddr        bool
 | 
				
			||||||
 | 
						once               sync.Once
 | 
				
			||||||
 | 
						proxyHeaderTimeout time.Duration
 | 
				
			||||||
 | 
						unknownOK          bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Accept waits for and returns the next connection to the listener.
 | 
				
			||||||
 | 
					func (p *Listener) Accept() (net.Conn, error) {
 | 
				
			||||||
 | 
						// Get the underlying connection
 | 
				
			||||||
 | 
						conn, err := p.Listener.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var useConnAddr bool
 | 
				
			||||||
 | 
						if p.SourceCheck != nil {
 | 
				
			||||||
 | 
							allowed, err := p.SourceCheck(conn.RemoteAddr())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !allowed {
 | 
				
			||||||
 | 
								useConnAddr = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						newConn := NewConn(conn, p.ProxyHeaderTimeout)
 | 
				
			||||||
 | 
						newConn.useConnAddr = useConnAddr
 | 
				
			||||||
 | 
						newConn.unknownOK = p.UnknownOK
 | 
				
			||||||
 | 
						return newConn, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close closes the underlying listener.
 | 
				
			||||||
 | 
					func (p *Listener) Close() error {
 | 
				
			||||||
 | 
						return p.Listener.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Addr returns the underlying listener's network address.
 | 
				
			||||||
 | 
					func (p *Listener) Addr() net.Addr {
 | 
				
			||||||
 | 
						return p.Listener.Addr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewConn is used to wrap a net.Conn that may be speaking
 | 
				
			||||||
 | 
					// the proxy protocol into a proxyproto.Conn
 | 
				
			||||||
 | 
					func NewConn(conn net.Conn, timeout time.Duration) *Conn {
 | 
				
			||||||
 | 
						pConn := &Conn{
 | 
				
			||||||
 | 
							bufReader:          bufio.NewReader(conn),
 | 
				
			||||||
 | 
							conn:               conn,
 | 
				
			||||||
 | 
							proxyHeaderTimeout: timeout,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return pConn
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Read is check for the proxy protocol header when doing
 | 
				
			||||||
 | 
					// the initial scan. If there is an error parsing the header,
 | 
				
			||||||
 | 
					// it is returned and the socket is closed.
 | 
				
			||||||
 | 
					func (p *Conn) Read(b []byte) (int, error) {
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						p.once.Do(func() { err = p.checkPrefix() })
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return p.bufReader.Read(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
 | 
				
			||||||
 | 
						if rf, ok := p.conn.(io.ReaderFrom); ok {
 | 
				
			||||||
 | 
							return rf.ReadFrom(r)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return io.Copy(p.conn, r)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) WriteTo(w io.Writer) (int64, error) {
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						p.once.Do(func() { err = p.checkPrefix() })
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return 0, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return p.bufReader.WriteTo(w)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) Write(b []byte) (int, error) {
 | 
				
			||||||
 | 
						return p.conn.Write(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) Close() error {
 | 
				
			||||||
 | 
						return p.conn.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) LocalAddr() net.Addr {
 | 
				
			||||||
 | 
						p.checkPrefixOnce()
 | 
				
			||||||
 | 
						if p.dstAddr != nil && !p.useConnAddr {
 | 
				
			||||||
 | 
							return p.dstAddr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return p.conn.LocalAddr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RemoteAddr returns the address of the client if the proxy
 | 
				
			||||||
 | 
					// protocol is being used, otherwise just returns the address of
 | 
				
			||||||
 | 
					// the socket peer. If there is an error parsing the header, the
 | 
				
			||||||
 | 
					// address of the client is not returned, and the socket is closed.
 | 
				
			||||||
 | 
					// Once implication of this is that the call could block if the
 | 
				
			||||||
 | 
					// client is slow. Using a Deadline is recommended if this is called
 | 
				
			||||||
 | 
					// before Read()
 | 
				
			||||||
 | 
					func (p *Conn) RemoteAddr() net.Addr {
 | 
				
			||||||
 | 
						p.checkPrefixOnce()
 | 
				
			||||||
 | 
						if p.srcAddr != nil && !p.useConnAddr {
 | 
				
			||||||
 | 
							return p.srcAddr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return p.conn.RemoteAddr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) SetDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						return p.conn.SetDeadline(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) SetReadDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						return p.conn.SetReadDeadline(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						return p.conn.SetWriteDeadline(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) checkPrefixOnce() {
 | 
				
			||||||
 | 
						p.once.Do(func() {
 | 
				
			||||||
 | 
							if err := p.checkPrefix(); err != nil && err != io.EOF {
 | 
				
			||||||
 | 
								log.Printf("[ERR] Failed to read proxy prefix: %v", err)
 | 
				
			||||||
 | 
								p.Close()
 | 
				
			||||||
 | 
								p.bufReader = bufio.NewReader(p.conn)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *Conn) checkPrefix() error {
 | 
				
			||||||
 | 
						if p.proxyHeaderTimeout != 0 {
 | 
				
			||||||
 | 
							readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
 | 
				
			||||||
 | 
							p.conn.SetReadDeadline(readDeadLine)
 | 
				
			||||||
 | 
							defer p.conn.SetReadDeadline(time.Time{})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Incrementally check each byte of the prefix
 | 
				
			||||||
 | 
						for i := 1; i <= prefixLen; i++ {
 | 
				
			||||||
 | 
							inp, err := p.bufReader.Peek(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check for a prefix mis-match, quit early
 | 
				
			||||||
 | 
							if !bytes.Equal(inp, prefix[:i]) {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the header line
 | 
				
			||||||
 | 
						header, err := p.bufReader.ReadString('\n')
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Strip the carriage return and new line
 | 
				
			||||||
 | 
						header = header[:len(header)-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
 | 
				
			||||||
 | 
						parts := strings.Split(header, " ")
 | 
				
			||||||
 | 
						if len(parts) < 2 {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid header line: %s", header)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Verify the type is known
 | 
				
			||||||
 | 
						switch parts[1] {
 | 
				
			||||||
 | 
						case "UNKNOWN":
 | 
				
			||||||
 | 
							if !p.unknownOK || len(parts) != 2 {
 | 
				
			||||||
 | 
								p.conn.Close()
 | 
				
			||||||
 | 
								return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							p.useConnAddr = true
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						case "TCP4":
 | 
				
			||||||
 | 
						case "TCP6":
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Unhandled address type: %s", parts[1])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(parts) != 6 {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid header line: %s", header)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse out the source address
 | 
				
			||||||
 | 
						ip := net.ParseIP(parts[2])
 | 
				
			||||||
 | 
						if ip == nil {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid source ip: %s", parts[2])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						port, err := strconv.Atoi(parts[4])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid source port: %s", parts[4])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse out the destination address
 | 
				
			||||||
 | 
						ip = net.ParseIP(parts[3])
 | 
				
			||||||
 | 
						if ip == nil {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid destination ip: %s", parts[3])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						port, err = strconv.Atoi(parts[5])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							p.conn.Close()
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid destination port: %s", parts[5])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										488
									
								
								proxyproto/protocol_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										488
									
								
								proxyproto/protocol_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,488 @@
 | 
				
			|||||||
 | 
					// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
 | 
				
			||||||
 | 
					package proxyproto
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						goodAddr = "127.0.0.1"
 | 
				
			||||||
 | 
						badAddr  = "127.0.0.2"
 | 
				
			||||||
 | 
						errAddr  = "9999.0.0.2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						checkAddr string
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPassthrough(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTimeout(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						clientWriteDelay := 200 * time.Millisecond
 | 
				
			||||||
 | 
						proxyHeaderTimeout := 50 * time.Millisecond
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Do not send data for a while
 | 
				
			||||||
 | 
							time.Sleep(clientWriteDelay)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check the remote addr is the original 127.0.0.1
 | 
				
			||||||
 | 
						remoteAddrStartTime := time.Now()
 | 
				
			||||||
 | 
						addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if addr.IP.String() != "127.0.0.1" {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						remoteAddrDuration := time.Since(remoteAddrStartTime)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check RemoteAddr() call did timeout
 | 
				
			||||||
 | 
						if remoteAddrDuration >= clientWriteDelay {
 | 
				
			||||||
 | 
							t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParse_ipv4(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Write out the header!
 | 
				
			||||||
 | 
							header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
 | 
				
			||||||
 | 
							conn.Write([]byte(header))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check the remote addr
 | 
				
			||||||
 | 
						addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if addr.IP.String() != "10.1.1.1" {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if addr.Port != 1000 {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParse_ipv6(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Write out the header!
 | 
				
			||||||
 | 
							header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
 | 
				
			||||||
 | 
							conn.Write([]byte(header))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check the remote addr
 | 
				
			||||||
 | 
						addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if addr.IP.String() != "ffff::ffff" {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if addr.Port != 1000 {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParse_Unknown(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l, UnknownOK: true}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Write out the header!
 | 
				
			||||||
 | 
							header := "PROXY UNKNOWN\r\n"
 | 
				
			||||||
 | 
							conn.Write([]byte(header))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParse_BadHeader(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Write out the header!
 | 
				
			||||||
 | 
							header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
 | 
				
			||||||
 | 
							conn.Write([]byte(header))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check the remote addr, should be the local addr
 | 
				
			||||||
 | 
						addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if addr.IP.String() != "127.0.0.1" {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read should fail
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParse_ipv4_checkfunc(t *testing.T) {
 | 
				
			||||||
 | 
						checkAddr = goodAddr
 | 
				
			||||||
 | 
						testParse_ipv4_checkfunc(t)
 | 
				
			||||||
 | 
						checkAddr = badAddr
 | 
				
			||||||
 | 
						testParse_ipv4_checkfunc(t)
 | 
				
			||||||
 | 
						checkAddr = errAddr
 | 
				
			||||||
 | 
						testParse_ipv4_checkfunc(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func testParse_ipv4_checkfunc(t *testing.T) {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkFunc := func(addr net.Addr) (bool, error) {
 | 
				
			||||||
 | 
							tcpAddr := addr.(*net.TCPAddr)
 | 
				
			||||||
 | 
							if tcpAddr.IP.String() == checkAddr {
 | 
				
			||||||
 | 
								return true, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pl := &Listener{Listener: l, SourceCheck: checkFunc}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							conn, err := net.Dial("tcp", pl.Addr().String())
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Write out the header!
 | 
				
			||||||
 | 
							header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
 | 
				
			||||||
 | 
							conn.Write([]byte(header))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							conn.Write([]byte("ping"))
 | 
				
			||||||
 | 
							recv := make([]byte, 4)
 | 
				
			||||||
 | 
							_, err = conn.Read(recv)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !bytes.Equal(recv, []byte("pong")) {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := pl.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if checkAddr == badAddr {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recv := make([]byte, 4)
 | 
				
			||||||
 | 
						_, err = conn.Read(recv)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !bytes.Equal(recv, []byte("ping")) {
 | 
				
			||||||
 | 
							t.Fatalf("bad: %v", recv)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := conn.Write([]byte("pong")); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check the remote addr
 | 
				
			||||||
 | 
						addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						switch checkAddr {
 | 
				
			||||||
 | 
						case goodAddr:
 | 
				
			||||||
 | 
							if addr.IP.String() != "10.1.1.1" {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if addr.Port != 1000 {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case badAddr:
 | 
				
			||||||
 | 
							if addr.IP.String() != "127.0.0.1" {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if addr.Port == 1000 {
 | 
				
			||||||
 | 
								t.Fatalf("bad: %v", addr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type testConn struct {
 | 
				
			||||||
 | 
						readFromCalledWith io.Reader
 | 
				
			||||||
 | 
						net.Conn           // nil; crash on any unexpected use
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
 | 
				
			||||||
 | 
						c.readFromCalledWith = r
 | 
				
			||||||
 | 
						return 0, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (c *testConn) Write(p []byte) (int, error) {
 | 
				
			||||||
 | 
						return len(p), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (c *testConn) Read(p []byte) (int, error) {
 | 
				
			||||||
 | 
						return 1, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCopyToWrappedConnection(t *testing.T) {
 | 
				
			||||||
 | 
						innerConn := &testConn{}
 | 
				
			||||||
 | 
						wrappedConn := NewConn(innerConn, 0)
 | 
				
			||||||
 | 
						dummySrc := &testConn{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						io.Copy(wrappedConn, dummySrc)
 | 
				
			||||||
 | 
						if innerConn.readFromCalledWith != dummySrc {
 | 
				
			||||||
 | 
							t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCopyFromWrappedConnection(t *testing.T) {
 | 
				
			||||||
 | 
						wrappedConn := NewConn(&testConn{}, 0)
 | 
				
			||||||
 | 
						dummyDst := &testConn{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						io.Copy(dummyDst, wrappedConn)
 | 
				
			||||||
 | 
						if dummyDst.readFromCalledWith != wrappedConn.conn {
 | 
				
			||||||
 | 
							t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
 | 
				
			||||||
 | 
						innerConn1 := &testConn{}
 | 
				
			||||||
 | 
						wrappedConn1 := NewConn(innerConn1, 0)
 | 
				
			||||||
 | 
						innerConn2 := &testConn{}
 | 
				
			||||||
 | 
						wrappedConn2 := NewConn(innerConn2, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						io.Copy(wrappedConn1, wrappedConn2)
 | 
				
			||||||
 | 
						if innerConn1.readFromCalledWith != innerConn2 {
 | 
				
			||||||
 | 
							t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Reference in New Issue
	
	Block a user