mirror of
				https://github.com/bjdgyc/anylink.git
				synced 2025-11-04 11:06:22 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			487 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			487 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
 | 
						|
package proxyproto
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	goodAddr = "127.0.0.1"
 | 
						|
	badAddr  = "127.0.0.2"
 | 
						|
	errAddr  = "9999.0.0.2"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	checkAddr string
 | 
						|
)
 | 
						|
 | 
						|
func TestPassthrough(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestTimeout(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	clientWriteDelay := 200 * time.Millisecond
 | 
						|
	proxyHeaderTimeout := 50 * time.Millisecond
 | 
						|
	pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Do not send data for a while
 | 
						|
		time.Sleep(clientWriteDelay)
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	// Check the remote addr is the original 127.0.0.1
 | 
						|
	remoteAddrStartTime := time.Now()
 | 
						|
	addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
						|
	if addr.IP.String() != "127.0.0.1" {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
	remoteAddrDuration := time.Since(remoteAddrStartTime)
 | 
						|
 | 
						|
	// Check RemoteAddr() call did timeout
 | 
						|
	if remoteAddrDuration >= clientWriteDelay {
 | 
						|
		t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
 | 
						|
	}
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParse_ipv4(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Write out the header!
 | 
						|
		header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
 | 
						|
		conn.Write([]byte(header))
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Check the remote addr
 | 
						|
	addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
						|
	if addr.IP.String() != "10.1.1.1" {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
	if addr.Port != 1000 {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParse_ipv6(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Write out the header!
 | 
						|
		header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
 | 
						|
		conn.Write([]byte(header))
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Check the remote addr
 | 
						|
	addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
						|
	if addr.IP.String() != "ffff::ffff" {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
	if addr.Port != 1000 {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParse_Unknown(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l, UnknownOK: true}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Write out the header!
 | 
						|
		header := "PROXY UNKNOWN\r\n"
 | 
						|
		conn.Write([]byte(header))
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func TestParse_BadHeader(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Write out the header!
 | 
						|
		header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
 | 
						|
		conn.Write([]byte(header))
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err == nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	// Check the remote addr, should be the local addr
 | 
						|
	addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
						|
	if addr.IP.String() != "127.0.0.1" {
 | 
						|
		t.Fatalf("bad: %v", addr)
 | 
						|
	}
 | 
						|
 | 
						|
	// Read should fail
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err == nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParse_ipv4_checkfunc(t *testing.T) {
 | 
						|
	checkAddr = goodAddr
 | 
						|
	testParse_ipv4_checkfunc(t)
 | 
						|
	checkAddr = badAddr
 | 
						|
	testParse_ipv4_checkfunc(t)
 | 
						|
	checkAddr = errAddr
 | 
						|
	testParse_ipv4_checkfunc(t)
 | 
						|
}
 | 
						|
 | 
						|
func testParse_ipv4_checkfunc(t *testing.T) {
 | 
						|
	l, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	checkFunc := func(addr net.Addr) (bool, error) {
 | 
						|
		tcpAddr := addr.(*net.TCPAddr)
 | 
						|
		if tcpAddr.IP.String() == checkAddr {
 | 
						|
			return true, nil
 | 
						|
		}
 | 
						|
		return false, nil
 | 
						|
	}
 | 
						|
 | 
						|
	pl := &Listener{Listener: l, SourceCheck: checkFunc}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		conn, err := net.Dial("tcp", pl.Addr().String())
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		defer conn.Close()
 | 
						|
 | 
						|
		// Write out the header!
 | 
						|
		header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
 | 
						|
		conn.Write([]byte(header))
 | 
						|
 | 
						|
		conn.Write([]byte("ping"))
 | 
						|
		recv := make([]byte, 4)
 | 
						|
		_, err = conn.Read(recv)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
		if !bytes.Equal(recv, []byte("pong")) {
 | 
						|
			t.Fatalf("bad: %v", recv)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn, err := pl.Accept()
 | 
						|
	if err != nil {
 | 
						|
		if checkAddr == badAddr {
 | 
						|
			return
 | 
						|
		}
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	defer conn.Close()
 | 
						|
 | 
						|
	recv := make([]byte, 4)
 | 
						|
	_, err = conn.Read(recv)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
	if !bytes.Equal(recv, []byte("ping")) {
 | 
						|
		t.Fatalf("bad: %v", recv)
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := conn.Write([]byte("pong")); err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Check the remote addr
 | 
						|
	addr := conn.RemoteAddr().(*net.TCPAddr)
 | 
						|
	switch checkAddr {
 | 
						|
	case goodAddr:
 | 
						|
		if addr.IP.String() != "10.1.1.1" {
 | 
						|
			t.Fatalf("bad: %v", addr)
 | 
						|
		}
 | 
						|
		if addr.Port != 1000 {
 | 
						|
			t.Fatalf("bad: %v", addr)
 | 
						|
		}
 | 
						|
	case badAddr:
 | 
						|
		if addr.IP.String() != "127.0.0.1" {
 | 
						|
			t.Fatalf("bad: %v", addr)
 | 
						|
		}
 | 
						|
		if addr.Port == 1000 {
 | 
						|
			t.Fatalf("bad: %v", addr)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type testConn struct {
 | 
						|
	readFromCalledWith io.Reader
 | 
						|
	net.Conn           // nil; crash on any unexpected use
 | 
						|
}
 | 
						|
 | 
						|
func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
 | 
						|
	c.readFromCalledWith = r
 | 
						|
	return 0, nil
 | 
						|
}
 | 
						|
func (c *testConn) Write(p []byte) (int, error) {
 | 
						|
	return len(p), nil
 | 
						|
}
 | 
						|
func (c *testConn) Read(p []byte) (int, error) {
 | 
						|
	return 1, nil
 | 
						|
}
 | 
						|
 | 
						|
func TestCopyToWrappedConnection(t *testing.T) {
 | 
						|
	innerConn := &testConn{}
 | 
						|
	wrappedConn := NewConn(innerConn, 0)
 | 
						|
	dummySrc := &testConn{}
 | 
						|
 | 
						|
	io.Copy(wrappedConn, dummySrc)
 | 
						|
	if innerConn.readFromCalledWith != dummySrc {
 | 
						|
		t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCopyFromWrappedConnection(t *testing.T) {
 | 
						|
	wrappedConn := NewConn(&testConn{}, 0)
 | 
						|
	dummyDst := &testConn{}
 | 
						|
 | 
						|
	io.Copy(dummyDst, wrappedConn)
 | 
						|
	if dummyDst.readFromCalledWith != wrappedConn.conn {
 | 
						|
		t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
 | 
						|
	innerConn1 := &testConn{}
 | 
						|
	wrappedConn1 := NewConn(innerConn1, 0)
 | 
						|
	innerConn2 := &testConn{}
 | 
						|
	wrappedConn2 := NewConn(innerConn2, 0)
 | 
						|
 | 
						|
	io.Copy(wrappedConn1, wrappedConn2)
 | 
						|
	if innerConn1.readFromCalledWith != innerConn2 {
 | 
						|
		t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
 | 
						|
	}
 | 
						|
 | 
						|
}
 |