mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 21:02:33 +08:00
添加 github.com/pion/dtls 代码
This commit is contained in:
329
dtls-2.0.9/e2e/e2e_test.go
Normal file
329
dtls-2.0.9/e2e/e2e_test.go
Normal file
@@ -0,0 +1,329 @@
|
||||
// +build !js
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v2"
|
||||
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||
"github.com/pion/transport/test"
|
||||
)
|
||||
|
||||
const (
|
||||
testMessage = "Hello World"
|
||||
testTimeLimit = 5 * time.Second
|
||||
messageRetry = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
var errServerTimeout = errors.New("waiting on serverReady err: timeout")
|
||||
|
||||
func randomPort(t testing.TB) int {
|
||||
t.Helper()
|
||||
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to pickPort: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
switch addr := conn.LocalAddr().(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.Port
|
||||
default:
|
||||
t.Fatalf("unknown addr type %T", addr)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
|
||||
go func() {
|
||||
buffer := make([]byte, 8192)
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
outChan <- string(buffer[:n])
|
||||
atomic.AddUint64(messageRecvCount, 1)
|
||||
}()
|
||||
|
||||
for {
|
||||
if atomic.LoadUint64(messageRecvCount) == 2 {
|
||||
break
|
||||
} else if _, err := conn.Write([]byte(testMessage)); err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(messageRetry)
|
||||
}
|
||||
}
|
||||
|
||||
type comm struct {
|
||||
ctx context.Context
|
||||
clientConfig, serverConfig *dtls.Config
|
||||
serverPort int
|
||||
messageRecvCount *uint64 // Counter to make sure both sides got a message
|
||||
clientMutex *sync.Mutex
|
||||
clientConn net.Conn
|
||||
serverMutex *sync.Mutex
|
||||
serverConn net.Conn
|
||||
serverListener net.Listener
|
||||
serverReady chan struct{}
|
||||
errChan chan error
|
||||
clientChan chan string
|
||||
serverChan chan string
|
||||
client func(*comm)
|
||||
server func(*comm)
|
||||
}
|
||||
|
||||
func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
|
||||
messageRecvCount := uint64(0)
|
||||
c := &comm{
|
||||
ctx: ctx,
|
||||
clientConfig: clientConfig,
|
||||
serverConfig: serverConfig,
|
||||
serverPort: serverPort,
|
||||
messageRecvCount: &messageRecvCount,
|
||||
clientMutex: &sync.Mutex{},
|
||||
serverMutex: &sync.Mutex{},
|
||||
serverReady: make(chan struct{}),
|
||||
errChan: make(chan error),
|
||||
clientChan: make(chan string),
|
||||
serverChan: make(chan string),
|
||||
server: server,
|
||||
client: client,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *comm) assert(t *testing.T) {
|
||||
// DTLS Client
|
||||
go c.client(c)
|
||||
|
||||
// DTLS Server
|
||||
go c.server(c)
|
||||
|
||||
defer func() {
|
||||
if c.clientConn != nil {
|
||||
if err := c.clientConn.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if c.serverConn != nil {
|
||||
if err := c.serverConn.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if c.serverListener != nil {
|
||||
if err := c.serverListener.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
func() {
|
||||
seenClient, seenServer := false, false
|
||||
for {
|
||||
select {
|
||||
case err := <-c.errChan:
|
||||
t.Fatal(err)
|
||||
case <-time.After(testTimeLimit):
|
||||
t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
|
||||
case clientMsg := <-c.clientChan:
|
||||
if clientMsg != testMessage {
|
||||
t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
|
||||
}
|
||||
|
||||
seenClient = true
|
||||
if seenClient && seenServer {
|
||||
return
|
||||
}
|
||||
case serverMsg := <-c.serverChan:
|
||||
if serverMsg != testMessage {
|
||||
t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
|
||||
}
|
||||
|
||||
seenServer = true
|
||||
if seenClient && seenServer {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func clientPion(c *comm) {
|
||||
select {
|
||||
case <-c.serverReady:
|
||||
// OK
|
||||
case <-time.After(time.Second):
|
||||
c.errChan <- errServerTimeout
|
||||
}
|
||||
|
||||
c.clientMutex.Lock()
|
||||
defer c.clientMutex.Unlock()
|
||||
|
||||
var err error
|
||||
c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
|
||||
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
||||
c.clientConfig,
|
||||
)
|
||||
if err != nil {
|
||||
c.errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
|
||||
}
|
||||
|
||||
func serverPion(c *comm) {
|
||||
c.serverMutex.Lock()
|
||||
defer c.serverMutex.Unlock()
|
||||
|
||||
var err error
|
||||
c.serverListener, err = dtls.Listen("udp",
|
||||
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
||||
c.serverConfig,
|
||||
)
|
||||
if err != nil {
|
||||
c.errChan <- err
|
||||
return
|
||||
}
|
||||
c.serverReady <- struct{}{}
|
||||
c.serverConn, err = c.serverListener.Accept()
|
||||
if err != nil {
|
||||
c.errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
|
||||
}
|
||||
|
||||
/*
|
||||
Simple DTLS Client/Server can communicate
|
||||
- Assert that you can send messages both ways
|
||||
- Assert that Close() on both ends work
|
||||
- Assert that no Goroutines are leaked
|
||||
*/
|
||||
func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
for _, cipherSuite := range []dtls.CipherSuiteID{
|
||||
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
} {
|
||||
cipherSuite := cipherSuite
|
||||
t.Run(cipherSuite.String(), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg := &dtls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
serverPort := randomPort(t)
|
||||
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||
comm.assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
for _, cipherSuite := range []dtls.CipherSuiteID{
|
||||
dtls.TLS_PSK_WITH_AES_128_CCM,
|
||||
dtls.TLS_PSK_WITH_AES_128_CCM_8,
|
||||
dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
|
||||
} {
|
||||
cipherSuite := cipherSuite
|
||||
t.Run(cipherSuite.String(), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := &dtls.Config{
|
||||
PSK: func(hint []byte) ([]byte, error) {
|
||||
return []byte{0xAB, 0xC1, 0x23}, nil
|
||||
},
|
||||
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
||||
}
|
||||
serverPort := randomPort(t)
|
||||
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||
comm.assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
for _, mtu := range []int{
|
||||
10000,
|
||||
1000,
|
||||
100,
|
||||
} {
|
||||
mtu := mtu
|
||||
t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg := &dtls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
InsecureSkipVerify: true,
|
||||
MTU: mtu,
|
||||
}
|
||||
serverPort := randomPort(t)
|
||||
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||
comm.assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPionE2ESimple(t *testing.T) {
|
||||
testPionE2ESimple(t, serverPion, clientPion)
|
||||
}
|
||||
|
||||
func TestPionE2ESimplePSK(t *testing.T) {
|
||||
testPionE2ESimplePSK(t, serverPion, clientPion)
|
||||
}
|
||||
|
||||
func TestPionE2EMTUs(t *testing.T) {
|
||||
testPionE2EMTUs(t, serverPion, clientPion)
|
||||
}
|
Reference in New Issue
Block a user