mirror of https://github.com/bjdgyc/anylink.git
330 lines
7.6 KiB
Go
330 lines
7.6 KiB
Go
// +build !js
|
|
|
|
package e2e
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/dtls/v2"
|
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
|
"github.com/pion/transport/test"
|
|
)
|
|
|
|
const (
|
|
testMessage = "Hello World"
|
|
testTimeLimit = 5 * time.Second
|
|
messageRetry = 200 * time.Millisecond
|
|
)
|
|
|
|
var errServerTimeout = errors.New("waiting on serverReady err: timeout")
|
|
|
|
func randomPort(t testing.TB) int {
|
|
t.Helper()
|
|
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("failed to pickPort: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
switch addr := conn.LocalAddr().(type) {
|
|
case *net.UDPAddr:
|
|
return addr.Port
|
|
default:
|
|
t.Fatalf("unknown addr type %T", addr)
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
|
|
go func() {
|
|
buffer := make([]byte, 8192)
|
|
n, err := conn.Read(buffer)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
outChan <- string(buffer[:n])
|
|
atomic.AddUint64(messageRecvCount, 1)
|
|
}()
|
|
|
|
for {
|
|
if atomic.LoadUint64(messageRecvCount) == 2 {
|
|
break
|
|
} else if _, err := conn.Write([]byte(testMessage)); err != nil {
|
|
errChan <- err
|
|
break
|
|
}
|
|
|
|
time.Sleep(messageRetry)
|
|
}
|
|
}
|
|
|
|
type comm struct {
|
|
ctx context.Context
|
|
clientConfig, serverConfig *dtls.Config
|
|
serverPort int
|
|
messageRecvCount *uint64 // Counter to make sure both sides got a message
|
|
clientMutex *sync.Mutex
|
|
clientConn net.Conn
|
|
serverMutex *sync.Mutex
|
|
serverConn net.Conn
|
|
serverListener net.Listener
|
|
serverReady chan struct{}
|
|
errChan chan error
|
|
clientChan chan string
|
|
serverChan chan string
|
|
client func(*comm)
|
|
server func(*comm)
|
|
}
|
|
|
|
func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
|
|
messageRecvCount := uint64(0)
|
|
c := &comm{
|
|
ctx: ctx,
|
|
clientConfig: clientConfig,
|
|
serverConfig: serverConfig,
|
|
serverPort: serverPort,
|
|
messageRecvCount: &messageRecvCount,
|
|
clientMutex: &sync.Mutex{},
|
|
serverMutex: &sync.Mutex{},
|
|
serverReady: make(chan struct{}),
|
|
errChan: make(chan error),
|
|
clientChan: make(chan string),
|
|
serverChan: make(chan string),
|
|
server: server,
|
|
client: client,
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (c *comm) assert(t *testing.T) {
|
|
// DTLS Client
|
|
go c.client(c)
|
|
|
|
// DTLS Server
|
|
go c.server(c)
|
|
|
|
defer func() {
|
|
if c.clientConn != nil {
|
|
if err := c.clientConn.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if c.serverConn != nil {
|
|
if err := c.serverConn.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if c.serverListener != nil {
|
|
if err := c.serverListener.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
func() {
|
|
seenClient, seenServer := false, false
|
|
for {
|
|
select {
|
|
case err := <-c.errChan:
|
|
t.Fatal(err)
|
|
case <-time.After(testTimeLimit):
|
|
t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
|
|
case clientMsg := <-c.clientChan:
|
|
if clientMsg != testMessage {
|
|
t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
|
|
}
|
|
|
|
seenClient = true
|
|
if seenClient && seenServer {
|
|
return
|
|
}
|
|
case serverMsg := <-c.serverChan:
|
|
if serverMsg != testMessage {
|
|
t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
|
|
}
|
|
|
|
seenServer = true
|
|
if seenClient && seenServer {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func clientPion(c *comm) {
|
|
select {
|
|
case <-c.serverReady:
|
|
// OK
|
|
case <-time.After(time.Second):
|
|
c.errChan <- errServerTimeout
|
|
}
|
|
|
|
c.clientMutex.Lock()
|
|
defer c.clientMutex.Unlock()
|
|
|
|
var err error
|
|
c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
|
|
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
|
c.clientConfig,
|
|
)
|
|
if err != nil {
|
|
c.errChan <- err
|
|
return
|
|
}
|
|
|
|
simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
|
|
}
|
|
|
|
func serverPion(c *comm) {
|
|
c.serverMutex.Lock()
|
|
defer c.serverMutex.Unlock()
|
|
|
|
var err error
|
|
c.serverListener, err = dtls.Listen("udp",
|
|
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
|
c.serverConfig,
|
|
)
|
|
if err != nil {
|
|
c.errChan <- err
|
|
return
|
|
}
|
|
c.serverReady <- struct{}{}
|
|
c.serverConn, err = c.serverListener.Accept()
|
|
if err != nil {
|
|
c.errChan <- err
|
|
return
|
|
}
|
|
|
|
simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
|
|
}
|
|
|
|
/*
|
|
Simple DTLS Client/Server can communicate
|
|
- Assert that you can send messages both ways
|
|
- Assert that Close() on both ends work
|
|
- Assert that no Goroutines are leaked
|
|
*/
|
|
func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
|
|
lim := test.TimeOut(time.Second * 30)
|
|
defer lim.Stop()
|
|
|
|
report := test.CheckRoutines(t)
|
|
defer report()
|
|
|
|
for _, cipherSuite := range []dtls.CipherSuiteID{
|
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
|
} {
|
|
cipherSuite := cipherSuite
|
|
t.Run(cipherSuite.String(), func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg := &dtls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
|
InsecureSkipVerify: true,
|
|
}
|
|
serverPort := randomPort(t)
|
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
|
comm.assert(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
|
|
lim := test.TimeOut(time.Second * 30)
|
|
defer lim.Stop()
|
|
|
|
report := test.CheckRoutines(t)
|
|
defer report()
|
|
|
|
for _, cipherSuite := range []dtls.CipherSuiteID{
|
|
dtls.TLS_PSK_WITH_AES_128_CCM,
|
|
dtls.TLS_PSK_WITH_AES_128_CCM_8,
|
|
dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
|
|
} {
|
|
cipherSuite := cipherSuite
|
|
t.Run(cipherSuite.String(), func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
cfg := &dtls.Config{
|
|
PSK: func(hint []byte) ([]byte, error) {
|
|
return []byte{0xAB, 0xC1, 0x23}, nil
|
|
},
|
|
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
|
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
|
}
|
|
serverPort := randomPort(t)
|
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
|
comm.assert(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
|
|
lim := test.TimeOut(time.Second * 30)
|
|
defer lim.Stop()
|
|
|
|
report := test.CheckRoutines(t)
|
|
defer report()
|
|
|
|
for _, mtu := range []int{
|
|
10000,
|
|
1000,
|
|
100,
|
|
} {
|
|
mtu := mtu
|
|
t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cfg := &dtls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
|
InsecureSkipVerify: true,
|
|
MTU: mtu,
|
|
}
|
|
serverPort := randomPort(t)
|
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
|
comm.assert(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPionE2ESimple(t *testing.T) {
|
|
testPionE2ESimple(t, serverPion, clientPion)
|
|
}
|
|
|
|
func TestPionE2ESimplePSK(t *testing.T) {
|
|
testPionE2ESimplePSK(t, serverPion, clientPion)
|
|
}
|
|
|
|
func TestPionE2EMTUs(t *testing.T) {
|
|
testPionE2EMTUs(t, serverPion, clientPion)
|
|
}
|