package dtls

import (
	"context"
	"net"
	"reflect"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/pion/dtls/v2/internal/net/dpipe"
	"github.com/pion/transport/test"
)

func TestReplayProtection(t *testing.T) {
	// Limit runtime in case of deadlocks
	lim := test.TimeOut(5 * time.Second)
	defer lim.Stop()

	// Check for leaking routines
	report := test.CheckRoutines(t)
	defer report()

	c0, c1 := dpipe.Pipe()
	c2, c3 := dpipe.Pipe()
	conn := []net.Conn{c0, c1, c2, c3}

	var wgRoutines sync.WaitGroup
	var cntReplays int32 = 1

	ctxReplayDone, replayDone := context.WithCancel(context.Background())

	replaySendDone := func() {
		cnt := atomic.AddInt32(&cntReplays, -1)
		if cnt == 0 {
			replayDone()
		}
	}

	replayer := func(ca, cb net.Conn) {
		defer wgRoutines.Done()
		// Man in the middle
		for {
			b := make([]byte, 2048)
			n, rerr := ca.Read(b)
			if rerr != nil {
				return
			}
			if _, werr := cb.Write(b[:n]); werr != nil {
				t.Error(werr)
				return
			}

			atomic.AddInt32(&cntReplays, 1)
			go func() {
				defer replaySendDone()
				// Replay bit later
				time.Sleep(time.Millisecond)
				if _, werr := cb.Write(b[:n]); werr != nil {
					t.Error(werr)
				}
			}()
		}
	}
	wgRoutines.Add(2)
	go replayer(conn[1], conn[2])
	go replayer(conn[2], conn[1])

	ca, cb, err := pipeConn(conn[0], conn[3])
	if err != nil {
		t.Fatal(err)
	}

	const numMsgs = 10

	var received [2][][]byte
	for i, c := range []net.Conn{ca, cb} {
		i := i
		c := c
		wgRoutines.Add(1)
		atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
		var lastMsgDone sync.Once
		go func() {
			defer wgRoutines.Done()
			for {
				b := make([]byte, 2048)
				n, rerr := c.Read(b)
				if rerr != nil {
					return
				}
				received[i] = append(received[i], b[:n])
				if b[0] == numMsgs-1 {
					// Final message received
					lastMsgDone.Do(func() {
						defer replaySendDone()
					})
				}
			}
		}()
	}

	var sent [][]byte
	for i := 0; i < numMsgs; i++ {
		data := []byte{byte(i)}
		sent = append(sent, data)
		if _, werr := ca.Write(data); werr != nil {
			t.Error(werr)
			return
		}
		if _, werr := cb.Write(data); werr != nil {
			t.Error(werr)
			return
		}
	}

	replaySendDone()
	<-ctxReplayDone.Done()
	time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent

	for i := 0; i < 4; i++ {
		if err := conn[i].Close(); err != nil {
			t.Error(err)
		}
	}
	if err := ca.Close(); err != nil {
		t.Error(err)
	}
	if err := cb.Close(); err != nil {
		t.Error(err)
	}
	wgRoutines.Wait()

	for _, r := range received {
		if !reflect.DeepEqual(sent, r) {
			t.Errorf("Received data differs, expected: %v, got: %v", sent, r)
		}
	}
}