mirror of
				https://github.com/bjdgyc/anylink.git
				synced 2025-11-01 00:59:34 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			291 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			291 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
 | |
| // design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
 | |
| 
 | |
| // HAProxy proxy proto v1
 | |
| package proxyproto
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	// prefix is the string we look for at the start of a connection
 | |
| 	// to check if this connection is using the proxy protocol
 | |
| 	prefix    = []byte("PROXY ")
 | |
| 	prefixLen = len(prefix)
 | |
| 
 | |
| 	ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
 | |
| )
 | |
| 
 | |
| // SourceChecker can be used to decide whether to trust the PROXY info or pass
 | |
| // the original connection address through. If set, the connecting address is
 | |
| // passed in as an argument. If the function returns an error due to the source
 | |
| // being disallowed, it should return ErrInvalidUpstream.
 | |
| //
 | |
| // If error is not nil, the call to Accept() will fail. If the reason for
 | |
| // triggering this failure is due to a disallowed source, it should return
 | |
| // ErrInvalidUpstream.
 | |
| //
 | |
| // If bool is true, the PROXY-set address is used.
 | |
| //
 | |
| // If bool is false, the connection's remote address is used, rather than the
 | |
| // address claimed in the PROXY info.
 | |
| type SourceChecker func(net.Addr) (bool, error)
 | |
| 
 | |
| // Listener is used to wrap an underlying listener,
 | |
| // whose connections may be using the HAProxy Proxy Protocol (version 1).
 | |
| // If the connection is using the protocol, the RemoteAddr() will return
 | |
| // the correct client address.
 | |
| //
 | |
| // Optionally define ProxyHeaderTimeout to set a maximum time to
 | |
| // receive the Proxy Protocol Header. Zero means no timeout.
 | |
| type Listener struct {
 | |
| 	Listener           net.Listener
 | |
| 	ProxyHeaderTimeout time.Duration
 | |
| 	SourceCheck        SourceChecker
 | |
| 	UnknownOK          bool // allow PROXY UNKNOWN
 | |
| }
 | |
| 
 | |
| // Conn is used to wrap and underlying connection which
 | |
| // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
 | |
| // return the address of the client instead of the proxy address.
 | |
| type Conn struct {
 | |
| 	bufReader          *bufio.Reader
 | |
| 	conn               net.Conn
 | |
| 	dstAddr            *net.TCPAddr
 | |
| 	srcAddr            *net.TCPAddr
 | |
| 	useConnAddr        bool
 | |
| 	once               sync.Once
 | |
| 	proxyHeaderTimeout time.Duration
 | |
| 	unknownOK          bool
 | |
| }
 | |
| 
 | |
| // Accept waits for and returns the next connection to the listener.
 | |
| func (p *Listener) Accept() (net.Conn, error) {
 | |
| 	// Get the underlying connection
 | |
| 	conn, err := p.Listener.Accept()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	var useConnAddr bool
 | |
| 	if p.SourceCheck != nil {
 | |
| 		allowed, err := p.SourceCheck(conn.RemoteAddr())
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if !allowed {
 | |
| 			useConnAddr = true
 | |
| 		}
 | |
| 	}
 | |
| 	newConn := NewConn(conn, p.ProxyHeaderTimeout)
 | |
| 	newConn.useConnAddr = useConnAddr
 | |
| 	newConn.unknownOK = p.UnknownOK
 | |
| 	return newConn, nil
 | |
| }
 | |
| 
 | |
| // Close closes the underlying listener.
 | |
| func (p *Listener) Close() error {
 | |
| 	return p.Listener.Close()
 | |
| }
 | |
| 
 | |
| // Addr returns the underlying listener's network address.
 | |
| func (p *Listener) Addr() net.Addr {
 | |
| 	return p.Listener.Addr()
 | |
| }
 | |
| 
 | |
| // NewConn is used to wrap a net.Conn that may be speaking
 | |
| // the proxy protocol into a proxyproto.Conn
 | |
| func NewConn(conn net.Conn, timeout time.Duration) *Conn {
 | |
| 	pConn := &Conn{
 | |
| 		bufReader:          bufio.NewReader(conn),
 | |
| 		conn:               conn,
 | |
| 		proxyHeaderTimeout: timeout,
 | |
| 	}
 | |
| 	return pConn
 | |
| }
 | |
| 
 | |
| // Read is check for the proxy protocol header when doing
 | |
| // the initial scan. If there is an error parsing the header,
 | |
| // it is returned and the socket is closed.
 | |
| func (p *Conn) Read(b []byte) (int, error) {
 | |
| 	var err error
 | |
| 	p.once.Do(func() { err = p.checkPrefix() })
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	return p.bufReader.Read(b)
 | |
| }
 | |
| 
 | |
| func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
 | |
| 	if rf, ok := p.conn.(io.ReaderFrom); ok {
 | |
| 		return rf.ReadFrom(r)
 | |
| 	}
 | |
| 	return io.Copy(p.conn, r)
 | |
| }
 | |
| 
 | |
| func (p *Conn) WriteTo(w io.Writer) (int64, error) {
 | |
| 	var err error
 | |
| 	p.once.Do(func() { err = p.checkPrefix() })
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 	return p.bufReader.WriteTo(w)
 | |
| }
 | |
| 
 | |
| func (p *Conn) Write(b []byte) (int, error) {
 | |
| 	return p.conn.Write(b)
 | |
| }
 | |
| 
 | |
| func (p *Conn) Close() error {
 | |
| 	return p.conn.Close()
 | |
| }
 | |
| 
 | |
| func (p *Conn) LocalAddr() net.Addr {
 | |
| 	p.checkPrefixOnce()
 | |
| 	if p.dstAddr != nil && !p.useConnAddr {
 | |
| 		return p.dstAddr
 | |
| 	}
 | |
| 	return p.conn.LocalAddr()
 | |
| }
 | |
| 
 | |
| // RemoteAddr returns the address of the client if the proxy
 | |
| // protocol is being used, otherwise just returns the address of
 | |
| // the socket peer. If there is an error parsing the header, the
 | |
| // address of the client is not returned, and the socket is closed.
 | |
| // Once implication of this is that the call could block if the
 | |
| // client is slow. Using a Deadline is recommended if this is called
 | |
| // before Read()
 | |
| func (p *Conn) RemoteAddr() net.Addr {
 | |
| 	p.checkPrefixOnce()
 | |
| 	if p.srcAddr != nil && !p.useConnAddr {
 | |
| 		return p.srcAddr
 | |
| 	}
 | |
| 	return p.conn.RemoteAddr()
 | |
| }
 | |
| 
 | |
| func (p *Conn) SetDeadline(t time.Time) error {
 | |
| 	return p.conn.SetDeadline(t)
 | |
| }
 | |
| 
 | |
| func (p *Conn) SetReadDeadline(t time.Time) error {
 | |
| 	return p.conn.SetReadDeadline(t)
 | |
| }
 | |
| 
 | |
| func (p *Conn) SetWriteDeadline(t time.Time) error {
 | |
| 	return p.conn.SetWriteDeadline(t)
 | |
| }
 | |
| 
 | |
| func (p *Conn) checkPrefixOnce() {
 | |
| 	p.once.Do(func() {
 | |
| 		if err := p.checkPrefix(); err != nil && err != io.EOF {
 | |
| 			log.Printf("[ERR] Failed to read proxy prefix: %v", err)
 | |
| 			p.Close()
 | |
| 			p.bufReader = bufio.NewReader(p.conn)
 | |
| 		}
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (p *Conn) checkPrefix() error {
 | |
| 	if p.proxyHeaderTimeout != 0 {
 | |
| 		readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
 | |
| 		_ = p.conn.SetReadDeadline(readDeadLine)
 | |
| 		defer func() {
 | |
| 			_ = p.conn.SetReadDeadline(time.Time{})
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	// Incrementally check each byte of the prefix
 | |
| 	for i := 1; i <= prefixLen; i++ {
 | |
| 		inp, err := p.bufReader.Peek(i)
 | |
| 
 | |
| 		if err != nil {
 | |
| 			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
 | |
| 				return nil
 | |
| 			} else {
 | |
| 				return err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Check for a prefix mis-match, quit early
 | |
| 		if !bytes.Equal(inp, prefix[:i]) {
 | |
| 			return nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Read the header line
 | |
| 	header, err := p.bufReader.ReadString('\n')
 | |
| 	if err != nil {
 | |
| 		p.conn.Close()
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Strip the carriage return and new line
 | |
| 	header = header[:len(header)-2]
 | |
| 
 | |
| 	// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
 | |
| 	parts := strings.Split(header, " ")
 | |
| 	if len(parts) < 2 {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid header line: %s", header)
 | |
| 	}
 | |
| 
 | |
| 	// Verify the type is known
 | |
| 	switch parts[1] {
 | |
| 	case "UNKNOWN":
 | |
| 		if !p.unknownOK || len(parts) != 2 {
 | |
| 			p.conn.Close()
 | |
| 			return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
 | |
| 		}
 | |
| 		p.useConnAddr = true
 | |
| 		return nil
 | |
| 	case "TCP4":
 | |
| 	case "TCP6":
 | |
| 	default:
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Unhandled address type: %s", parts[1])
 | |
| 	}
 | |
| 
 | |
| 	if len(parts) != 6 {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid header line: %s", header)
 | |
| 	}
 | |
| 
 | |
| 	// Parse out the source address
 | |
| 	ip := net.ParseIP(parts[2])
 | |
| 	if ip == nil {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid source ip: %s", parts[2])
 | |
| 	}
 | |
| 	port, err := strconv.Atoi(parts[4])
 | |
| 	if err != nil {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid source port: %s", parts[4])
 | |
| 	}
 | |
| 	p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
 | |
| 
 | |
| 	// Parse out the destination address
 | |
| 	ip = net.ParseIP(parts[3])
 | |
| 	if ip == nil {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid destination ip: %s", parts[3])
 | |
| 	}
 | |
| 	port, err = strconv.Atoi(parts[5])
 | |
| 	if err != nil {
 | |
| 		p.conn.Close()
 | |
| 		return fmt.Errorf("Invalid destination port: %s", parts[5])
 | |
| 	}
 | |
| 	p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
 | |
| 
 | |
| 	return nil
 | |
| }
 |