更改目录结构

This commit is contained in:
bjdgyc
2021-03-01 15:46:08 +08:00
parent 3464d1d10e
commit 0f91c779e3
105 changed files with 29099 additions and 96 deletions

99
server/pkg/arpdis/addr.go Normal file
View File

@@ -0,0 +1,99 @@
package arpdis
import (
"net"
"sync"
"time"
)
const (
StaleTimeNormal = time.Minute * 5
StaleTimeUnreachable = time.Minute * 10
TypeNormal = 0
TypeUnreachable = 1
TypeStatic = 2
)
var (
table = make(map[string]*Addr)
tableMu sync.RWMutex
)
type Addr struct {
IP net.IP
HardwareAddr net.HardwareAddr
disTime time.Time
Type int8
}
func Lookup(ip net.IP, onlyTable bool) *Addr {
addr := tableLookup(ip.To4())
if addr != nil || onlyTable {
return addr
}
addr = doLookup(ip.To4())
Add(addr)
return addr
}
// Add adds a IP-MAC map to a runtime ARP table.
func tableLookup(ip net.IP) *Addr {
tableMu.Lock()
addr := table[ip.To4().String()]
tableMu.Unlock()
if addr == nil {
return nil
}
// 判断老化过期时间
tsub := time.Since(addr.disTime)
switch addr.Type {
case TypeNormal:
if tsub > StaleTimeNormal {
return nil
}
case TypeUnreachable:
if tsub > StaleTimeUnreachable {
return nil
}
case TypeStatic:
}
return addr
}
// Add adds a IP-MAC map to a runtime ARP table.
func Add(addr *Addr) {
if addr == nil {
return
}
if addr.disTime.IsZero() {
addr.disTime = time.Now()
}
ip := addr.IP.To4().String()
tableMu.Lock()
defer tableMu.Unlock()
if a, ok := table[ip]; ok {
// 静态地址只能设置一次
if a.Type == TypeStatic {
return
}
}
table[ip] = addr
}
// Delete removes an IP from the runtime ARP table.
func Delete(ip net.IP) {
tableMu.Lock()
defer tableMu.Unlock()
delete(table, ip.To4().String())
}
// List returns the current runtime ARP table.
func List() map[string]*Addr {
tableMu.RLock()
defer tableMu.RUnlock()
return table
}

View File

@@ -0,0 +1,35 @@
package arpdis
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLookup(t *testing.T) {
assert := assert.New(t)
ip := net.IPv4(192, 168, 10, 2)
hw, _ := net.ParseMAC("08:00:27:a0:17:42")
now := time.Now()
addr1 := &Addr{
IP: ip,
HardwareAddr: hw,
Type: TypeStatic,
disTime: now,
}
Add(addr1)
addr2 := Lookup(ip, true)
assert.Equal(addr1, addr2)
addr3 := &Addr{
IP: ip,
HardwareAddr: hw,
Type: TypeNormal,
disTime: now,
}
Add(addr3)
addr4 := Lookup(ip, true)
// 静态地址只能设置一次
assert.NotEqual(addr3, addr4)
}

56
server/pkg/arpdis/arp.go Normal file
View File

@@ -0,0 +1,56 @@
package arpdis
// Reference: github.com/malfunkt/arpfox
// TODO now only support IPv4
import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var defaultSerializeOpts = gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
// NewARPRequest creates a bew ARP packet of type "request.
func NewARPRequest(src *Addr, dst *Addr) ([]byte, error) {
return buildPacket(src, dst, layers.ARPRequest)
}
// NewARPReply creates a new ARP packet of type "reply".
func NewARPReply(src *Addr, dst *Addr) ([]byte, error) {
return buildPacket(src, dst, layers.ARPReply)
}
// buildPacket creates an template ARP packet with the given source and
// destination.
func buildPacket(src *Addr, dst *Addr, opt uint16) ([]byte, error) {
ether := layers.Ethernet{
EthernetType: layers.EthernetTypeARP,
SrcMAC: src.HardwareAddr,
DstMAC: dst.HardwareAddr,
}
arp := layers.ARP{
AddrType: layers.LinkTypeEthernet,
Protocol: layers.EthernetTypeIPv4,
HwAddressSize: 6,
ProtAddressSize: 4,
Operation: opt,
SourceHwAddress: src.HardwareAddr,
SourceProtAddress: src.IP.To4(),
DstHwAddress: dst.HardwareAddr,
DstProtAddress: dst.IP.To4(),
}
buf := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buf, defaultSerializeOpts, &ether, &arp)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

101
server/pkg/arpdis/icmp.go Normal file
View File

@@ -0,0 +1,101 @@
package arpdis
import (
"errors"
"net"
"os"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
ProtocolICMP = 1
ProtocolIPv6ICMP = 58
)
func doPing(ip string) error {
raddr, _ := net.ResolveIPAddr("ip4:icmp", ip)
conn, err := icmp.ListenPacket("ip4:icmp", "")
if err != nil {
return err
}
ipv4Conn := conn.IPv4PacketConn()
// 限制跳跃数
err = ipv4Conn.SetTTL(10)
if err != nil {
return err
}
msg := &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Seq: 1,
Data: timeToBytes(time.Now()),
},
}
b, err := msg.Marshal(nil)
if err != nil {
return err
}
_, err = conn.WriteTo(b, raddr)
if err != nil {
return err
}
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 2))
for {
buf := make([]byte, 512)
n, dst, err := conn.ReadFrom(buf)
if err != nil {
return err
}
if dst.String() != ip {
continue
}
var result *icmp.Message
result, err = icmp.ParseMessage(ProtocolICMP, buf[:n])
if err != nil {
return err
}
switch result.Type {
case ipv4.ICMPTypeEchoReply:
// success
if rply, ok := result.Body.(*icmp.Echo); ok {
_ = rply
// log.Printf("%+v \n", rply)
}
return nil
// case ipv4.ICMPTypeTimeExceeded:
// case ipv4.ICMPTypeDestinationUnreachable:
default:
return errors.New("DestinationUnreachable")
}
}
}
func timeToBytes(t time.Time) []byte {
nsec := t.UnixNano()
b := make([]byte, 8)
for i := uint8(0); i < 8; i++ {
b[i] = byte((nsec >> ((7 - i) * 8)) & 0xff)
}
return b
}
func bytesToTime(b []byte) time.Time {
var nsec int64
for i := uint8(0); i < 8; i++ {
nsec += int64(b[i]) << ((7 - i) * 8)
}
return time.Unix(nsec/1000000000, nsec%1000000000)
}

View File

@@ -0,0 +1,61 @@
// Currently only Darwin and Linux support this.
package arpdis
import (
"log"
"net"
"os/exec"
"strings"
)
func doLookup(ip net.IP) *Addr {
// ping := exec.Command("ping", "-c1", "-t1", ip.String())
// if err := ping.Run(); err != nil {
// addr := &Addr{IP: ip, Type: TypeUnreachable}
// return addr
// }
err := doPing(ip.String())
if err != nil {
// log.Println(err)
addr := &Addr{IP: ip, Type: TypeUnreachable}
return addr
}
return doArpShow(ip)
}
func doArpShow(ip net.IP) *Addr {
cmd := exec.Command("ip", "n", "show", ip.String())
out, err := cmd.Output()
if err != nil {
log.Println("lookup show", err)
return nil
}
// os.Open("/proc/net/arp")
// 192.168.1.2 0x1 0x2 e0:94:67:e2:42:5d * eth0
// 192.168.1.2 dev eth0 lladdr 08:00:27:94:a5:a4 STALE
outS := strings.ReplaceAll(string(out), " ", " ")
outS = strings.TrimSpace(outS)
arpArr := strings.Split(outS, " ")
if len(arpArr) != 6 {
log.Println("lookup arpArr", outS, ip)
return nil
}
mac, err := net.ParseMAC(arpArr[4])
if err != nil {
log.Println("lookup mac", outS, err)
return nil
}
return &Addr{IP: ip, HardwareAddr: mac}
}
// IP address HW type Flags HW address Mask Device
// 172.23.24.12 0x1 0x2 00:e0:4c:73:5c:48 * anylink0
// 172.23.24.1 0x1 0x2 3c:8c:40:a0:7a:2d * anylink0
// 172.23.24.13 0x1 0x2 00:1c:42:4d:33:46 * anylink0
// 172.23.24.2 0x1 0x0 00:00:00:00:00:00 * anylink0
// 172.23.24.14 0x1 0x0 00:00:00:00:00:00 * anylink0

View File

@@ -0,0 +1,290 @@
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol.go
// design: http://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
// HAProxy proxy proto v1
package proxyproto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
)
// SourceChecker can be used to decide whether to trust the PROXY info or pass
// the original connection address through. If set, the connecting address is
// passed in as an argument. If the function returns an error due to the source
// being disallowed, it should return ErrInvalidUpstream.
//
// If error is not nil, the call to Accept() will fail. If the reason for
// triggering this failure is due to a disallowed source, it should return
// ErrInvalidUpstream.
//
// If bool is true, the PROXY-set address is used.
//
// If bool is false, the connection's remote address is used, rather than the
// address claimed in the PROXY info.
type SourceChecker func(net.Addr) (bool, error)
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
SourceCheck SourceChecker
UnknownOK bool // allow PROXY UNKNOWN
}
// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
useConnAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
unknownOK bool
}
// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
var useConnAddr bool
if p.SourceCheck != nil {
allowed, err := p.SourceCheck(conn.RemoteAddr())
if err != nil {
return nil, err
}
if !allowed {
useConnAddr = true
}
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.useConnAddr = useConnAddr
newConn.unknownOK = p.UnknownOK
return newConn, nil
}
// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
}
return pConn
}
// Read is check for the proxy protocol header when doing
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.Read(b)
}
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(p.conn, r)
}
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.WriteTo(w)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
func (p *Conn) Close() error {
return p.conn.Close()
}
func (p *Conn) LocalAddr() net.Addr {
p.checkPrefixOnce()
if p.dstAddr != nil && !p.useConnAddr {
return p.dstAddr
}
return p.conn.LocalAddr()
}
// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
p.checkPrefixOnce()
if p.srcAddr != nil && !p.useConnAddr {
return p.srcAddr
}
return p.conn.RemoteAddr()
}
func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}
func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *Conn) checkPrefixOnce() {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
p.Close()
p.bufReader = bufio.NewReader(p.conn)
}
})
}
func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
_ = p.conn.SetReadDeadline(readDeadLine)
defer func() {
_ = p.conn.SetReadDeadline(time.Time{})
}()
}
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
}
// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
return nil
}
}
// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.conn.Close()
return err
}
// Strip the carriage return and new line
header = header[:len(header)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.unknownOK || len(parts) != 2 {
p.conn.Close()
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
}
p.useConnAddr = true
return nil
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
return nil
}

View File

@@ -0,0 +1,486 @@
// copy from: https://github.com/armon/go-proxyproto/blob/master/protocol_test.go
package proxyproto
import (
"bytes"
"io"
"net"
"testing"
"time"
)
const (
goodAddr = "127.0.0.1"
badAddr = "127.0.0.2"
errAddr = "9999.0.0.2"
)
var (
checkAddr string
)
func TestPassthrough(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
clientWriteDelay := 200 * time.Millisecond
proxyHeaderTimeout := 50 * time.Millisecond
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Do not send data for a while
time.Sleep(clientWriteDelay)
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr is the original 127.0.0.1
remoteAddrStartTime := time.Now()
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
remoteAddrDuration := time.Since(remoteAddrStartTime)
// Check RemoteAddr() call did timeout
if remoteAddrDuration >= clientWriteDelay {
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
}
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_ipv6(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "ffff::ffff" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_Unknown(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l, UnknownOK: true}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY UNKNOWN\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_BadHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{Listener: l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err == nil {
t.Fatalf("err: %v", err)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr, should be the local addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
// Read should fail
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err == nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4_checkfunc(t *testing.T) {
checkAddr = goodAddr
testParse_ipv4_checkfunc(t)
checkAddr = badAddr
testParse_ipv4_checkfunc(t)
checkAddr = errAddr
testParse_ipv4_checkfunc(t)
}
func testParse_ipv4_checkfunc(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
checkFunc := func(addr net.Addr) (bool, error) {
tcpAddr := addr.(*net.TCPAddr)
if tcpAddr.IP.String() == checkAddr {
return true, nil
}
return false, nil
}
pl := &Listener{Listener: l, SourceCheck: checkFunc}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
if checkAddr == badAddr {
return
}
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
switch checkAddr {
case goodAddr:
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
case badAddr:
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port == 1000 {
t.Fatalf("bad: %v", addr)
}
}
}
type testConn struct {
readFromCalledWith io.Reader
net.Conn // nil; crash on any unexpected use
}
func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
c.readFromCalledWith = r
return 0, nil
}
func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testConn) Read(p []byte) (int, error) {
return 1, nil
}
func TestCopyToWrappedConnection(t *testing.T) {
innerConn := &testConn{}
wrappedConn := NewConn(innerConn, 0)
dummySrc := &testConn{}
io.Copy(wrappedConn, dummySrc)
if innerConn.readFromCalledWith != dummySrc {
t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
}
}
func TestCopyFromWrappedConnection(t *testing.T) {
wrappedConn := NewConn(&testConn{}, 0)
dummyDst := &testConn{}
io.Copy(dummyDst, wrappedConn)
if dummyDst.readFromCalledWith != wrappedConn.conn {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
}
}
func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
innerConn1 := &testConn{}
wrappedConn1 := NewConn(innerConn1, 0)
innerConn2 := &testConn{}
wrappedConn2 := NewConn(innerConn2, 0)
io.Copy(wrappedConn1, wrappedConn2)
if innerConn1.readFromCalledWith != innerConn2 {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
}
}

View File

@@ -0,0 +1,40 @@
package utils
import (
"crypto/rand"
"encoding/base64"
mt "math/rand"
"golang.org/x/crypto/bcrypt"
)
func PasswordHash(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func PasswordVerify(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// $sha-256$salt-key$hash-abcd
// $sha-512$salt-key$hash-abcd
const (
saltSize = 16
delmiter = "$"
)
func RandSecret(min int, max int) (string, error) {
rb := make([]byte, randInt(min, max))
_, err := rand.Read(rb)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(rb), nil
}
func randInt(min int, max int) int {
return min + mt.Intn(max-min)
}

74
server/pkg/utils/util.go Normal file
View File

@@ -0,0 +1,74 @@
package utils
import (
"fmt"
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func InArrStr(arr []string, str string) bool {
for _, d := range arr {
if d == str {
return true
}
}
return false
}
const (
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
TB = 1024 * GB
PB = 1024 * TB
)
func HumanByte(bf interface{}) string {
var hb string
var bAll float64
switch bi := bf.(type) {
case int:
bAll = float64(bi)
case int32:
bAll = float64(bi)
case uint32:
bAll = float64(bi)
case int64:
bAll = float64(bi)
case uint64:
bAll = float64(bi)
case float64:
bAll = float64(bi)
}
switch {
case bAll >= TB:
hb = fmt.Sprintf("%0.2f TB", bAll/TB)
case bAll >= GB:
hb = fmt.Sprintf("%0.2f GB", bAll/GB)
case bAll >= MB:
hb = fmt.Sprintf("%0.2f MB", bAll/MB)
case bAll >= KB:
hb = fmt.Sprintf("%0.2f KB", bAll/KB)
default:
hb = fmt.Sprintf("%0.2f B", bAll)
}
return hb
}
func RandomNum(length int) string {
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")
bytes := make([]rune, length)
for i := range bytes {
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(bytes)
}

View File

@@ -0,0 +1,29 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestInArrStr(t *testing.T) {
assert := assert.New(t)
arr := []string{"a", "b", "c"}
assert.True(InArrStr(arr, "b"))
assert.False(InArrStr(arr, "d"))
}
func TestHumanByte(t *testing.T) {
assert := assert.New(t)
var s string
s = HumanByte(999)
assert.Equal(s, "999.00 B")
s = HumanByte(10256)
assert.Equal(s, "10.02 KB")
s = HumanByte(99 * 1024 * 1024)
assert.Equal(s, "99.00 MB")
s = HumanByte(1023 * 1024 * 1024)
assert.Equal(s, "1023.00 MB")
s = HumanByte(1024 * 1024 * 1024)
assert.Equal(s, "1.00 GB")
}