package dtls import ( "bytes" "context" "crypto/tls" "sync" "testing" "time" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/test" ) const nonZeroRetransmitInterval = 100 * time.Millisecond // Test that writes to the key log are in the correct format and only applies // when a key log writer is given. func TestWriteKeyLog(t *testing.T) { var buf bytes.Buffer cfg := handshakeConfig{ keyLogWriter: &buf, } cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) // Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret> // https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format want := "LABEL aabbcc ddeeff\n" if buf.String() != want { t.Fatalf("Got %s want %s", buf.String(), want) } // no key log writer = no writes cfg = handshakeConfig{} cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) } func TestHandshaker(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() loggerFactory := logging.NewDefaultLoggerFactory() logger := loggerFactory.NewLogger("dtls") cipherSuites, err := parseCipherSuites(nil, nil, true, false) if err != nil { t.Fatal(err) } clientCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){ "PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) { return nil, nil, nil }, "HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) { var ( cntHelloVerifyRequest = 0 cntClientHelloNoCookie = 0 ) const helloVerifyDrop = 5 return func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if hmch, ok := h.Message.(*handshake.MessageClientHello); ok { if len(hmch.Cookie) == 0 { cntClientHelloNoCookie++ } } return true }, func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { cntHelloVerifyRequest++ return cntHelloVerifyRequest > helloVerifyDrop } return true }, func(t *testing.T) { if cntHelloVerifyRequest != helloVerifyDrop+1 { t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) } if cntClientHelloNoCookie != cntHelloVerifyRequest { t.Errorf( "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", cntHelloVerifyRequest, cntClientHelloNoCookie, ) } } }, } for name, filters := range genFilters { f1, f2, report := filters() t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if report != nil { defer report(t) } ca, cb := flightTestPipe(ctx, f1, f2) ca.state.isClient = true var wg sync.WaitGroup wg.Add(2) ctxCliFinished, cancelCli := context.WithCancel(ctx) ctxSrvFinished, cancelSrv := context.WithCancel(ctx) go func() { defer wg.Done() cfg := &handshakeConfig{ localCipherSuites: cipherSuites, localCertificates: []tls.Certificate{clientCert}, localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, onFlightState: func(f flightVal, s handshakeState) { if s == handshakeFinished { cancelCli() } }, retransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) switch err := fsm.Run(ctx, ca, handshakePreparing); err { case context.Canceled: case context.DeadlineExceeded: t.Error("Timeout") default: t.Error(err) } }() go func() { defer wg.Done() cfg := &handshakeConfig{ localCipherSuites: cipherSuites, localCertificates: []tls.Certificate{clientCert}, localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, onFlightState: func(f flightVal, s handshakeState) { if s == handshakeFinished { cancelSrv() } }, retransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) switch err := fsm.Run(ctx, cb, handshakePreparing); err { case context.Canceled: case context.DeadlineExceeded: t.Error("Timeout") default: t.Error(err) } }() <-ctxCliFinished.Done() <-ctxSrvFinished.Done() cancel() wg.Wait() }) } } type packetFilter func(*packet) bool func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() chA := make(chan chan struct{}) chB := make(chan chan struct{}) return &flightTestConn{ handshakeCache: ca, otherEndCache: cb, recv: chA, otherEndRecv: chB, done: ctx.Done(), filter: filter1, }, &flightTestConn{ handshakeCache: cb, otherEndCache: ca, recv: chB, otherEndRecv: chA, done: ctx.Done(), filter: filter2, } } type flightTestConn struct { state State handshakeCache *handshakeCache recv chan chan struct{} done <-chan struct{} epoch uint16 filter packetFilter otherEndCache *handshakeCache otherEndRecv chan chan struct{} } func (c *flightTestConn) recvHandshake() <-chan chan struct{} { return c.recv } func (c *flightTestConn) setLocalEpoch(epoch uint16) { c.epoch = epoch } func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { return nil } func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error { for _, p := range pkts { if c.filter != nil && !c.filter(p) { continue } if h, ok := p.record.Content.(*handshake.Handshake); ok { handshakeRaw, err := p.record.Marshal() if err != nil { return err } c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) content, err := h.Message.Marshal() if err != nil { return err } h.Header.Length = uint32(len(content)) h.Header.FragmentLength = uint32(len(content)) hdr, err := h.Header.Marshal() if err != nil { return err } c.otherEndCache.push( append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) } } go func() { select { case c.otherEndRecv <- make(chan struct{}): case <-c.done: } }() // Avoid deadlock on JS/WASM environment due to context switch problem. time.Sleep(10 * time.Millisecond) return nil } func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error { return nil }