diff --git a/README.md b/README.md index 6864664..a7139e3 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannop AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。 -AnyLink 服务端仅在CentOs7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。 +AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。 ## Installation @@ -26,6 +26,7 @@ sudo ./anylink -conf="conf/server.toml" - [x] TLS-TCP通道 - [x] 兼容AnyConnect - [x] 多用户支持 +- [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议 - [ ] DTLS-UDP通道 - [ ] 后台管理界面 - [ ] 用户组支持 diff --git a/common/cfg_server.go b/common/cfg_server.go index a302ad2..4f6854f 100644 --- a/common/cfg_server.go +++ b/common/cfg_server.go @@ -27,6 +27,7 @@ type ServerConfig struct { UserFile string `toml:"user_file"` ServerAddr string `toml:"server_addr"` DebugAddr string `toml:"debug_addr"` + ProxyProtocol bool `toml:"proxy_protocol"` CertFile string `toml:"cert_file"` CertKey string `toml:"cert_key"` LinkGroups []string `toml:"link_groups"` diff --git a/conf/.gitignore b/conf/.gitignore deleted file mode 100644 index eef0a03..0000000 --- a/conf/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -#过滤本地证书文件 -vpn_cert.key -vpn_cert.pem \ No newline at end of file diff --git a/conf/server.toml b/conf/server.toml index 0d765b8..3a3db96 100644 --- a/conf/server.toml +++ b/conf/server.toml @@ -10,6 +10,8 @@ cert_key = "./vpn_cert.key" #服务监听的地址 server_addr = ":443" debug_addr = "127.0.0.1:8800" +#开启tcp proxy protocol协议 +proxy_protocol = false #用户组 link_groups = ["one", "two"] diff --git a/go.mod b/go.mod index bb165b6..3eaec73 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,5 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/pelletier/go-toml v1.8.0 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect + golang.org/x/sys v0.0.0-20200819171115-d785dc25833f // indirect ) diff --git a/go.sum b/go.sum index 605c35a..c74f6b7 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMF github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= -golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw= -golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200819171115-d785dc25833f h1:KJuwZVtZBVzDmEDtB2zro9CXkD9O0dpCv4o2LHbQIAw= +golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= diff --git a/handler/base.go b/handler/base.go index 4dcf5ba..e5e9943 100644 --- a/handler/base.go +++ b/handler/base.go @@ -47,6 +47,7 @@ func checkVpnClient(h httprouter.Handle) httprouter.Handle { // TODO 调试信息输出 // hd, _ := httputil.DumpRequest(r, true) // fmt.Println("DumpRequest: ", string(hd)) + fmt.Println(r.RemoteAddr) user_Agent := strings.ToLower(r.UserAgent()) x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth") diff --git a/handler/link_cstp.go b/handler/link_cstp.go index 4d8337e..4cdd11e 100644 --- a/handler/link_cstp.go +++ b/handler/link_cstp.go @@ -2,6 +2,7 @@ package handler import ( "encoding/binary" + "fmt" "log" "net" "time" @@ -9,12 +10,12 @@ import ( "github.com/bjdgyc/anylink/common" ) -func LinkCstp(conn net.Conn, sess *Session) { +func LinkCstp(conn net.Conn, sess *ConnSession) { // fmt.Println("HandlerCstp") defer func() { + log.Println("LinkCstp return") conn.Close() sess.Close() - log.Println("LinkCstp return") }() var ( @@ -47,7 +48,7 @@ func LinkCstp(conn net.Conn, sess *Session) { // fmt.Println("DISCONNECT") return case 0x03: // DPD-REQ - // fmt.Println("DPD-REQ") + fmt.Println("DPD-REQ") payload := &Payload{ ptype: 0x04, // DPD-RESP } @@ -73,11 +74,11 @@ func LinkCstp(conn net.Conn, sess *Session) { } } -func cstpWrite(conn net.Conn, sess *Session) { +func cstpWrite(conn net.Conn, sess *ConnSession) { defer func() { + log.Println("cstpWrite return") conn.Close() sess.Close() - log.Println("cstpWrite return") }() var ( diff --git a/handler/link_tun.go b/handler/link_tun.go index 567a32d..8fca1db 100644 --- a/handler/link_tun.go +++ b/handler/link_tun.go @@ -29,10 +29,10 @@ func testTun() { } // 创建tun网卡 -func LinkTun(sess *Session) { +func LinkTun(sess *ConnSession) { defer func() { - sess.Close() log.Println("LinkTun return") + sess.Close() }() cfg := water.Config{ @@ -84,7 +84,11 @@ func LinkTun(sess *Session) { } -func tunRead(ifce *water.Interface, sess *Session) { +func tunRead(ifce *water.Interface, sess *ConnSession) { + defer func() { + log.Println("tunRead return") + ifce.Close() + }() var ( err error n int diff --git a/handler/link_tunnel.go b/handler/link_tunnel.go index b8c66ca..82cfa87 100644 --- a/handler/link_tunnel.go +++ b/handler/link_tunnel.go @@ -2,6 +2,7 @@ package handler import ( "fmt" + "log" "net/http" "os" @@ -19,7 +20,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // TODO 调试信息输出 // hd, _ := httputil.DumpRequest(r, true) // fmt.Println("DumpRequest: ", string(hd)) - // fmt.Println(r.RemoteAddr) + fmt.Println("LinkTunnel", r.RemoteAddr) // 判断session-token的值 cookie, err := r.Cookie("webvpn") @@ -35,8 +36,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { } // 开启link - sess.StartLink() - if sess.NetIp == nil { + cSess := sess.StartConn() + if cSess == nil { + log.Println(err) w.WriteHeader(http.StatusBadRequest) return } @@ -44,14 +46,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { // 客户端信息 cstp_mtu := r.Header.Get("X-CSTP-MTU") master_Secret := r.Header.Get("X-DTLS-Master-Secret") - sess.MasterSecret = master_Secret - sess.Mtu = cstp_mtu - sess.RemoteAddr = r.RemoteAddr + cSess.MasterSecret = master_Secret + cSess.Mtu = cstp_mtu + cSess.RemoteAddr = r.RemoteAddr w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER)) w.Header().Set("X-CSTP-Version", "1") w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.") - w.Header().Set("X-CSTP-Address", sess.NetIp.String()) // 分配的ip地址 + w.Header().Set("X-CSTP-Address", cSess.NetIp.String()) // 分配的ip地址 w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码 w.Header().Set("X-CSTP-Hostname", hn) // 机器名称 for _, v := range common.ServerCfg.ClientDns { @@ -115,6 +117,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { } // 开始数据处理 - go LinkTun(sess) - go LinkCstp(conn, sess) + go LinkTun(cSess) + go LinkCstp(conn, cSess) } diff --git a/handler/server.go b/handler/server.go index 99450e5..b7aae08 100644 --- a/handler/server.go +++ b/handler/server.go @@ -3,11 +3,13 @@ package handler import ( "crypto/tls" "fmt" + "github.com/bjdgyc/anylink/proxyproto" "log" "net" "net/http" "net/http/httputil" _ "net/http/pprof" + "time" "github.com/bjdgyc/anylink/common" "github.com/julienschmidt/httprouter" @@ -40,13 +42,18 @@ func startTls() { TLSConfig: tlsConfig, } + var ln net.Listener + ln, err := net.Listen("tcp", addr) if err != nil { log.Fatal(err) } defer ln.Close() - srv.SetKeepAlivesEnabled(true) + if common.ServerCfg.ProxyProtocol { + ln = &proxyproto.Listener{Listener: ln, ProxyHeaderTimeout: time.Second * 5} + } + fmt.Println("listen ", addr) err = srv.ServeTLS(ln, certFile, keyFile) if err != nil { diff --git a/handler/session.go b/handler/session.go index 6bea917..c67f59a 100644 --- a/handler/session.go +++ b/handler/session.go @@ -17,27 +17,33 @@ var ( sessions = make(map[string]*Session) // session_token -> SessUser ) -type Session struct { - Sid string // auth返回的 session-id - Token string // session信息的唯一token - DtlsSid string // dtls协议的 session_id - MacAddr string // 客户端mac地址 - - // 开启link需要设置的参数 +// 连接sess +type ConnSession struct { + Sess *Session MasterSecret string // dtls协议的 master_secret NetIp net.IP // 分配的ip地址 - UserName string // 用户名 RemoteAddr string Mtu string TunName string - IsActive bool - LastLogin time.Time closeOnce sync.Once Closed chan struct{} PayloadIn chan *Payload PayloadOut chan *Payload } +type Session struct { + Sid string // auth返回的 session-id + Token string // session信息的唯一token + DtlsSid string // dtls协议的 session_id + MacAddr string // 客户端mac地址 + UserName string // 用户名 + LastLogin time.Time + IsActive bool + + // 开启link需要设置的参数 + CSess *ConnSession +} + func init() { rand.Seed(time.Now().UnixNano()) @@ -86,28 +92,37 @@ func NewSession() *Session { return sess } -func (s *Session) StartLink() { +func (s *Session) StartConn() *ConnSession { + if s.IsActive == true { + s.CSess.Close() + } + limit := common.LimitClient(s.UserName, false) if limit == false { - s.NetIp = nil - return + // s.NetIp = nil + return nil } - s.NetIp = common.AcquireIp(s.MacAddr) s.IsActive = true - s.closeOnce = sync.Once{} - s.Closed = make(chan struct{}) - s.PayloadIn = make(chan *Payload) - s.PayloadOut = make(chan *Payload) + cSess := &ConnSession{ + Sess: s, + NetIp: common.AcquireIp(s.MacAddr), + closeOnce: sync.Once{}, + Closed: make(chan struct{}), + PayloadIn: make(chan *Payload), + PayloadOut: make(chan *Payload), + } + s.CSess = cSess + return cSess } -func (s *Session) Close() { - s.closeOnce.Do(func() { +func (cs *ConnSession) Close() { + cs.closeOnce.Do(func() { log.Println("closeOnce") - close(s.Closed) - s.IsActive = false - s.LastLogin = time.Now() - common.ReleaseIp(s.NetIp, s.MacAddr) - common.LimitClient(s.UserName, true) + close(cs.Closed) + cs.Sess.IsActive = false + cs.Sess.LastLogin = time.Now() + common.ReleaseIp(cs.NetIp, cs.Sess.MacAddr) + common.LimitClient(cs.Sess.UserName, true) }) } diff --git a/proxyproto/protocol.go b/proxyproto/protocol.go new file mode 100644 index 0000000..84a582e --- /dev/null +++ b/proxyproto/protocol.go @@ -0,0 +1,289 @@ +// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go +// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt +package proxyproto + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "log" + "net" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // prefix is the string we look for at the start of a connection + // to check if this connection is using the proxy protocol + prefix = []byte("PROXY ") + prefixLen = len(prefix) + + ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information") +) + +// SourceChecker can be used to decide whether to trust the PROXY info or pass +// the original connection address through. If set, the connecting address is +// passed in as an argument. If the function returns an error due to the source +// being disallowed, it should return ErrInvalidUpstream. +// +// If error is not nil, the call to Accept() will fail. If the reason for +// triggering this failure is due to a disallowed source, it should return +// ErrInvalidUpstream. +// +// If bool is true, the PROXY-set address is used. +// +// If bool is false, the connection's remote address is used, rather than the +// address claimed in the PROXY info. +type SourceChecker func(net.Addr) (bool, error) + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol (version 1). +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. +// +// Optionally define ProxyHeaderTimeout to set a maximum time to +// receive the Proxy Protocol Header. Zero means no timeout. +type Listener struct { + Listener net.Listener + ProxyHeaderTimeout time.Duration + SourceCheck SourceChecker + UnknownOK bool // allow PROXY UNKNOWN +} + +// Conn is used to wrap and underlying connection which +// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will +// return the address of the client instead of the proxy address. +type Conn struct { + bufReader *bufio.Reader + conn net.Conn + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr + useConnAddr bool + once sync.Once + proxyHeaderTimeout time.Duration + unknownOK bool +} + +// Accept waits for and returns the next connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + var useConnAddr bool + if p.SourceCheck != nil { + allowed, err := p.SourceCheck(conn.RemoteAddr()) + if err != nil { + return nil, err + } + if !allowed { + useConnAddr = true + } + } + newConn := NewConn(conn, p.ProxyHeaderTimeout) + newConn.useConnAddr = useConnAddr + newConn.unknownOK = p.UnknownOK + return newConn, nil +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} + +// NewConn is used to wrap a net.Conn that may be speaking +// the proxy protocol into a proxyproto.Conn +func NewConn(conn net.Conn, timeout time.Duration) *Conn { + pConn := &Conn{ + bufReader: bufio.NewReader(conn), + conn: conn, + proxyHeaderTimeout: timeout, + } + return pConn +} + +// Read is check for the proxy protocol header when doing +// the initial scan. If there is an error parsing the header, +// it is returned and the socket is closed. +func (p *Conn) Read(b []byte) (int, error) { + var err error + p.once.Do(func() { err = p.checkPrefix() }) + if err != nil { + return 0, err + } + return p.bufReader.Read(b) +} + +func (p *Conn) ReadFrom(r io.Reader) (int64, error) { + if rf, ok := p.conn.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(p.conn, r) +} + +func (p *Conn) WriteTo(w io.Writer) (int64, error) { + var err error + p.once.Do(func() { err = p.checkPrefix() }) + if err != nil { + return 0, err + } + return p.bufReader.WriteTo(w) +} + +func (p *Conn) Write(b []byte) (int, error) { + return p.conn.Write(b) +} + +func (p *Conn) Close() error { + return p.conn.Close() +} + +func (p *Conn) LocalAddr() net.Addr { + p.checkPrefixOnce() + if p.dstAddr != nil && !p.useConnAddr { + return p.dstAddr + } + return p.conn.LocalAddr() +} + +// RemoteAddr returns the address of the client if the proxy +// protocol is being used, otherwise just returns the address of +// the socket peer. If there is an error parsing the header, the +// address of the client is not returned, and the socket is closed. +// Once implication of this is that the call could block if the +// client is slow. Using a Deadline is recommended if this is called +// before Read() +func (p *Conn) RemoteAddr() net.Addr { + p.checkPrefixOnce() + if p.srcAddr != nil && !p.useConnAddr { + return p.srcAddr + } + return p.conn.RemoteAddr() +} + +func (p *Conn) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +func (p *Conn) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +func (p *Conn) checkPrefixOnce() { + p.once.Do(func() { + if err := p.checkPrefix(); err != nil && err != io.EOF { + log.Printf("[ERR] Failed to read proxy prefix: %v", err) + p.Close() + p.bufReader = bufio.NewReader(p.conn) + } + }) +} + +func (p *Conn) checkPrefix() error { + if p.proxyHeaderTimeout != 0 { + readDeadLine := time.Now().Add(p.proxyHeaderTimeout) + p.conn.SetReadDeadline(readDeadLine) + defer p.conn.SetReadDeadline(time.Time{}) + } + + // Incrementally check each byte of the prefix + for i := 1; i <= prefixLen; i++ { + inp, err := p.bufReader.Peek(i) + + if err != nil { + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return nil + } else { + return err + } + } + + // Check for a prefix mis-match, quit early + if !bytes.Equal(inp, prefix[:i]) { + return nil + } + } + + // Read the header line + header, err := p.bufReader.ReadString('\n') + if err != nil { + p.conn.Close() + return err + } + + // Strip the carriage return and new line + header = header[:len(header)-2] + + // Split on spaces, should be (PROXY ) + parts := strings.Split(header, " ") + if len(parts) < 2 { + p.conn.Close() + return fmt.Errorf("Invalid header line: %s", header) + } + + // Verify the type is known + switch parts[1] { + case "UNKNOWN": + if !p.unknownOK || len(parts) != 2 { + p.conn.Close() + return fmt.Errorf("Invalid UNKNOWN header line: %s", header) + } + p.useConnAddr = true + return nil + case "TCP4": + case "TCP6": + default: + p.conn.Close() + return fmt.Errorf("Unhandled address type: %s", parts[1]) + } + + if len(parts) != 6 { + p.conn.Close() + return fmt.Errorf("Invalid header line: %s", header) + } + + // Parse out the source address + ip := net.ParseIP(parts[2]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid source ip: %s", parts[2]) + } + port, err := strconv.Atoi(parts[4]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid source port: %s", parts[4]) + } + p.srcAddr = &net.TCPAddr{IP: ip, Port: port} + + // Parse out the destination address + ip = net.ParseIP(parts[3]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid destination ip: %s", parts[3]) + } + port, err = strconv.Atoi(parts[5]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid destination port: %s", parts[5]) + } + p.dstAddr = &net.TCPAddr{IP: ip, Port: port} + + return nil +} + + + diff --git a/proxyproto/protocol_test.go b/proxyproto/protocol_test.go new file mode 100644 index 0000000..43484e7 --- /dev/null +++ b/proxyproto/protocol_test.go @@ -0,0 +1,488 @@ +// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go +package proxyproto + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +const ( + goodAddr = "127.0.0.1" + badAddr = "127.0.0.2" + errAddr = "9999.0.0.2" +) + +var ( + checkAddr string +) + +func TestPassthrough(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestTimeout(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + clientWriteDelay := 200 * time.Millisecond + proxyHeaderTimeout := 50 * time.Millisecond + pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Do not send data for a while + time.Sleep(clientWriteDelay) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Check the remote addr is the original 127.0.0.1 + remoteAddrStartTime := time.Now() + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "127.0.0.1" { + t.Fatalf("bad: %v", addr) + } + remoteAddrDuration := time.Since(remoteAddrStartTime) + + // Check RemoteAddr() call did timeout + if remoteAddrDuration >= clientWriteDelay { + t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration) + } + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestParse_ipv4(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the remote addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "10.1.1.1" { + t.Fatalf("bad: %v", addr) + } + if addr.Port != 1000 { + t.Fatalf("bad: %v", addr) + } +} + +func TestParse_ipv6(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the remote addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "ffff::ffff" { + t.Fatalf("bad: %v", addr) + } + if addr.Port != 1000 { + t.Fatalf("bad: %v", addr) + } +} + +func TestParse_Unknown(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l, UnknownOK: true} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY UNKNOWN\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + +} + +func TestParse_BadHeader(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err == nil { + t.Fatalf("err: %v", err) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Check the remote addr, should be the local addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "127.0.0.1" { + t.Fatalf("bad: %v", addr) + } + + // Read should fail + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err == nil { + t.Fatalf("err: %v", err) + } +} + +func TestParse_ipv4_checkfunc(t *testing.T) { + checkAddr = goodAddr + testParse_ipv4_checkfunc(t) + checkAddr = badAddr + testParse_ipv4_checkfunc(t) + checkAddr = errAddr + testParse_ipv4_checkfunc(t) +} + +func testParse_ipv4_checkfunc(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + checkFunc := func(addr net.Addr) (bool, error) { + tcpAddr := addr.(*net.TCPAddr) + if tcpAddr.IP.String() == checkAddr { + return true, nil + } + return false, nil + } + + pl := &Listener{Listener: l, SourceCheck: checkFunc} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + if checkAddr == badAddr { + return + } + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the remote addr + addr := conn.RemoteAddr().(*net.TCPAddr) + switch checkAddr { + case goodAddr: + if addr.IP.String() != "10.1.1.1" { + t.Fatalf("bad: %v", addr) + } + if addr.Port != 1000 { + t.Fatalf("bad: %v", addr) + } + case badAddr: + if addr.IP.String() != "127.0.0.1" { + t.Fatalf("bad: %v", addr) + } + if addr.Port == 1000 { + t.Fatalf("bad: %v", addr) + } + } +} + +type testConn struct { + readFromCalledWith io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c *testConn) ReadFrom(r io.Reader) (int64, error) { + c.readFromCalledWith = r + return 0, nil +} +func (c *testConn) Write(p []byte) (int, error) { + return len(p), nil +} +func (c *testConn) Read(p []byte) (int, error) { + return 1, nil +} + +func TestCopyToWrappedConnection(t *testing.T) { + innerConn := &testConn{} + wrappedConn := NewConn(innerConn, 0) + dummySrc := &testConn{} + + io.Copy(wrappedConn, dummySrc) + if innerConn.readFromCalledWith != dummySrc { + t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection") + } +} + +func TestCopyFromWrappedConnection(t *testing.T) { + wrappedConn := NewConn(&testConn{}, 0) + dummyDst := &testConn{} + + io.Copy(dummyDst, wrappedConn) + if dummyDst.readFromCalledWith != wrappedConn.conn { + t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination") + } +} + +func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { + innerConn1 := &testConn{} + wrappedConn1 := NewConn(innerConn1, 0) + innerConn2 := &testConn{} + wrappedConn2 := NewConn(innerConn2, 0) + + io.Copy(wrappedConn1, wrappedConn2) + if innerConn1.readFromCalledWith != innerConn2 { + t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection") + } + +} + +