From b0f5456d6833c8e57b220fec862014b8d89cd296 Mon Sep 17 00:00:00 2001 From: bjdgyc <bjdgyc@163.com> Date: Fri, 6 Jan 2023 10:03:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dtest=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/pkg/proxyproto/protocol.go | 290 --------------- server/pkg/proxyproto/protocol_test.go | 486 ------------------------- server/sessdata/ip_pool_test.go | 10 +- 3 files changed, 5 insertions(+), 781 deletions(-) delete mode 100644 server/pkg/proxyproto/protocol.go delete mode 100644 server/pkg/proxyproto/protocol_test.go diff --git a/server/pkg/proxyproto/protocol.go b/server/pkg/proxyproto/protocol.go deleted file mode 100644 index f91f0b0..0000000 --- a/server/pkg/proxyproto/protocol.go +++ /dev/null @@ -1,290 +0,0 @@ -// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go -// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt - -// HAProxy proxy proto v1 -package proxyproto - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "log" - "net" - "strconv" - "strings" - "sync" - "time" -) - -var ( - // prefix is the string we look for at the start of a connection - // to check if this connection is using the proxy protocol - prefix = []byte("PROXY ") - prefixLen = len(prefix) - - ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information") -) - -// SourceChecker can be used to decide whether to trust the PROXY info or pass -// the original connection address through. If set, the connecting address is -// passed in as an argument. If the function returns an error due to the source -// being disallowed, it should return ErrInvalidUpstream. -// -// If error is not nil, the call to Accept() will fail. If the reason for -// triggering this failure is due to a disallowed source, it should return -// ErrInvalidUpstream. -// -// If bool is true, the PROXY-set address is used. -// -// If bool is false, the connection's remote address is used, rather than the -// address claimed in the PROXY info. -type SourceChecker func(net.Addr) (bool, error) - -// Listener is used to wrap an underlying listener, -// whose connections may be using the HAProxy Proxy Protocol (version 1). -// If the connection is using the protocol, the RemoteAddr() will return -// the correct client address. -// -// Optionally define ProxyHeaderTimeout to set a maximum time to -// receive the Proxy Protocol Header. Zero means no timeout. -type Listener struct { - Listener net.Listener - ProxyHeaderTimeout time.Duration - SourceCheck SourceChecker - UnknownOK bool // allow PROXY UNKNOWN -} - -// Conn is used to wrap and underlying connection which -// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will -// return the address of the client instead of the proxy address. -type Conn struct { - bufReader *bufio.Reader - conn net.Conn - dstAddr *net.TCPAddr - srcAddr *net.TCPAddr - useConnAddr bool - once sync.Once - proxyHeaderTimeout time.Duration - unknownOK bool -} - -// Accept waits for and returns the next connection to the listener. -func (p *Listener) Accept() (net.Conn, error) { - // Get the underlying connection - conn, err := p.Listener.Accept() - if err != nil { - return nil, err - } - var useConnAddr bool - if p.SourceCheck != nil { - allowed, err := p.SourceCheck(conn.RemoteAddr()) - if err != nil { - return nil, err - } - if !allowed { - useConnAddr = true - } - } - newConn := NewConn(conn, p.ProxyHeaderTimeout) - newConn.useConnAddr = useConnAddr - newConn.unknownOK = p.UnknownOK - return newConn, nil -} - -// Close closes the underlying listener. -func (p *Listener) Close() error { - return p.Listener.Close() -} - -// Addr returns the underlying listener's network address. -func (p *Listener) Addr() net.Addr { - return p.Listener.Addr() -} - -// NewConn is used to wrap a net.Conn that may be speaking -// the proxy protocol into a proxyproto.Conn -func NewConn(conn net.Conn, timeout time.Duration) *Conn { - pConn := &Conn{ - bufReader: bufio.NewReader(conn), - conn: conn, - proxyHeaderTimeout: timeout, - } - return pConn -} - -// Read is check for the proxy protocol header when doing -// the initial scan. If there is an error parsing the header, -// it is returned and the socket is closed. -func (p *Conn) Read(b []byte) (int, error) { - var err error - p.once.Do(func() { err = p.checkPrefix() }) - if err != nil { - return 0, err - } - return p.bufReader.Read(b) -} - -func (p *Conn) ReadFrom(r io.Reader) (int64, error) { - if rf, ok := p.conn.(io.ReaderFrom); ok { - return rf.ReadFrom(r) - } - return io.Copy(p.conn, r) -} - -func (p *Conn) WriteTo(w io.Writer) (int64, error) { - var err error - p.once.Do(func() { err = p.checkPrefix() }) - if err != nil { - return 0, err - } - return p.bufReader.WriteTo(w) -} - -func (p *Conn) Write(b []byte) (int, error) { - return p.conn.Write(b) -} - -func (p *Conn) Close() error { - return p.conn.Close() -} - -func (p *Conn) LocalAddr() net.Addr { - p.checkPrefixOnce() - if p.dstAddr != nil && !p.useConnAddr { - return p.dstAddr - } - return p.conn.LocalAddr() -} - -// RemoteAddr returns the address of the client if the proxy -// protocol is being used, otherwise just returns the address of -// the socket peer. If there is an error parsing the header, the -// address of the client is not returned, and the socket is closed. -// Once implication of this is that the call could block if the -// client is slow. Using a Deadline is recommended if this is called -// before Read() -func (p *Conn) RemoteAddr() net.Addr { - p.checkPrefixOnce() - if p.srcAddr != nil && !p.useConnAddr { - return p.srcAddr - } - return p.conn.RemoteAddr() -} - -func (p *Conn) SetDeadline(t time.Time) error { - return p.conn.SetDeadline(t) -} - -func (p *Conn) SetReadDeadline(t time.Time) error { - return p.conn.SetReadDeadline(t) -} - -func (p *Conn) SetWriteDeadline(t time.Time) error { - return p.conn.SetWriteDeadline(t) -} - -func (p *Conn) checkPrefixOnce() { - p.once.Do(func() { - if err := p.checkPrefix(); err != nil && err != io.EOF { - log.Printf("[ERR] Failed to read proxy prefix: %v", err) - p.Close() - p.bufReader = bufio.NewReader(p.conn) - } - }) -} - -func (p *Conn) checkPrefix() error { - if p.proxyHeaderTimeout != 0 { - readDeadLine := time.Now().Add(p.proxyHeaderTimeout) - _ = p.conn.SetReadDeadline(readDeadLine) - defer func() { - _ = p.conn.SetReadDeadline(time.Time{}) - }() - } - - // Incrementally check each byte of the prefix - for i := 1; i <= prefixLen; i++ { - inp, err := p.bufReader.Peek(i) - - if err != nil { - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - return nil - } else { - return err - } - } - - // Check for a prefix mis-match, quit early - if !bytes.Equal(inp, prefix[:i]) { - return nil - } - } - - // Read the header line - header, err := p.bufReader.ReadString('\n') - if err != nil { - p.conn.Close() - return err - } - - // Strip the carriage return and new line - header = header[:len(header)-2] - - // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>) - parts := strings.Split(header, " ") - if len(parts) < 2 { - p.conn.Close() - return fmt.Errorf("Invalid header line: %s", header) - } - - // Verify the type is known - switch parts[1] { - case "UNKNOWN": - if !p.unknownOK || len(parts) != 2 { - p.conn.Close() - return fmt.Errorf("Invalid UNKNOWN header line: %s", header) - } - p.useConnAddr = true - return nil - case "TCP4": - case "TCP6": - default: - p.conn.Close() - return fmt.Errorf("Unhandled address type: %s", parts[1]) - } - - if len(parts) != 6 { - p.conn.Close() - return fmt.Errorf("Invalid header line: %s", header) - } - - // Parse out the source address - ip := net.ParseIP(parts[2]) - if ip == nil { - p.conn.Close() - return fmt.Errorf("Invalid source ip: %s", parts[2]) - } - port, err := strconv.Atoi(parts[4]) - if err != nil { - p.conn.Close() - return fmt.Errorf("Invalid source port: %s", parts[4]) - } - p.srcAddr = &net.TCPAddr{IP: ip, Port: port} - - // Parse out the destination address - ip = net.ParseIP(parts[3]) - if ip == nil { - p.conn.Close() - return fmt.Errorf("Invalid destination ip: %s", parts[3]) - } - port, err = strconv.Atoi(parts[5]) - if err != nil { - p.conn.Close() - return fmt.Errorf("Invalid destination port: %s", parts[5]) - } - p.dstAddr = &net.TCPAddr{IP: ip, Port: port} - - return nil -} diff --git a/server/pkg/proxyproto/protocol_test.go b/server/pkg/proxyproto/protocol_test.go deleted file mode 100644 index 1ad37aa..0000000 --- a/server/pkg/proxyproto/protocol_test.go +++ /dev/null @@ -1,486 +0,0 @@ -// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go -package proxyproto - -import ( - "bytes" - "io" - "net" - "testing" - "time" -) - -const ( - goodAddr = "127.0.0.1" - badAddr = "127.0.0.2" - errAddr = "9999.0.0.2" -) - -var ( - checkAddr string -) - -func TestPassthrough(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - pl := &Listener{Listener: l} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } -} - -func TestTimeout(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - clientWriteDelay := 200 * time.Millisecond - proxyHeaderTimeout := 50 * time.Millisecond - pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Do not send data for a while - time.Sleep(clientWriteDelay) - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Check the remote addr is the original 127.0.0.1 - remoteAddrStartTime := time.Now() - addr := conn.RemoteAddr().(*net.TCPAddr) - if addr.IP.String() != "127.0.0.1" { - t.Fatalf("bad: %v", addr) - } - remoteAddrDuration := time.Since(remoteAddrStartTime) - - // Check RemoteAddr() call did timeout - if remoteAddrDuration >= clientWriteDelay { - t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration) - } - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } -} - -func TestParse_ipv4(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - pl := &Listener{Listener: l} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Write out the header! - header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" - conn.Write([]byte(header)) - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } - - // Check the remote addr - addr := conn.RemoteAddr().(*net.TCPAddr) - if addr.IP.String() != "10.1.1.1" { - t.Fatalf("bad: %v", addr) - } - if addr.Port != 1000 { - t.Fatalf("bad: %v", addr) - } -} - -func TestParse_ipv6(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - pl := &Listener{Listener: l} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Write out the header! - header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n" - conn.Write([]byte(header)) - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } - - // Check the remote addr - addr := conn.RemoteAddr().(*net.TCPAddr) - if addr.IP.String() != "ffff::ffff" { - t.Fatalf("bad: %v", addr) - } - if addr.Port != 1000 { - t.Fatalf("bad: %v", addr) - } -} - -func TestParse_Unknown(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - pl := &Listener{Listener: l, UnknownOK: true} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Write out the header! - header := "PROXY UNKNOWN\r\n" - conn.Write([]byte(header)) - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } - -} - -func TestParse_BadHeader(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - pl := &Listener{Listener: l} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Write out the header! - header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n" - conn.Write([]byte(header)) - - conn.Write([]byte("ping")) - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err == nil { - t.Fatalf("err: %v", err) - } - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Check the remote addr, should be the local addr - addr := conn.RemoteAddr().(*net.TCPAddr) - if addr.IP.String() != "127.0.0.1" { - t.Fatalf("bad: %v", addr) - } - - // Read should fail - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err == nil { - t.Fatalf("err: %v", err) - } -} - -func TestParse_ipv4_checkfunc(t *testing.T) { - checkAddr = goodAddr - testParse_ipv4_checkfunc(t) - checkAddr = badAddr - testParse_ipv4_checkfunc(t) - checkAddr = errAddr - testParse_ipv4_checkfunc(t) -} - -func testParse_ipv4_checkfunc(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) - } - - checkFunc := func(addr net.Addr) (bool, error) { - tcpAddr := addr.(*net.TCPAddr) - if tcpAddr.IP.String() == checkAddr { - return true, nil - } - return false, nil - } - - pl := &Listener{Listener: l, SourceCheck: checkFunc} - - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - // Write out the header! - header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" - conn.Write([]byte(header)) - - conn.Write([]byte("ping")) - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("pong")) { - t.Fatalf("bad: %v", recv) - } - }() - - conn, err := pl.Accept() - if err != nil { - if checkAddr == badAddr { - return - } - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - _, err = conn.Read(recv) - if err != nil { - t.Fatalf("err: %v", err) - } - if !bytes.Equal(recv, []byte("ping")) { - t.Fatalf("bad: %v", recv) - } - - if _, err := conn.Write([]byte("pong")); err != nil { - t.Fatalf("err: %v", err) - } - - // Check the remote addr - addr := conn.RemoteAddr().(*net.TCPAddr) - switch checkAddr { - case goodAddr: - if addr.IP.String() != "10.1.1.1" { - t.Fatalf("bad: %v", addr) - } - if addr.Port != 1000 { - t.Fatalf("bad: %v", addr) - } - case badAddr: - if addr.IP.String() != "127.0.0.1" { - t.Fatalf("bad: %v", addr) - } - if addr.Port == 1000 { - t.Fatalf("bad: %v", addr) - } - } -} - -type testConn struct { - readFromCalledWith io.Reader - net.Conn // nil; crash on any unexpected use -} - -func (c *testConn) ReadFrom(r io.Reader) (int64, error) { - c.readFromCalledWith = r - return 0, nil -} -func (c *testConn) Write(p []byte) (int, error) { - return len(p), nil -} -func (c *testConn) Read(p []byte) (int, error) { - return 1, nil -} - -func TestCopyToWrappedConnection(t *testing.T) { - innerConn := &testConn{} - wrappedConn := NewConn(innerConn, 0) - dummySrc := &testConn{} - - io.Copy(wrappedConn, dummySrc) - if innerConn.readFromCalledWith != dummySrc { - t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection") - } -} - -func TestCopyFromWrappedConnection(t *testing.T) { - wrappedConn := NewConn(&testConn{}, 0) - dummyDst := &testConn{} - - io.Copy(dummyDst, wrappedConn) - if dummyDst.readFromCalledWith != wrappedConn.conn { - t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination") - } -} - -func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { - innerConn1 := &testConn{} - wrappedConn1 := NewConn(innerConn1, 0) - innerConn2 := &testConn{} - wrappedConn2 := NewConn(innerConn2, 0) - - io.Copy(wrappedConn1, wrappedConn2) - if innerConn1.readFromCalledWith != innerConn2 { - t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection") - } - -} diff --git a/server/sessdata/ip_pool_test.go b/server/sessdata/ip_pool_test.go index 5db01cd..5a2086e 100644 --- a/server/sessdata/ip_pool_test.go +++ b/server/sessdata/ip_pool_test.go @@ -47,21 +47,21 @@ func TestIpPool(t *testing.T) { var ip net.IP for i := 1; i <= 100; i++ { - _ = AcquireIp("user", fmt.Sprintf("mac-%d", i)) + _ = AcquireIp("user", fmt.Sprintf("mac-%d", i), true) } - ip = AcquireIp("user", "mac-new") + ip = AcquireIp("user", "mac-new", true) assert.True(net.IPv4(192, 168, 3, 101).Equal(ip)) for i := 102; i <= 199; i++ { - ip = AcquireIp("user", fmt.Sprintf("mac-%d", i)) + ip = AcquireIp("user", fmt.Sprintf("mac-%d", i), true) } assert.True(net.IPv4(192, 168, 3, 199).Equal(ip)) - ip = AcquireIp("user", "mac-nil") + ip = AcquireIp("user", "mac-nil", true) assert.Nil(ip) ReleaseIp(net.IPv4(192, 168, 3, 88), "mac-88") ReleaseIp(net.IPv4(192, 168, 3, 188), "mac-188") // 从头循环获取可用ip - ip = AcquireIp("user", "mac-188") + ip = AcquireIp("user", "mac-188", true) t.Log("mac-188", ip) assert.True(net.IPv4(192, 168, 3, 188).Equal(ip)) }