mirror of https://github.com/bjdgyc/anylink.git
209 lines
3.9 KiB
Go
209 lines
3.9 KiB
Go
package dtls
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
|
"github.com/pion/transport/test"
|
|
)
|
|
|
|
var errMessageMissmatch = errors.New("messages missmatch")
|
|
|
|
func TestResumeClient(t *testing.T) {
|
|
DoTestResume(t, Client, Server)
|
|
}
|
|
|
|
func TestResumeServer(t *testing.T) {
|
|
DoTestResume(t, Server, Client)
|
|
}
|
|
|
|
func fatal(t *testing.T, errChan chan error, err error) {
|
|
close(errChan)
|
|
t.Fatal(err)
|
|
}
|
|
|
|
func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) {
|
|
// Limit runtime in case of deadlocks
|
|
lim := test.TimeOut(time.Second * 20)
|
|
defer lim.Stop()
|
|
|
|
// Check for leaking routines
|
|
report := test.CheckRoutines(t)
|
|
defer report()
|
|
|
|
certificate, err := selfsign.GenerateSelfSigned()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Generate connections
|
|
localConn1, rc1 := net.Pipe()
|
|
localConn2, rc2 := net.Pipe()
|
|
remoteConn := &backupConn{curr: rc1, next: rc2}
|
|
|
|
// Launch remote in another goroutine
|
|
errChan := make(chan error, 1)
|
|
defer func() {
|
|
err = <-errChan
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
config := &Config{
|
|
Certificates: []tls.Certificate{certificate},
|
|
InsecureSkipVerify: true,
|
|
ExtendedMasterSecret: RequireExtendedMasterSecret,
|
|
}
|
|
go func() {
|
|
var remote *Conn
|
|
var errR error
|
|
remote, errR = newRemote(remoteConn, config)
|
|
if errR != nil {
|
|
errChan <- errR
|
|
}
|
|
|
|
// Loop of read write
|
|
for i := 0; i < 2; i++ {
|
|
recv := make([]byte, 1024)
|
|
var n int
|
|
n, errR = remote.Read(recv)
|
|
if errR != nil {
|
|
errChan <- errR
|
|
}
|
|
|
|
if _, errR = remote.Write(recv[:n]); errR != nil {
|
|
errChan <- errR
|
|
}
|
|
}
|
|
errChan <- nil
|
|
}()
|
|
|
|
var local *Conn
|
|
local, err = newLocal(localConn1, config)
|
|
if err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
defer func() {
|
|
_ = local.Close()
|
|
}()
|
|
|
|
// Test write and read
|
|
message := []byte("Hello")
|
|
if _, err = local.Write(message); err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
recv := make([]byte, 1024)
|
|
var n int
|
|
n, err = local.Read(recv)
|
|
if err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
if !bytes.Equal(message, recv[:n]) {
|
|
fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
|
|
}
|
|
|
|
if err = localConn1.Close(); err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
// Serialize and deserialize state
|
|
state := local.ConnectionState()
|
|
var b []byte
|
|
b, err = state.MarshalBinary()
|
|
if err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
deserialized := &State{}
|
|
if err = deserialized.UnmarshalBinary(b); err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
// Resume dtls connection
|
|
var resumed net.Conn
|
|
resumed, err = Resume(deserialized, localConn2, config)
|
|
if err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
defer func() {
|
|
_ = resumed.Close()
|
|
}()
|
|
|
|
// Test write and read on resumed connection
|
|
if _, err = resumed.Write(message); err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
recv = make([]byte, 1024)
|
|
n, err = resumed.Read(recv)
|
|
if err != nil {
|
|
fatal(t, errChan, err)
|
|
}
|
|
|
|
if !bytes.Equal(message, recv[:n]) {
|
|
fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
|
|
}
|
|
}
|
|
|
|
type backupConn struct {
|
|
curr net.Conn
|
|
next net.Conn
|
|
mux sync.Mutex
|
|
}
|
|
|
|
func (b *backupConn) Read(data []byte) (n int, err error) {
|
|
n, err = b.curr.Read(data)
|
|
if err != nil && b.next != nil {
|
|
b.mux.Lock()
|
|
b.curr = b.next
|
|
b.next = nil
|
|
b.mux.Unlock()
|
|
return b.Read(data)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (b *backupConn) Write(data []byte) (n int, err error) {
|
|
n, err = b.curr.Write(data)
|
|
if err != nil && b.next != nil {
|
|
b.mux.Lock()
|
|
b.curr = b.next
|
|
b.next = nil
|
|
b.mux.Unlock()
|
|
return b.Write(data)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (b *backupConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (b *backupConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (b *backupConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (b *backupConn) SetDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (b *backupConn) SetReadDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (b *backupConn) SetWriteDeadline(t time.Time) error {
|
|
return nil
|
|
}
|