mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 00:31:17 +08:00
更改目录结构
This commit is contained in:
99
server/pkg/arpdis/addr.go
Normal file
99
server/pkg/arpdis/addr.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package arpdis
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
StaleTimeNormal = time.Minute * 5
|
||||
StaleTimeUnreachable = time.Minute * 10
|
||||
|
||||
TypeNormal = 0
|
||||
TypeUnreachable = 1
|
||||
TypeStatic = 2
|
||||
)
|
||||
|
||||
var (
|
||||
table = make(map[string]*Addr)
|
||||
tableMu sync.RWMutex
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
IP net.IP
|
||||
HardwareAddr net.HardwareAddr
|
||||
disTime time.Time
|
||||
Type int8
|
||||
}
|
||||
|
||||
func Lookup(ip net.IP, onlyTable bool) *Addr {
|
||||
addr := tableLookup(ip.To4())
|
||||
if addr != nil || onlyTable {
|
||||
return addr
|
||||
}
|
||||
|
||||
addr = doLookup(ip.To4())
|
||||
Add(addr)
|
||||
return addr
|
||||
}
|
||||
|
||||
// Add adds a IP-MAC map to a runtime ARP table.
|
||||
func tableLookup(ip net.IP) *Addr {
|
||||
tableMu.Lock()
|
||||
addr := table[ip.To4().String()]
|
||||
tableMu.Unlock()
|
||||
if addr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 判断老化过期时间
|
||||
tsub := time.Since(addr.disTime)
|
||||
switch addr.Type {
|
||||
case TypeNormal:
|
||||
if tsub > StaleTimeNormal {
|
||||
return nil
|
||||
}
|
||||
case TypeUnreachable:
|
||||
if tsub > StaleTimeUnreachable {
|
||||
return nil
|
||||
}
|
||||
case TypeStatic:
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// Add adds a IP-MAC map to a runtime ARP table.
|
||||
func Add(addr *Addr) {
|
||||
if addr == nil {
|
||||
return
|
||||
}
|
||||
if addr.disTime.IsZero() {
|
||||
addr.disTime = time.Now()
|
||||
}
|
||||
ip := addr.IP.To4().String()
|
||||
tableMu.Lock()
|
||||
defer tableMu.Unlock()
|
||||
if a, ok := table[ip]; ok {
|
||||
// 静态地址只能设置一次
|
||||
if a.Type == TypeStatic {
|
||||
return
|
||||
}
|
||||
}
|
||||
table[ip] = addr
|
||||
}
|
||||
|
||||
// Delete removes an IP from the runtime ARP table.
|
||||
func Delete(ip net.IP) {
|
||||
tableMu.Lock()
|
||||
defer tableMu.Unlock()
|
||||
delete(table, ip.To4().String())
|
||||
}
|
||||
|
||||
// List returns the current runtime ARP table.
|
||||
func List() map[string]*Addr {
|
||||
tableMu.RLock()
|
||||
defer tableMu.RUnlock()
|
||||
return table
|
||||
}
|
35
server/pkg/arpdis/addr_test.go
Normal file
35
server/pkg/arpdis/addr_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package arpdis
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
ip := net.IPv4(192, 168, 10, 2)
|
||||
hw, _ := net.ParseMAC("08:00:27:a0:17:42")
|
||||
now := time.Now()
|
||||
addr1 := &Addr{
|
||||
IP: ip,
|
||||
HardwareAddr: hw,
|
||||
Type: TypeStatic,
|
||||
disTime: now,
|
||||
}
|
||||
Add(addr1)
|
||||
addr2 := Lookup(ip, true)
|
||||
assert.Equal(addr1, addr2)
|
||||
addr3 := &Addr{
|
||||
IP: ip,
|
||||
HardwareAddr: hw,
|
||||
Type: TypeNormal,
|
||||
disTime: now,
|
||||
}
|
||||
Add(addr3)
|
||||
addr4 := Lookup(ip, true)
|
||||
// 静态地址只能设置一次
|
||||
assert.NotEqual(addr3, addr4)
|
||||
}
|
56
server/pkg/arpdis/arp.go
Normal file
56
server/pkg/arpdis/arp.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package arpdis
|
||||
|
||||
// Reference: github.com/malfunkt/arpfox
|
||||
// TODO now only support IPv4
|
||||
|
||||
import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
var defaultSerializeOpts = gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
// NewARPRequest creates a bew ARP packet of type "request.
|
||||
func NewARPRequest(src *Addr, dst *Addr) ([]byte, error) {
|
||||
return buildPacket(src, dst, layers.ARPRequest)
|
||||
}
|
||||
|
||||
// NewARPReply creates a new ARP packet of type "reply".
|
||||
func NewARPReply(src *Addr, dst *Addr) ([]byte, error) {
|
||||
return buildPacket(src, dst, layers.ARPReply)
|
||||
}
|
||||
|
||||
// buildPacket creates an template ARP packet with the given source and
|
||||
// destination.
|
||||
func buildPacket(src *Addr, dst *Addr, opt uint16) ([]byte, error) {
|
||||
ether := layers.Ethernet{
|
||||
EthernetType: layers.EthernetTypeARP,
|
||||
SrcMAC: src.HardwareAddr,
|
||||
DstMAC: dst.HardwareAddr,
|
||||
}
|
||||
arp := layers.ARP{
|
||||
AddrType: layers.LinkTypeEthernet,
|
||||
Protocol: layers.EthernetTypeIPv4,
|
||||
|
||||
HwAddressSize: 6,
|
||||
ProtAddressSize: 4,
|
||||
Operation: opt,
|
||||
|
||||
SourceHwAddress: src.HardwareAddr,
|
||||
SourceProtAddress: src.IP.To4(),
|
||||
|
||||
DstHwAddress: dst.HardwareAddr,
|
||||
DstProtAddress: dst.IP.To4(),
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
err := gopacket.SerializeLayers(buf, defaultSerializeOpts, ðer, &arp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
101
server/pkg/arpdis/icmp.go
Normal file
101
server/pkg/arpdis/icmp.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package arpdis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolICMP = 1
|
||||
ProtocolIPv6ICMP = 58
|
||||
)
|
||||
|
||||
func doPing(ip string) error {
|
||||
raddr, _ := net.ResolveIPAddr("ip4:icmp", ip)
|
||||
conn, err := icmp.ListenPacket("ip4:icmp", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipv4Conn := conn.IPv4PacketConn()
|
||||
// 限制跳跃数
|
||||
err = ipv4Conn.SetTTL(10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: os.Getpid() & 0xffff,
|
||||
Seq: 1,
|
||||
Data: timeToBytes(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
b, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.WriteTo(b, raddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 2))
|
||||
|
||||
for {
|
||||
buf := make([]byte, 512)
|
||||
n, dst, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dst.String() != ip {
|
||||
continue
|
||||
}
|
||||
|
||||
var result *icmp.Message
|
||||
result, err = icmp.ParseMessage(ProtocolICMP, buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch result.Type {
|
||||
case ipv4.ICMPTypeEchoReply:
|
||||
// success
|
||||
if rply, ok := result.Body.(*icmp.Echo); ok {
|
||||
_ = rply
|
||||
// log.Printf("%+v \n", rply)
|
||||
}
|
||||
return nil
|
||||
|
||||
// case ipv4.ICMPTypeTimeExceeded:
|
||||
// case ipv4.ICMPTypeDestinationUnreachable:
|
||||
default:
|
||||
return errors.New("DestinationUnreachable")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func timeToBytes(t time.Time) []byte {
|
||||
nsec := t.UnixNano()
|
||||
b := make([]byte, 8)
|
||||
for i := uint8(0); i < 8; i++ {
|
||||
b[i] = byte((nsec >> ((7 - i) * 8)) & 0xff)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func bytesToTime(b []byte) time.Time {
|
||||
var nsec int64
|
||||
for i := uint8(0); i < 8; i++ {
|
||||
nsec += int64(b[i]) << ((7 - i) * 8)
|
||||
}
|
||||
return time.Unix(nsec/1000000000, nsec%1000000000)
|
||||
}
|
61
server/pkg/arpdis/lookup.go
Normal file
61
server/pkg/arpdis/lookup.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Currently only Darwin and Linux support this.
|
||||
|
||||
package arpdis
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func doLookup(ip net.IP) *Addr {
|
||||
// ping := exec.Command("ping", "-c1", "-t1", ip.String())
|
||||
// if err := ping.Run(); err != nil {
|
||||
// addr := &Addr{IP: ip, Type: TypeUnreachable}
|
||||
// return addr
|
||||
// }
|
||||
|
||||
err := doPing(ip.String())
|
||||
if err != nil {
|
||||
// log.Println(err)
|
||||
addr := &Addr{IP: ip, Type: TypeUnreachable}
|
||||
return addr
|
||||
}
|
||||
|
||||
return doArpShow(ip)
|
||||
}
|
||||
|
||||
func doArpShow(ip net.IP) *Addr {
|
||||
cmd := exec.Command("ip", "n", "show", ip.String())
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Println("lookup show", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// os.Open("/proc/net/arp")
|
||||
// 192.168.1.2 0x1 0x2 e0:94:67:e2:42:5d * eth0
|
||||
// 192.168.1.2 dev eth0 lladdr 08:00:27:94:a5:a4 STALE
|
||||
outS := strings.ReplaceAll(string(out), " ", " ")
|
||||
outS = strings.TrimSpace(outS)
|
||||
arpArr := strings.Split(outS, " ")
|
||||
if len(arpArr) != 6 {
|
||||
log.Println("lookup arpArr", outS, ip)
|
||||
return nil
|
||||
}
|
||||
mac, err := net.ParseMAC(arpArr[4])
|
||||
if err != nil {
|
||||
log.Println("lookup mac", outS, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Addr{IP: ip, HardwareAddr: mac}
|
||||
}
|
||||
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 172.23.24.12 0x1 0x2 00:e0:4c:73:5c:48 * anylink0
|
||||
// 172.23.24.1 0x1 0x2 3c:8c:40:a0:7a:2d * anylink0
|
||||
// 172.23.24.13 0x1 0x2 00:1c:42:4d:33:46 * anylink0
|
||||
// 172.23.24.2 0x1 0x0 00:00:00:00:00:00 * anylink0
|
||||
// 172.23.24.14 0x1 0x0 00:00:00:00:00:00 * anylink0
|
290
server/pkg/proxyproto/protocol.go
Normal file
290
server/pkg/proxyproto/protocol.go
Normal file
@@ -0,0 +1,290 @@
|
||||
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
|
||||
// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
|
||||
|
||||
// HAProxy proxy proto v1
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// prefix is the string we look for at the start of a connection
|
||||
// to check if this connection is using the proxy protocol
|
||||
prefix = []byte("PROXY ")
|
||||
prefixLen = len(prefix)
|
||||
|
||||
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
|
||||
)
|
||||
|
||||
// SourceChecker can be used to decide whether to trust the PROXY info or pass
|
||||
// the original connection address through. If set, the connecting address is
|
||||
// passed in as an argument. If the function returns an error due to the source
|
||||
// being disallowed, it should return ErrInvalidUpstream.
|
||||
//
|
||||
// If error is not nil, the call to Accept() will fail. If the reason for
|
||||
// triggering this failure is due to a disallowed source, it should return
|
||||
// ErrInvalidUpstream.
|
||||
//
|
||||
// If bool is true, the PROXY-set address is used.
|
||||
//
|
||||
// If bool is false, the connection's remote address is used, rather than the
|
||||
// address claimed in the PROXY info.
|
||||
type SourceChecker func(net.Addr) (bool, error)
|
||||
|
||||
// Listener is used to wrap an underlying listener,
|
||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||
// If the connection is using the protocol, the RemoteAddr() will return
|
||||
// the correct client address.
|
||||
//
|
||||
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||
type Listener struct {
|
||||
Listener net.Listener
|
||||
ProxyHeaderTimeout time.Duration
|
||||
SourceCheck SourceChecker
|
||||
UnknownOK bool // allow PROXY UNKNOWN
|
||||
}
|
||||
|
||||
// Conn is used to wrap and underlying connection which
|
||||
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
|
||||
// return the address of the client instead of the proxy address.
|
||||
type Conn struct {
|
||||
bufReader *bufio.Reader
|
||||
conn net.Conn
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
useConnAddr bool
|
||||
once sync.Once
|
||||
proxyHeaderTimeout time.Duration
|
||||
unknownOK bool
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (p *Listener) Accept() (net.Conn, error) {
|
||||
// Get the underlying connection
|
||||
conn, err := p.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var useConnAddr bool
|
||||
if p.SourceCheck != nil {
|
||||
allowed, err := p.SourceCheck(conn.RemoteAddr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !allowed {
|
||||
useConnAddr = true
|
||||
}
|
||||
}
|
||||
newConn := NewConn(conn, p.ProxyHeaderTimeout)
|
||||
newConn.useConnAddr = useConnAddr
|
||||
newConn.unknownOK = p.UnknownOK
|
||||
return newConn, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying listener.
|
||||
func (p *Listener) Close() error {
|
||||
return p.Listener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the underlying listener's network address.
|
||||
func (p *Listener) Addr() net.Addr {
|
||||
return p.Listener.Addr()
|
||||
}
|
||||
|
||||
// NewConn is used to wrap a net.Conn that may be speaking
|
||||
// the proxy protocol into a proxyproto.Conn
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||
pConn := &Conn{
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
proxyHeaderTimeout: timeout,
|
||||
}
|
||||
return pConn
|
||||
}
|
||||
|
||||
// Read is check for the proxy protocol header when doing
|
||||
// the initial scan. If there is an error parsing the header,
|
||||
// it is returned and the socket is closed.
|
||||
func (p *Conn) Read(b []byte) (int, error) {
|
||||
var err error
|
||||
p.once.Do(func() { err = p.checkPrefix() })
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.bufReader.Read(b)
|
||||
}
|
||||
|
||||
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
|
||||
if rf, ok := p.conn.(io.ReaderFrom); ok {
|
||||
return rf.ReadFrom(r)
|
||||
}
|
||||
return io.Copy(p.conn, r)
|
||||
}
|
||||
|
||||
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
|
||||
var err error
|
||||
p.once.Do(func() { err = p.checkPrefix() })
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.bufReader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (p *Conn) Write(b []byte) (int, error) {
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
func (p *Conn) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func (p *Conn) LocalAddr() net.Addr {
|
||||
p.checkPrefixOnce()
|
||||
if p.dstAddr != nil && !p.useConnAddr {
|
||||
return p.dstAddr
|
||||
}
|
||||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the address of the client if the proxy
|
||||
// protocol is being used, otherwise just returns the address of
|
||||
// the socket peer. If there is an error parsing the header, the
|
||||
// address of the client is not returned, and the socket is closed.
|
||||
// Once implication of this is that the call could block if the
|
||||
// client is slow. Using a Deadline is recommended if this is called
|
||||
// before Read()
|
||||
func (p *Conn) RemoteAddr() net.Addr {
|
||||
p.checkPrefixOnce()
|
||||
if p.srcAddr != nil && !p.useConnAddr {
|
||||
return p.srcAddr
|
||||
}
|
||||
return p.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (p *Conn) SetDeadline(t time.Time) error {
|
||||
return p.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetReadDeadline(t time.Time) error {
|
||||
return p.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return p.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefixOnce() {
|
||||
p.once.Do(func() {
|
||||
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
||||
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||
p.Close()
|
||||
p.bufReader = bufio.NewReader(p.conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefix() error {
|
||||
if p.proxyHeaderTimeout != 0 {
|
||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||
_ = p.conn.SetReadDeadline(readDeadLine)
|
||||
defer func() {
|
||||
_ = p.conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := p.bufReader.Peek(i)
|
||||
|
||||
if err != nil {
|
||||
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for a prefix mis-match, quit early
|
||||
if !bytes.Equal(inp, prefix[:i]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read the header line
|
||||
header, err := p.bufReader.ReadString('\n')
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Strip the carriage return and new line
|
||||
header = header[:len(header)-2]
|
||||
|
||||
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||
parts := strings.Split(header, " ")
|
||||
if len(parts) < 2 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Verify the type is known
|
||||
switch parts[1] {
|
||||
case "UNKNOWN":
|
||||
if !p.unknownOK || len(parts) != 2 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
|
||||
}
|
||||
p.useConnAddr = true
|
||||
return nil
|
||||
case "TCP4":
|
||||
case "TCP6":
|
||||
default:
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Unhandled address type: %s", parts[1])
|
||||
}
|
||||
|
||||
if len(parts) != 6 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Parse out the source address
|
||||
ip := net.ParseIP(parts[2])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source ip: %s", parts[2])
|
||||
}
|
||||
port, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source port: %s", parts[4])
|
||||
}
|
||||
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
// Parse out the destination address
|
||||
ip = net.ParseIP(parts[3])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination ip: %s", parts[3])
|
||||
}
|
||||
port, err = strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination port: %s", parts[5])
|
||||
}
|
||||
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
return nil
|
||||
}
|
486
server/pkg/proxyproto/protocol_test.go
Normal file
486
server/pkg/proxyproto/protocol_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
goodAddr = "127.0.0.1"
|
||||
badAddr = "127.0.0.2"
|
||||
errAddr = "9999.0.0.2"
|
||||
)
|
||||
|
||||
var (
|
||||
checkAddr string
|
||||
)
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
clientWriteDelay := 200 * time.Millisecond
|
||||
proxyHeaderTimeout := 50 * time.Millisecond
|
||||
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Do not send data for a while
|
||||
time.Sleep(clientWriteDelay)
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr is the original 127.0.0.1
|
||||
remoteAddrStartTime := time.Now()
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
remoteAddrDuration := time.Since(remoteAddrStartTime)
|
||||
|
||||
// Check RemoteAddr() call did timeout
|
||||
if remoteAddrDuration >= clientWriteDelay {
|
||||
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
|
||||
}
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "10.1.1.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv6(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "ffff::ffff" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_Unknown(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l, UnknownOK: true}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY UNKNOWN\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParse_BadHeader(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr, should be the local addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
|
||||
// Read should fail
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4_checkfunc(t *testing.T) {
|
||||
checkAddr = goodAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
checkAddr = badAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
checkAddr = errAddr
|
||||
testParse_ipv4_checkfunc(t)
|
||||
}
|
||||
|
||||
func testParse_ipv4_checkfunc(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
checkFunc := func(addr net.Addr) (bool, error) {
|
||||
tcpAddr := addr.(*net.TCPAddr)
|
||||
if tcpAddr.IP.String() == checkAddr {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pl := &Listener{Listener: l, SourceCheck: checkFunc}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
if checkAddr == badAddr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
switch checkAddr {
|
||||
case goodAddr:
|
||||
if addr.IP.String() != "10.1.1.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
case badAddr:
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port == 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
readFromCalledWith io.Reader
|
||||
net.Conn // nil; crash on any unexpected use
|
||||
}
|
||||
|
||||
func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
|
||||
c.readFromCalledWith = r
|
||||
return 0, nil
|
||||
}
|
||||
func (c *testConn) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
func (c *testConn) Read(p []byte) (int, error) {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func TestCopyToWrappedConnection(t *testing.T) {
|
||||
innerConn := &testConn{}
|
||||
wrappedConn := NewConn(innerConn, 0)
|
||||
dummySrc := &testConn{}
|
||||
|
||||
io.Copy(wrappedConn, dummySrc)
|
||||
if innerConn.readFromCalledWith != dummySrc {
|
||||
t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFromWrappedConnection(t *testing.T) {
|
||||
wrappedConn := NewConn(&testConn{}, 0)
|
||||
dummyDst := &testConn{}
|
||||
|
||||
io.Copy(dummyDst, wrappedConn)
|
||||
if dummyDst.readFromCalledWith != wrappedConn.conn {
|
||||
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
|
||||
innerConn1 := &testConn{}
|
||||
wrappedConn1 := NewConn(innerConn1, 0)
|
||||
innerConn2 := &testConn{}
|
||||
wrappedConn2 := NewConn(innerConn2, 0)
|
||||
|
||||
io.Copy(wrappedConn1, wrappedConn2)
|
||||
if innerConn1.readFromCalledWith != innerConn2 {
|
||||
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
|
||||
}
|
||||
|
||||
}
|
40
server/pkg/utils/password_hash.go
Normal file
40
server/pkg/utils/password_hash.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
mt "math/rand"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func PasswordHash(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func PasswordVerify(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// $sha-256$salt-key$hash-abcd
|
||||
// $sha-512$salt-key$hash-abcd
|
||||
const (
|
||||
saltSize = 16
|
||||
delmiter = "$"
|
||||
)
|
||||
|
||||
func RandSecret(min int, max int) (string, error) {
|
||||
rb := make([]byte, randInt(min, max))
|
||||
_, err := rand.Read(rb)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(rb), nil
|
||||
}
|
||||
|
||||
func randInt(min int, max int) int {
|
||||
return min + mt.Intn(max-min)
|
||||
}
|
74
server/pkg/utils/util.go
Normal file
74
server/pkg/utils/util.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func InArrStr(arr []string, str string) bool {
|
||||
for _, d := range arr {
|
||||
if d == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
TB = 1024 * GB
|
||||
PB = 1024 * TB
|
||||
)
|
||||
|
||||
func HumanByte(bf interface{}) string {
|
||||
var hb string
|
||||
var bAll float64
|
||||
switch bi := bf.(type) {
|
||||
case int:
|
||||
bAll = float64(bi)
|
||||
case int32:
|
||||
bAll = float64(bi)
|
||||
case uint32:
|
||||
bAll = float64(bi)
|
||||
case int64:
|
||||
bAll = float64(bi)
|
||||
case uint64:
|
||||
bAll = float64(bi)
|
||||
case float64:
|
||||
bAll = float64(bi)
|
||||
}
|
||||
|
||||
switch {
|
||||
case bAll >= TB:
|
||||
hb = fmt.Sprintf("%0.2f TB", bAll/TB)
|
||||
case bAll >= GB:
|
||||
hb = fmt.Sprintf("%0.2f GB", bAll/GB)
|
||||
case bAll >= MB:
|
||||
hb = fmt.Sprintf("%0.2f MB", bAll/MB)
|
||||
case bAll >= KB:
|
||||
hb = fmt.Sprintf("%0.2f KB", bAll/KB)
|
||||
default:
|
||||
hb = fmt.Sprintf("%0.2f B", bAll)
|
||||
}
|
||||
|
||||
return hb
|
||||
}
|
||||
|
||||
func RandomNum(length int) string {
|
||||
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
|
||||
|
||||
bytes := make([]rune, length)
|
||||
|
||||
for i := range bytes {
|
||||
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
|
||||
}
|
||||
|
||||
return string(bytes)
|
||||
}
|
29
server/pkg/utils/util_test.go
Normal file
29
server/pkg/utils/util_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInArrStr(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
arr := []string{"a", "b", "c"}
|
||||
assert.True(InArrStr(arr, "b"))
|
||||
assert.False(InArrStr(arr, "d"))
|
||||
}
|
||||
|
||||
func TestHumanByte(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
var s string
|
||||
s = HumanByte(999)
|
||||
assert.Equal(s, "999.00 B")
|
||||
s = HumanByte(10256)
|
||||
assert.Equal(s, "10.02 KB")
|
||||
s = HumanByte(99 * 1024 * 1024)
|
||||
assert.Equal(s, "99.00 MB")
|
||||
s = HumanByte(1023 * 1024 * 1024)
|
||||
assert.Equal(s, "1023.00 MB")
|
||||
s = HumanByte(1024 * 1024 * 1024)
|
||||
assert.Equal(s, "1.00 GB")
|
||||
}
|
Reference in New Issue
Block a user