修复断线重连后的bug

添加 proxy protocol v1 的支持
This commit is contained in:
bjdgyc 2020-08-20 19:32:19 +08:00
parent eee99fcb0d
commit bb0de6690a
14 changed files with 858 additions and 50 deletions

View File

@ -8,7 +8,7 @@ AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannop
AnyLink 使用TLS/DTLS进行数据加密因此需要RSA或ECC证书可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。
AnyLink 服务端仅在CentOs7测试通过如需要安装在其他系统需要服务端支持tun功能、ip设置命令。
AnyLink 服务端仅在CentOS7测试通过如需要安装在其他系统需要服务端支持tun功能、ip设置命令。
## Installation
@ -26,6 +26,7 @@ sudo ./anylink -conf="conf/server.toml"
- [x] TLS-TCP通道
- [x] 兼容AnyConnect
- [x] 多用户支持
- [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议
- [ ] DTLS-UDP通道
- [ ] 后台管理界面
- [ ] 用户组支持

View File

@ -27,6 +27,7 @@ type ServerConfig struct {
UserFile string `toml:"user_file"`
ServerAddr string `toml:"server_addr"`
DebugAddr string `toml:"debug_addr"`
ProxyProtocol bool `toml:"proxy_protocol"`
CertFile string `toml:"cert_file"`
CertKey string `toml:"cert_key"`
LinkGroups []string `toml:"link_groups"`

3
conf/.gitignore vendored
View File

@ -1,3 +0,0 @@
#过滤本地证书文件
vpn_cert.key
vpn_cert.pem

View File

@ -10,6 +10,8 @@ cert_key = "./vpn_cert.key"
#服务监听的地址
server_addr = ":443"
debug_addr = "127.0.0.1:8800"
#开启tcp proxy protocol协议
proxy_protocol = false
#用户组
link_groups = ["one", "two"]

2
go.mod
View File

@ -6,5 +6,5 @@ require (
github.com/julienschmidt/httprouter v1.3.0
github.com/pelletier/go-toml v1.8.0
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f // indirect
)

4
go.sum
View File

@ -8,8 +8,8 @@ github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMF
github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw=
golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f h1:KJuwZVtZBVzDmEDtB2zro9CXkD9O0dpCv4o2LHbQIAw=
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=

View File

@ -47,6 +47,7 @@ func checkVpnClient(h httprouter.Handle) httprouter.Handle {
// TODO 调试信息输出
// hd, _ := httputil.DumpRequest(r, true)
// fmt.Println("DumpRequest: ", string(hd))
fmt.Println(r.RemoteAddr)
user_Agent := strings.ToLower(r.UserAgent())
x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")

View File

@ -2,6 +2,7 @@ package handler
import (
"encoding/binary"
"fmt"
"log"
"net"
"time"
@ -9,12 +10,12 @@ import (
"github.com/bjdgyc/anylink/common"
)
func LinkCstp(conn net.Conn, sess *Session) {
func LinkCstp(conn net.Conn, sess *ConnSession) {
// fmt.Println("HandlerCstp")
defer func() {
log.Println("LinkCstp return")
conn.Close()
sess.Close()
log.Println("LinkCstp return")
}()
var (
@ -47,7 +48,7 @@ func LinkCstp(conn net.Conn, sess *Session) {
// fmt.Println("DISCONNECT")
return
case 0x03: // DPD-REQ
// fmt.Println("DPD-REQ")
fmt.Println("DPD-REQ")
payload := &Payload{
ptype: 0x04, // DPD-RESP
}
@ -73,11 +74,11 @@ func LinkCstp(conn net.Conn, sess *Session) {
}
}
func cstpWrite(conn net.Conn, sess *Session) {
func cstpWrite(conn net.Conn, sess *ConnSession) {
defer func() {
log.Println("cstpWrite return")
conn.Close()
sess.Close()
log.Println("cstpWrite return")
}()
var (

View File

@ -29,10 +29,10 @@ func testTun() {
}
// 创建tun网卡
func LinkTun(sess *Session) {
func LinkTun(sess *ConnSession) {
defer func() {
sess.Close()
log.Println("LinkTun return")
sess.Close()
}()
cfg := water.Config{
@ -84,7 +84,11 @@ func LinkTun(sess *Session) {
}
func tunRead(ifce *water.Interface, sess *Session) {
func tunRead(ifce *water.Interface, sess *ConnSession) {
defer func() {
log.Println("tunRead return")
ifce.Close()
}()
var (
err error
n int

View File

@ -2,6 +2,7 @@ package handler
import (
"fmt"
"log"
"net/http"
"os"
@ -19,7 +20,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
// TODO 调试信息输出
// hd, _ := httputil.DumpRequest(r, true)
// fmt.Println("DumpRequest: ", string(hd))
// fmt.Println(r.RemoteAddr)
fmt.Println("LinkTunnel", r.RemoteAddr)
// 判断session-token的值
cookie, err := r.Cookie("webvpn")
@ -35,8 +36,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
}
// 开启link
sess.StartLink()
if sess.NetIp == nil {
cSess := sess.StartConn()
if cSess == nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
return
}
@ -44,14 +46,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
// 客户端信息
cstp_mtu := r.Header.Get("X-CSTP-MTU")
master_Secret := r.Header.Get("X-DTLS-Master-Secret")
sess.MasterSecret = master_Secret
sess.Mtu = cstp_mtu
sess.RemoteAddr = r.RemoteAddr
cSess.MasterSecret = master_Secret
cSess.Mtu = cstp_mtu
cSess.RemoteAddr = r.RemoteAddr
w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER))
w.Header().Set("X-CSTP-Version", "1")
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
w.Header().Set("X-CSTP-Address", sess.NetIp.String()) // 分配的ip地址
w.Header().Set("X-CSTP-Address", cSess.NetIp.String()) // 分配的ip地址
w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
for _, v := range common.ServerCfg.ClientDns {
@ -115,6 +117,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
}
// 开始数据处理
go LinkTun(sess)
go LinkCstp(conn, sess)
go LinkTun(cSess)
go LinkCstp(conn, cSess)
}

View File

@ -3,11 +3,13 @@ package handler
import (
"crypto/tls"
"fmt"
"github.com/bjdgyc/anylink/proxyproto"
"log"
"net"
"net/http"
"net/http/httputil"
_ "net/http/pprof"
"time"
"github.com/bjdgyc/anylink/common"
"github.com/julienschmidt/httprouter"
@ -40,13 +42,18 @@ func startTls() {
TLSConfig: tlsConfig,
}
var ln net.Listener
ln, err := net.Listen("tcp", addr)
if err != nil {
log.Fatal(err)
}
defer ln.Close()
srv.SetKeepAlivesEnabled(true)
if common.ServerCfg.ProxyProtocol {
ln = &proxyproto.Listener{Listener: ln, ProxyHeaderTimeout: time.Second * 5}
}
fmt.Println("listen ", addr)
err = srv.ServeTLS(ln, certFile, keyFile)
if err != nil {

View File

@ -17,25 +17,31 @@ var (
sessions = make(map[string]*Session) // session_token -> SessUser
)
// 连接sess
type ConnSession struct {
Sess *Session
MasterSecret string // dtls协议的 master_secret
NetIp net.IP // 分配的ip地址
RemoteAddr string
Mtu string
TunName string
closeOnce sync.Once
Closed chan struct{}
PayloadIn chan *Payload
PayloadOut chan *Payload
}
type Session struct {
Sid string // auth返回的 session-id
Token string // session信息的唯一token
DtlsSid string // dtls协议的 session_id
MacAddr string // 客户端mac地址
UserName string // 用户名
LastLogin time.Time
IsActive bool
// 开启link需要设置的参数
MasterSecret string // dtls协议的 master_secret
NetIp net.IP // 分配的ip地址
UserName string // 用户名
RemoteAddr string
Mtu string
TunName string
IsActive bool
LastLogin time.Time
closeOnce sync.Once
Closed chan struct{}
PayloadIn chan *Payload
PayloadOut chan *Payload
CSess *ConnSession
}
func init() {
@ -86,28 +92,37 @@ func NewSession() *Session {
return sess
}
func (s *Session) StartLink() {
func (s *Session) StartConn() *ConnSession {
if s.IsActive == true {
s.CSess.Close()
}
limit := common.LimitClient(s.UserName, false)
if limit == false {
s.NetIp = nil
return
// s.NetIp = nil
return nil
}
s.NetIp = common.AcquireIp(s.MacAddr)
s.IsActive = true
s.closeOnce = sync.Once{}
s.Closed = make(chan struct{})
s.PayloadIn = make(chan *Payload)
s.PayloadOut = make(chan *Payload)
cSess := &ConnSession{
Sess: s,
NetIp: common.AcquireIp(s.MacAddr),
closeOnce: sync.Once{},
Closed: make(chan struct{}),
PayloadIn: make(chan *Payload),
PayloadOut: make(chan *Payload),
}
s.CSess = cSess
return cSess
}
func (s *Session) Close() {
s.closeOnce.Do(func() {
func (cs *ConnSession) Close() {
cs.closeOnce.Do(func() {
log.Println("closeOnce")
close(s.Closed)
s.IsActive = false
s.LastLogin = time.Now()
common.ReleaseIp(s.NetIp, s.MacAddr)
common.LimitClient(s.UserName, true)
close(cs.Closed)
cs.Sess.IsActive = false
cs.Sess.LastLogin = time.Now()
common.ReleaseIp(cs.NetIp, cs.Sess.MacAddr)
common.LimitClient(cs.Sess.UserName, true)
})
}

289
proxyproto/protocol.go Normal file
View File

@ -0,0 +1,289 @@
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
package proxyproto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
)
// SourceChecker can be used to decide whether to trust the PROXY info or pass
// the original connection address through. If set, the connecting address is
// passed in as an argument. If the function returns an error due to the source
// being disallowed, it should return ErrInvalidUpstream.
//
// If error is not nil, the call to Accept() will fail. If the reason for
// triggering this failure is due to a disallowed source, it should return
// ErrInvalidUpstream.
//
// If bool is true, the PROXY-set address is used.
//
// If bool is false, the connection's remote address is used, rather than the
// address claimed in the PROXY info.
type SourceChecker func(net.Addr) (bool, error)
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
SourceCheck SourceChecker
UnknownOK bool // allow PROXY UNKNOWN
}
// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
useConnAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
unknownOK bool
}
// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
var useConnAddr bool
if p.SourceCheck != nil {
allowed, err := p.SourceCheck(conn.RemoteAddr())
if err != nil {
return nil, err
}
if !allowed {
useConnAddr = true
}
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.useConnAddr = useConnAddr
newConn.unknownOK = p.UnknownOK
return newConn, nil
}
// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
}
return pConn
}
// Read is check for the proxy protocol header when doing
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.Read(b)
}
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(p.conn, r)
}
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.WriteTo(w)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
func (p *Conn) Close() error {
return p.conn.Close()
}
func (p *Conn) LocalAddr() net.Addr {
p.checkPrefixOnce()
if p.dstAddr != nil && !p.useConnAddr {
return p.dstAddr
}
return p.conn.LocalAddr()
}
// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
p.checkPrefixOnce()
if p.srcAddr != nil && !p.useConnAddr {
return p.srcAddr
}
return p.conn.RemoteAddr()
}
func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}
func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *Conn) checkPrefixOnce() {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
p.Close()
p.bufReader = bufio.NewReader(p.conn)
}
})
}
func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
}
// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
return nil
}
}
// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.conn.Close()
return err
}
// Strip the carriage return and new line
header = header[:len(header)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.unknownOK || len(parts) != 2 {
p.conn.Close()
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
}
p.useConnAddr = true
return nil
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
return nil
}

488
proxyproto/protocol_test.go Normal file
View File

@ -0,0 +1,488 @@
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
package proxyproto
import (
"bytes"
"io"
"net"
"testing"
"time"
)
const (
goodAddr = "127.0.0.1"
badAddr = "127.0.0.2"
errAddr = "9999.0.0.2"
)
var (
checkAddr string
)
func TestPassthrough(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
clientWriteDelay := 200 * time.Millisecond
proxyHeaderTimeout := 50 * time.Millisecond
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Do not send data for a while
time.Sleep(clientWriteDelay)
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr is the original 127.0.0.1
remoteAddrStartTime := time.Now()
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
remoteAddrDuration := time.Since(remoteAddrStartTime)
// Check RemoteAddr() call did timeout
if remoteAddrDuration >= clientWriteDelay {
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
}
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_ipv6(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "ffff::ffff" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_Unknown(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l, UnknownOK: true}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY UNKNOWN\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_BadHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err == nil {
t.Fatalf("err: %v", err)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr, should be the local addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
// Read should fail
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err == nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4_checkfunc(t *testing.T) {
checkAddr = goodAddr
testParse_ipv4_checkfunc(t)
checkAddr = badAddr
testParse_ipv4_checkfunc(t)
checkAddr = errAddr
testParse_ipv4_checkfunc(t)
}
func testParse_ipv4_checkfunc(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
checkFunc := func(addr net.Addr) (bool, error) {
tcpAddr := addr.(*net.TCPAddr)
if tcpAddr.IP.String() == checkAddr {
return true, nil
}
return false, nil
}
pl := &Listener{Listener: l, SourceCheck: checkFunc}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
if checkAddr == badAddr {
return
}
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
switch checkAddr {
case goodAddr:
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
case badAddr:
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port == 1000 {
t.Fatalf("bad: %v", addr)
}
}
}
type testConn struct {
readFromCalledWith io.Reader
net.Conn // nil; crash on any unexpected use
}
func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
c.readFromCalledWith = r
return 0, nil
}
func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testConn) Read(p []byte) (int, error) {
return 1, nil
}
func TestCopyToWrappedConnection(t *testing.T) {
innerConn := &testConn{}
wrappedConn := NewConn(innerConn, 0)
dummySrc := &testConn{}
io.Copy(wrappedConn, dummySrc)
if innerConn.readFromCalledWith != dummySrc {
t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
}
}
func TestCopyFromWrappedConnection(t *testing.T) {
wrappedConn := NewConn(&testConn{}, 0)
dummyDst := &testConn{}
io.Copy(dummyDst, wrappedConn)
if dummyDst.readFromCalledWith != wrappedConn.conn {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
}
}
func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
innerConn1 := &testConn{}
wrappedConn1 := NewConn(innerConn1, 0)
innerConn2 := &testConn{}
wrappedConn2 := NewConn(innerConn2, 0)
io.Copy(wrappedConn1, wrappedConn2)
if innerConn1.readFromCalledWith != innerConn2 {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
}
}