mirror of https://github.com/bjdgyc/anylink.git
parent
eee99fcb0d
commit
bb0de6690a
|
@ -8,7 +8,7 @@ AnyLink 基于 [ietf-openconnect](https://tools.ietf.org/html/draft-mavrogiannop
|
|||
|
||||
AnyLink 使用TLS/DTLS进行数据加密,因此需要RSA或ECC证书,可以通过 Let's Encrypt 和 TrustAsia 申请免费的SSL证书。
|
||||
|
||||
AnyLink 服务端仅在CentOs7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。
|
||||
AnyLink 服务端仅在CentOS7测试通过,如需要安装在其他系统,需要服务端支持tun功能、ip设置命令。
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -26,6 +26,7 @@ sudo ./anylink -conf="conf/server.toml"
|
|||
- [x] TLS-TCP通道
|
||||
- [x] 兼容AnyConnect
|
||||
- [x] 多用户支持
|
||||
- [x] 支持 [proxy protocol v1](http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt) 协议
|
||||
- [ ] DTLS-UDP通道
|
||||
- [ ] 后台管理界面
|
||||
- [ ] 用户组支持
|
||||
|
|
|
@ -27,6 +27,7 @@ type ServerConfig struct {
|
|||
UserFile string `toml:"user_file"`
|
||||
ServerAddr string `toml:"server_addr"`
|
||||
DebugAddr string `toml:"debug_addr"`
|
||||
ProxyProtocol bool `toml:"proxy_protocol"`
|
||||
CertFile string `toml:"cert_file"`
|
||||
CertKey string `toml:"cert_key"`
|
||||
LinkGroups []string `toml:"link_groups"`
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
#过滤本地证书文件
|
||||
vpn_cert.key
|
||||
vpn_cert.pem
|
|
@ -10,6 +10,8 @@ cert_key = "./vpn_cert.key"
|
|||
#服务监听的地址
|
||||
server_addr = ":443"
|
||||
debug_addr = "127.0.0.1:8800"
|
||||
#开启tcp proxy protocol协议
|
||||
proxy_protocol = false
|
||||
|
||||
#用户组
|
||||
link_groups = ["one", "two"]
|
||||
|
|
2
go.mod
2
go.mod
|
@ -6,5 +6,5 @@ require (
|
|||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/pelletier/go-toml v1.8.0
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
golang.org/x/sys v0.0.0-20200817155316-9781c653f443 // indirect
|
||||
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f // indirect
|
||||
)
|
||||
|
|
4
go.sum
4
go.sum
|
@ -8,8 +8,8 @@ github.com/pelletier/go-toml v1.8.0 h1:Keo9qb7iRJs2voHvunFtuuYFsbWeOBh8/P9v/kVMF
|
|||
github.com/pelletier/go-toml v1.8.0/go.mod h1:D6yutnOGMveHEPV7VQOuvI/gXY61bv+9bAOTRnLElKs=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
||||
golang.org/x/sys v0.0.0-20200817155316-9781c653f443 h1:X18bCaipMcoJGm27Nv7zr4XYPKGUy92GtqboKC2Hxaw=
|
||||
golang.org/x/sys v0.0.0-20200817155316-9781c653f443/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f h1:KJuwZVtZBVzDmEDtB2zro9CXkD9O0dpCv4o2LHbQIAw=
|
||||
golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
|
||||
|
|
|
@ -47,6 +47,7 @@ func checkVpnClient(h httprouter.Handle) httprouter.Handle {
|
|||
// TODO 调试信息输出
|
||||
// hd, _ := httputil.DumpRequest(r, true)
|
||||
// fmt.Println("DumpRequest: ", string(hd))
|
||||
fmt.Println(r.RemoteAddr)
|
||||
|
||||
user_Agent := strings.ToLower(r.UserAgent())
|
||||
x_Aggregate_Auth := r.Header.Get("X-Aggregate-Auth")
|
||||
|
|
|
@ -2,6 +2,7 @@ package handler
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -9,12 +10,12 @@ import (
|
|||
"github.com/bjdgyc/anylink/common"
|
||||
)
|
||||
|
||||
func LinkCstp(conn net.Conn, sess *Session) {
|
||||
func LinkCstp(conn net.Conn, sess *ConnSession) {
|
||||
// fmt.Println("HandlerCstp")
|
||||
defer func() {
|
||||
log.Println("LinkCstp return")
|
||||
conn.Close()
|
||||
sess.Close()
|
||||
log.Println("LinkCstp return")
|
||||
}()
|
||||
|
||||
var (
|
||||
|
@ -47,7 +48,7 @@ func LinkCstp(conn net.Conn, sess *Session) {
|
|||
// fmt.Println("DISCONNECT")
|
||||
return
|
||||
case 0x03: // DPD-REQ
|
||||
// fmt.Println("DPD-REQ")
|
||||
fmt.Println("DPD-REQ")
|
||||
payload := &Payload{
|
||||
ptype: 0x04, // DPD-RESP
|
||||
}
|
||||
|
@ -73,11 +74,11 @@ func LinkCstp(conn net.Conn, sess *Session) {
|
|||
}
|
||||
}
|
||||
|
||||
func cstpWrite(conn net.Conn, sess *Session) {
|
||||
func cstpWrite(conn net.Conn, sess *ConnSession) {
|
||||
defer func() {
|
||||
log.Println("cstpWrite return")
|
||||
conn.Close()
|
||||
sess.Close()
|
||||
log.Println("cstpWrite return")
|
||||
}()
|
||||
|
||||
var (
|
||||
|
|
|
@ -29,10 +29,10 @@ func testTun() {
|
|||
}
|
||||
|
||||
// 创建tun网卡
|
||||
func LinkTun(sess *Session) {
|
||||
func LinkTun(sess *ConnSession) {
|
||||
defer func() {
|
||||
sess.Close()
|
||||
log.Println("LinkTun return")
|
||||
sess.Close()
|
||||
}()
|
||||
|
||||
cfg := water.Config{
|
||||
|
@ -84,7 +84,11 @@ func LinkTun(sess *Session) {
|
|||
|
||||
}
|
||||
|
||||
func tunRead(ifce *water.Interface, sess *Session) {
|
||||
func tunRead(ifce *water.Interface, sess *ConnSession) {
|
||||
defer func() {
|
||||
log.Println("tunRead return")
|
||||
ifce.Close()
|
||||
}()
|
||||
var (
|
||||
err error
|
||||
n int
|
||||
|
|
|
@ -2,6 +2,7 @@ package handler
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
|
@ -19,7 +20,7 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
// TODO 调试信息输出
|
||||
// hd, _ := httputil.DumpRequest(r, true)
|
||||
// fmt.Println("DumpRequest: ", string(hd))
|
||||
// fmt.Println(r.RemoteAddr)
|
||||
fmt.Println("LinkTunnel", r.RemoteAddr)
|
||||
|
||||
// 判断session-token的值
|
||||
cookie, err := r.Cookie("webvpn")
|
||||
|
@ -35,8 +36,9 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// 开启link
|
||||
sess.StartLink()
|
||||
if sess.NetIp == nil {
|
||||
cSess := sess.StartConn()
|
||||
if cSess == nil {
|
||||
log.Println(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -44,14 +46,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
// 客户端信息
|
||||
cstp_mtu := r.Header.Get("X-CSTP-MTU")
|
||||
master_Secret := r.Header.Get("X-DTLS-Master-Secret")
|
||||
sess.MasterSecret = master_Secret
|
||||
sess.Mtu = cstp_mtu
|
||||
sess.RemoteAddr = r.RemoteAddr
|
||||
cSess.MasterSecret = master_Secret
|
||||
cSess.Mtu = cstp_mtu
|
||||
cSess.RemoteAddr = r.RemoteAddr
|
||||
|
||||
w.Header().Set("Server", fmt.Sprintf("%s %s", common.APP_NAME, common.APP_VER))
|
||||
w.Header().Set("X-CSTP-Version", "1")
|
||||
w.Header().Set("X-CSTP-Protocol", "Copyright (c) 2004 Cisco Systems, Inc.")
|
||||
w.Header().Set("X-CSTP-Address", sess.NetIp.String()) // 分配的ip地址
|
||||
w.Header().Set("X-CSTP-Address", cSess.NetIp.String()) // 分配的ip地址
|
||||
w.Header().Set("X-CSTP-Netmask", common.ServerCfg.Ipv4Netmask) // 子网掩码
|
||||
w.Header().Set("X-CSTP-Hostname", hn) // 机器名称
|
||||
for _, v := range common.ServerCfg.ClientDns {
|
||||
|
@ -115,6 +117,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// 开始数据处理
|
||||
go LinkTun(sess)
|
||||
go LinkCstp(conn, sess)
|
||||
go LinkTun(cSess)
|
||||
go LinkCstp(conn, cSess)
|
||||
}
|
||||
|
|
|
@ -3,11 +3,13 @@ package handler
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/bjdgyc/anylink/proxyproto"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
_ "net/http/pprof"
|
||||
"time"
|
||||
|
||||
"github.com/bjdgyc/anylink/common"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
@ -40,13 +42,18 @@ func startTls() {
|
|||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
var ln net.Listener
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
srv.SetKeepAlivesEnabled(true)
|
||||
if common.ServerCfg.ProxyProtocol {
|
||||
ln = &proxyproto.Listener{Listener: ln, ProxyHeaderTimeout: time.Second * 5}
|
||||
}
|
||||
|
||||
fmt.Println("listen ", addr)
|
||||
err = srv.ServeTLS(ln, certFile, keyFile)
|
||||
if err != nil {
|
||||
|
|
|
@ -17,25 +17,31 @@ var (
|
|||
sessions = make(map[string]*Session) // session_token -> SessUser
|
||||
)
|
||||
|
||||
// 连接sess
|
||||
type ConnSession struct {
|
||||
Sess *Session
|
||||
MasterSecret string // dtls协议的 master_secret
|
||||
NetIp net.IP // 分配的ip地址
|
||||
RemoteAddr string
|
||||
Mtu string
|
||||
TunName string
|
||||
closeOnce sync.Once
|
||||
Closed chan struct{}
|
||||
PayloadIn chan *Payload
|
||||
PayloadOut chan *Payload
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Sid string // auth返回的 session-id
|
||||
Token string // session信息的唯一token
|
||||
DtlsSid string // dtls协议的 session_id
|
||||
MacAddr string // 客户端mac地址
|
||||
UserName string // 用户名
|
||||
LastLogin time.Time
|
||||
IsActive bool
|
||||
|
||||
// 开启link需要设置的参数
|
||||
MasterSecret string // dtls协议的 master_secret
|
||||
NetIp net.IP // 分配的ip地址
|
||||
UserName string // 用户名
|
||||
RemoteAddr string
|
||||
Mtu string
|
||||
TunName string
|
||||
IsActive bool
|
||||
LastLogin time.Time
|
||||
closeOnce sync.Once
|
||||
Closed chan struct{}
|
||||
PayloadIn chan *Payload
|
||||
PayloadOut chan *Payload
|
||||
CSess *ConnSession
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -86,28 +92,37 @@ func NewSession() *Session {
|
|||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) StartLink() {
|
||||
func (s *Session) StartConn() *ConnSession {
|
||||
if s.IsActive == true {
|
||||
s.CSess.Close()
|
||||
}
|
||||
|
||||
limit := common.LimitClient(s.UserName, false)
|
||||
if limit == false {
|
||||
s.NetIp = nil
|
||||
return
|
||||
// s.NetIp = nil
|
||||
return nil
|
||||
}
|
||||
s.NetIp = common.AcquireIp(s.MacAddr)
|
||||
s.IsActive = true
|
||||
s.closeOnce = sync.Once{}
|
||||
s.Closed = make(chan struct{})
|
||||
s.PayloadIn = make(chan *Payload)
|
||||
s.PayloadOut = make(chan *Payload)
|
||||
cSess := &ConnSession{
|
||||
Sess: s,
|
||||
NetIp: common.AcquireIp(s.MacAddr),
|
||||
closeOnce: sync.Once{},
|
||||
Closed: make(chan struct{}),
|
||||
PayloadIn: make(chan *Payload),
|
||||
PayloadOut: make(chan *Payload),
|
||||
}
|
||||
s.CSess = cSess
|
||||
return cSess
|
||||
}
|
||||
|
||||
func (s *Session) Close() {
|
||||
s.closeOnce.Do(func() {
|
||||
func (cs *ConnSession) Close() {
|
||||
cs.closeOnce.Do(func() {
|
||||
log.Println("closeOnce")
|
||||
close(s.Closed)
|
||||
s.IsActive = false
|
||||
s.LastLogin = time.Now()
|
||||
common.ReleaseIp(s.NetIp, s.MacAddr)
|
||||
common.LimitClient(s.UserName, true)
|
||||
close(cs.Closed)
|
||||
cs.Sess.IsActive = false
|
||||
cs.Sess.LastLogin = time.Now()
|
||||
common.ReleaseIp(cs.NetIp, cs.Sess.MacAddr)
|
||||
common.LimitClient(cs.Sess.UserName, true)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
|
||||
// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// prefix is the string we look for at the start of a connection
|
||||
// to check if this connection is using the proxy protocol
|
||||
prefix = []byte("PROXY ")
|
||||
prefixLen = len(prefix)
|
||||
|
||||
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
|
||||
)
|
||||
|
||||
// SourceChecker can be used to decide whether to trust the PROXY info or pass
|
||||
// the original connection address through. If set, the connecting address is
|
||||
// passed in as an argument. If the function returns an error due to the source
|
||||
// being disallowed, it should return ErrInvalidUpstream.
|
||||
//
|
||||
// If error is not nil, the call to Accept() will fail. If the reason for
|
||||
// triggering this failure is due to a disallowed source, it should return
|
||||
// ErrInvalidUpstream.
|
||||
//
|
||||
// If bool is true, the PROXY-set address is used.
|
||||
//
|
||||
// If bool is false, the connection's remote address is used, rather than the
|
||||
// address claimed in the PROXY info.
|
||||
type SourceChecker func(net.Addr) (bool, error)
|
||||
|
||||
// Listener is used to wrap an underlying listener,
|
||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||
// If the connection is using the protocol, the RemoteAddr() will return
|
||||
// the correct client address.
|
||||
//
|
||||
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||
type Listener struct {
|
||||
Listener net.Listener
|
||||
ProxyHeaderTimeout time.Duration
|
||||
SourceCheck SourceChecker
|
||||
UnknownOK bool // allow PROXY UNKNOWN
|
||||
}
|
||||
|
||||
// Conn is used to wrap and underlying connection which
|
||||
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
|
||||
// return the address of the client instead of the proxy address.
|
||||
type Conn struct {
|
||||
bufReader *bufio.Reader
|
||||
conn net.Conn
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
useConnAddr bool
|
||||
once sync.Once
|
||||
proxyHeaderTimeout time.Duration
|
||||
unknownOK bool
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (p *Listener) Accept() (net.Conn, error) {
|
||||
// Get the underlying connection
|
||||
conn, err := p.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var useConnAddr bool
|
||||
if p.SourceCheck != nil {
|
||||
allowed, err := p.SourceCheck(conn.RemoteAddr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !allowed {
|
||||
useConnAddr = true
|
||||
}
|
||||
}
|
||||
newConn := NewConn(conn, p.ProxyHeaderTimeout)
|
||||
newConn.useConnAddr = useConnAddr
|
||||
newConn.unknownOK = p.UnknownOK
|
||||
return newConn, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying listener.
|
||||
func (p *Listener) Close() error {
|
||||
return p.Listener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the underlying listener's network address.
|
||||
func (p *Listener) Addr() net.Addr {
|
||||
return p.Listener.Addr()
|
||||
}
|
||||
|
||||
// NewConn is used to wrap a net.Conn that may be speaking
|
||||
// the proxy protocol into a proxyproto.Conn
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||
pConn := &Conn{
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
proxyHeaderTimeout: timeout,
|
||||
}
|
||||
return pConn
|
||||
}
|
||||
|
||||
// Read is check for the proxy protocol header when doing
|
||||
// the initial scan. If there is an error parsing the header,
|
||||
// it is returned and the socket is closed.
|
||||
func (p *Conn) Read(b []byte) (int, error) {
|
||||
var err error
|
||||
p.once.Do(func() { err = p.checkPrefix() })
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.bufReader.Read(b)
|
||||
}
|
||||
|
||||
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
|
||||
if rf, ok := p.conn.(io.ReaderFrom); ok {
|
||||
return rf.ReadFrom(r)
|
||||
}
|
||||
return io.Copy(p.conn, r)
|
||||
}
|
||||
|
||||
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
|
||||
var err error
|
||||
p.once.Do(func() { err = p.checkPrefix() })
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.bufReader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (p *Conn) Write(b []byte) (int, error) {
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
func (p *Conn) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func (p *Conn) LocalAddr() net.Addr {
|
||||
p.checkPrefixOnce()
|
||||
if p.dstAddr != nil && !p.useConnAddr {
|
||||
return p.dstAddr
|
||||
}
|
||||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the address of the client if the proxy
|
||||
// protocol is being used, otherwise just returns the address of
|
||||
// the socket peer. If there is an error parsing the header, the
|
||||
// address of the client is not returned, and the socket is closed.
|
||||
// Once implication of this is that the call could block if the
|
||||
// client is slow. Using a Deadline is recommended if this is called
|
||||
// before Read()
|
||||
func (p *Conn) RemoteAddr() net.Addr {
|
||||
p.checkPrefixOnce()
|
||||
if p.srcAddr != nil && !p.useConnAddr {
|
||||
return p.srcAddr
|
||||
}
|
||||
return p.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (p *Conn) SetDeadline(t time.Time) error {
|
||||
return p.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetReadDeadline(t time.Time) error {
|
||||
return p.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return p.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefixOnce() {
|
||||
p.once.Do(func() {
|
||||
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
||||
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||
p.Close()
|
||||
p.bufReader = bufio.NewReader(p.conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefix() error {
|
||||
if p.proxyHeaderTimeout != 0 {
|
||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||
p.conn.SetReadDeadline(readDeadLine)
|
||||
defer p.conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := p.bufReader.Peek(i)
|
||||
|
||||
if err != nil {
|
||||
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for a prefix mis-match, quit early
|
||||
if !bytes.Equal(inp, prefix[:i]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read the header line
|
||||
header, err := p.bufReader.ReadString('\n')
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Strip the carriage return and new line
|
||||
header = header[:len(header)-2]
|
||||
|
||||
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||
parts := strings.Split(header, " ")
|
||||
if len(parts) < 2 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Verify the type is known
|
||||
switch parts[1] {
|
||||
case "UNKNOWN":
|
||||
if !p.unknownOK || len(parts) != 2 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
|
||||
}
|
||||
p.useConnAddr = true
|
||||
return nil
|
||||
case "TCP4":
|
||||
case "TCP6":
|
||||
default:
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Unhandled address type: %s", parts[1])
|
||||
}
|
||||
|
||||
if len(parts) != 6 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Parse out the source address
|
||||
ip := net.ParseIP(parts[2])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source ip: %s", parts[2])
|
||||
}
|
||||
port, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source port: %s", parts[4])
|
||||
}
|
||||
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
// Parse out the destination address
|
||||
ip = net.ParseIP(parts[3])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination ip: %s", parts[3])
|
||||
}
|
||||
port, err = strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination port: %s", parts[5])
|
||||
}
|
||||
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,488 @@
|
|||
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
goodAddr = "127.0.0.1"
|
||||
badAddr = "127.0.0.2"
|
||||
errAddr = "9999.0.0.2"
|
||||
)
|
||||
|
||||
var (
|
||||
checkAddr string
|
||||
)
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
clientWriteDelay := 200 * time.Millisecond
|
||||
proxyHeaderTimeout := 50 * time.Millisecond
|
||||
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Do not send data for a while
|
||||
time.Sleep(clientWriteDelay)
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr is the original 127.0.0.1
|
||||
remoteAddrStartTime := time.Now()
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
remoteAddrDuration := time.Since(remoteAddrStartTime)
|
||||
|
||||
// Check RemoteAddr() call did timeout
|
||||
if remoteAddrDuration >= clientWriteDelay {
|
||||
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
|
||||
}
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "10.1.1.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv6(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "ffff::ffff" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_Unknown(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l, UnknownOK: true}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY UNKNOWN\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParse_BadHeader(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr, should be the local addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
|
||||
// Read should fail
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4_checkfunc(t *testing.T) {
|
||||
checkAddr = goodAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
checkAddr = badAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
checkAddr = errAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
}
|
||||
|
||||
func testParse_ipv4_checkfunc(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
checkFunc := func(addr net.Addr) (bool, error) {
|
||||
tcpAddr := addr.(*net.TCPAddr)
|
||||
if tcpAddr.IP.String() == checkAddr {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l, SourceCheck: checkFunc}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
if checkAddr == badAddr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
switch checkAddr {
|
||||
case goodAddr:
|
||||
if addr.IP.String() != "10.1.1.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
case badAddr:
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port == 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
readFromCalledWith io.Reader
|
||||
net.Conn // nil; crash on any unexpected use
|
||||
}
|
||||
|
||||
func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
|
||||
c.readFromCalledWith = r
|
||||
return 0, nil
|
||||
}
|
||||
func (c *testConn) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
func (c *testConn) Read(p []byte) (int, error) {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func TestCopyToWrappedConnection(t *testing.T) {
|
||||
innerConn := &testConn{}
|
||||
wrappedConn := NewConn(innerConn, 0)
|
||||
dummySrc := &testConn{}
|
||||
|
||||
io.Copy(wrappedConn, dummySrc)
|
||||
if innerConn.readFromCalledWith != dummySrc {
|
||||
t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFromWrappedConnection(t *testing.T) {
|
||||
wrappedConn := NewConn(&testConn{}, 0)
|
||||
dummyDst := &testConn{}
|
||||
|
||||
io.Copy(dummyDst, wrappedConn)
|
||||
if dummyDst.readFromCalledWith != wrappedConn.conn {
|
||||
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
|
||||
innerConn1 := &testConn{}
|
||||
wrappedConn1 := NewConn(innerConn1, 0)
|
||||
innerConn2 := &testConn{}
|
||||
wrappedConn2 := NewConn(innerConn2, 0)
|
||||
|
||||
io.Copy(wrappedConn1, wrappedConn2)
|
||||
if innerConn1.readFromCalledWith != innerConn2 {
|
||||
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue