mirror of https://github.com/bjdgyc/anylink.git
278 lines
7.3 KiB
Go
278 lines
7.3 KiB
Go
package dtls
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
|
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
|
"github.com/pion/logging"
|
|
"github.com/pion/transport/test"
|
|
)
|
|
|
|
const nonZeroRetransmitInterval = 100 * time.Millisecond
|
|
|
|
// Test that writes to the key log are in the correct format and only applies
|
|
// when a key log writer is given.
|
|
func TestWriteKeyLog(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
cfg := handshakeConfig{
|
|
keyLogWriter: &buf,
|
|
}
|
|
cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
|
|
|
|
// Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret>
|
|
// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
|
|
want := "LABEL aabbcc ddeeff\n"
|
|
if buf.String() != want {
|
|
t.Fatalf("Got %s want %s", buf.String(), want)
|
|
}
|
|
|
|
// no key log writer = no writes
|
|
cfg = handshakeConfig{}
|
|
cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
|
|
}
|
|
|
|
func TestHandshaker(t *testing.T) {
|
|
// Check for leaking routines
|
|
report := test.CheckRoutines(t)
|
|
defer report()
|
|
|
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
|
logger := loggerFactory.NewLogger("dtls")
|
|
|
|
cipherSuites, err := parseCipherSuites(nil, nil, true, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientCert, err := selfsign.GenerateSelfSigned()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){
|
|
"PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) {
|
|
return nil, nil, nil
|
|
},
|
|
"HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) {
|
|
var (
|
|
cntHelloVerifyRequest = 0
|
|
cntClientHelloNoCookie = 0
|
|
)
|
|
const helloVerifyDrop = 5
|
|
return func(p *packet) bool {
|
|
h, ok := p.record.Content.(*handshake.Handshake)
|
|
if !ok {
|
|
return true
|
|
}
|
|
if hmch, ok := h.Message.(*handshake.MessageClientHello); ok {
|
|
if len(hmch.Cookie) == 0 {
|
|
cntClientHelloNoCookie++
|
|
}
|
|
}
|
|
return true
|
|
},
|
|
func(p *packet) bool {
|
|
h, ok := p.record.Content.(*handshake.Handshake)
|
|
if !ok {
|
|
return true
|
|
}
|
|
if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok {
|
|
cntHelloVerifyRequest++
|
|
return cntHelloVerifyRequest > helloVerifyDrop
|
|
}
|
|
return true
|
|
},
|
|
func(t *testing.T) {
|
|
if cntHelloVerifyRequest != helloVerifyDrop+1 {
|
|
t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
|
|
}
|
|
if cntClientHelloNoCookie != cntHelloVerifyRequest {
|
|
t.Errorf(
|
|
"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
|
|
cntHelloVerifyRequest, cntClientHelloNoCookie,
|
|
)
|
|
}
|
|
}
|
|
},
|
|
}
|
|
|
|
for name, filters := range genFilters {
|
|
f1, f2, report := filters()
|
|
t.Run(name, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
if report != nil {
|
|
defer report(t)
|
|
}
|
|
|
|
ca, cb := flightTestPipe(ctx, f1, f2)
|
|
ca.state.isClient = true
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
ctxCliFinished, cancelCli := context.WithCancel(ctx)
|
|
ctxSrvFinished, cancelSrv := context.WithCancel(ctx)
|
|
go func() {
|
|
defer wg.Done()
|
|
cfg := &handshakeConfig{
|
|
localCipherSuites: cipherSuites,
|
|
localCertificates: []tls.Certificate{clientCert},
|
|
localSignatureSchemes: signaturehash.Algorithms(),
|
|
insecureSkipVerify: true,
|
|
log: logger,
|
|
onFlightState: func(f flightVal, s handshakeState) {
|
|
if s == handshakeFinished {
|
|
cancelCli()
|
|
}
|
|
},
|
|
retransmitInterval: nonZeroRetransmitInterval,
|
|
}
|
|
|
|
fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
|
|
switch err := fsm.Run(ctx, ca, handshakePreparing); err {
|
|
case context.Canceled:
|
|
case context.DeadlineExceeded:
|
|
t.Error("Timeout")
|
|
default:
|
|
t.Error(err)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
cfg := &handshakeConfig{
|
|
localCipherSuites: cipherSuites,
|
|
localCertificates: []tls.Certificate{clientCert},
|
|
localSignatureSchemes: signaturehash.Algorithms(),
|
|
insecureSkipVerify: true,
|
|
log: logger,
|
|
onFlightState: func(f flightVal, s handshakeState) {
|
|
if s == handshakeFinished {
|
|
cancelSrv()
|
|
}
|
|
},
|
|
retransmitInterval: nonZeroRetransmitInterval,
|
|
}
|
|
|
|
fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
|
|
switch err := fsm.Run(ctx, cb, handshakePreparing); err {
|
|
case context.Canceled:
|
|
case context.DeadlineExceeded:
|
|
t.Error("Timeout")
|
|
default:
|
|
t.Error(err)
|
|
}
|
|
}()
|
|
|
|
<-ctxCliFinished.Done()
|
|
<-ctxSrvFinished.Done()
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
})
|
|
}
|
|
}
|
|
|
|
type packetFilter func(*packet) bool
|
|
|
|
func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) {
|
|
ca := newHandshakeCache()
|
|
cb := newHandshakeCache()
|
|
chA := make(chan chan struct{})
|
|
chB := make(chan chan struct{})
|
|
return &flightTestConn{
|
|
handshakeCache: ca,
|
|
otherEndCache: cb,
|
|
recv: chA,
|
|
otherEndRecv: chB,
|
|
done: ctx.Done(),
|
|
filter: filter1,
|
|
}, &flightTestConn{
|
|
handshakeCache: cb,
|
|
otherEndCache: ca,
|
|
recv: chB,
|
|
otherEndRecv: chA,
|
|
done: ctx.Done(),
|
|
filter: filter2,
|
|
}
|
|
}
|
|
|
|
type flightTestConn struct {
|
|
state State
|
|
handshakeCache *handshakeCache
|
|
recv chan chan struct{}
|
|
done <-chan struct{}
|
|
epoch uint16
|
|
|
|
filter packetFilter
|
|
|
|
otherEndCache *handshakeCache
|
|
otherEndRecv chan chan struct{}
|
|
}
|
|
|
|
func (c *flightTestConn) recvHandshake() <-chan chan struct{} {
|
|
return c.recv
|
|
}
|
|
|
|
func (c *flightTestConn) setLocalEpoch(epoch uint16) {
|
|
c.epoch = epoch
|
|
}
|
|
|
|
func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error {
|
|
for _, p := range pkts {
|
|
if c.filter != nil && !c.filter(p) {
|
|
continue
|
|
}
|
|
if h, ok := p.record.Content.(*handshake.Handshake); ok {
|
|
handshakeRaw, err := p.record.Marshal()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
|
|
|
|
content, err := h.Message.Marshal()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.Header.Length = uint32(len(content))
|
|
h.Header.FragmentLength = uint32(len(content))
|
|
hdr, err := h.Header.Marshal()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.otherEndCache.push(
|
|
append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
|
|
}
|
|
}
|
|
go func() {
|
|
select {
|
|
case c.otherEndRecv <- make(chan struct{}):
|
|
case <-c.done:
|
|
}
|
|
}()
|
|
|
|
// Avoid deadlock on JS/WASM environment due to context switch problem.
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error {
|
|
return nil
|
|
}
|