mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 16:43:25 +08:00
添加 github.com/pion/dtls 代码
This commit is contained in:
343
dtls-2.0.9/handshaker.go
Normal file
343
dtls-2.0.9/handshaker.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package dtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
||||
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||
"github.com/pion/logging"
|
||||
)
|
||||
|
||||
// [RFC6347 Section-4.2.4]
|
||||
// +-----------+
|
||||
// +---> | PREPARING | <--------------------+
|
||||
// | +-----------+ |
|
||||
// | | |
|
||||
// | | Buffer next flight |
|
||||
// | | |
|
||||
// | \|/ |
|
||||
// | +-----------+ |
|
||||
// | | SENDING |<------------------+ | Send
|
||||
// | +-----------+ | | HelloRequest
|
||||
// Receive | | | |
|
||||
// next | | Send flight | | or
|
||||
// flight | +--------+ | |
|
||||
// | | | Set retransmit timer | | Receive
|
||||
// | | \|/ | | HelloRequest
|
||||
// | | +-----------+ | | Send
|
||||
// +--)--| WAITING |-------------------+ | ClientHello
|
||||
// | | +-----------+ Timer expires | |
|
||||
// | | | | |
|
||||
// | | +------------------------+ |
|
||||
// Receive | | Send Read retransmit |
|
||||
// last | | last |
|
||||
// flight | | flight |
|
||||
// | | |
|
||||
// \|/\|/ |
|
||||
// +-----------+ |
|
||||
// | FINISHED | -------------------------------+
|
||||
// +-----------+
|
||||
// | /|\
|
||||
// | |
|
||||
// +---+
|
||||
// Read retransmit
|
||||
// Retransmit last flight
|
||||
|
||||
type handshakeState uint8
|
||||
|
||||
const (
|
||||
handshakeErrored handshakeState = iota
|
||||
handshakePreparing
|
||||
handshakeSending
|
||||
handshakeWaiting
|
||||
handshakeFinished
|
||||
)
|
||||
|
||||
func (s handshakeState) String() string {
|
||||
switch s {
|
||||
case handshakeErrored:
|
||||
return "Errored"
|
||||
case handshakePreparing:
|
||||
return "Preparing"
|
||||
case handshakeSending:
|
||||
return "Sending"
|
||||
case handshakeWaiting:
|
||||
return "Waiting"
|
||||
case handshakeFinished:
|
||||
return "Finished"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type handshakeFSM struct {
|
||||
currentFlight flightVal
|
||||
flights []*packet
|
||||
retransmit bool
|
||||
state *State
|
||||
cache *handshakeCache
|
||||
cfg *handshakeConfig
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type handshakeConfig struct {
|
||||
localPSKCallback PSKCallback
|
||||
localPSKIdentityHint []byte
|
||||
localCiscoCompatCallback PSKCallback // TODO add cisco anyconnect support
|
||||
localCipherSuites []CipherSuite // Available CipherSuites
|
||||
localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
|
||||
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
|
||||
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
|
||||
serverName string
|
||||
clientAuth ClientAuthType // If we are a client should we request a client certificate
|
||||
localCertificates []tls.Certificate
|
||||
nameToCertificate map[string]*tls.Certificate
|
||||
insecureSkipVerify bool
|
||||
verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||
rootCAs *x509.CertPool
|
||||
clientCAs *x509.CertPool
|
||||
retransmitInterval time.Duration
|
||||
customCipherSuites func() []CipherSuite
|
||||
|
||||
onFlightState func(flightVal, handshakeState)
|
||||
log logging.LeveledLogger
|
||||
keyLogWriter io.Writer
|
||||
|
||||
initialEpoch uint16
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type flightConn interface {
|
||||
notify(ctx context.Context, level alert.Level, desc alert.Description) error
|
||||
writePackets(context.Context, []*packet) error
|
||||
recvHandshake() <-chan chan struct{}
|
||||
setLocalEpoch(epoch uint16)
|
||||
handleQueuedPackets(context.Context) error
|
||||
}
|
||||
|
||||
func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
|
||||
if c.keyLogWriter == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
|
||||
if err != nil {
|
||||
c.log.Debugf("failed to write key log file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func srvCliStr(isClient bool) string {
|
||||
if isClient {
|
||||
return "client"
|
||||
}
|
||||
return "server"
|
||||
}
|
||||
|
||||
func newHandshakeFSM(
|
||||
s *State, cache *handshakeCache, cfg *handshakeConfig,
|
||||
initialFlight flightVal,
|
||||
) *handshakeFSM {
|
||||
return &handshakeFSM{
|
||||
currentFlight: initialFlight,
|
||||
state: s,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
|
||||
state := initialState
|
||||
defer func() {
|
||||
close(s.closed)
|
||||
}()
|
||||
for {
|
||||
s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
|
||||
if s.cfg.onFlightState != nil {
|
||||
s.cfg.onFlightState(s.currentFlight, state)
|
||||
}
|
||||
var err error
|
||||
switch state {
|
||||
case handshakePreparing:
|
||||
state, err = s.prepare(ctx, c)
|
||||
case handshakeSending:
|
||||
state, err = s.send(ctx, c)
|
||||
case handshakeWaiting:
|
||||
state, err = s.wait(ctx, c)
|
||||
case handshakeFinished:
|
||||
state, err = s.finish(ctx, c)
|
||||
default:
|
||||
return errInvalidFSMTransition
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO 添加 CiscoCompat 支持
|
||||
if s.cfg.localCiscoCompatCallback != nil {
|
||||
if s.currentFlight == flight4 && state == handshakeWaiting {
|
||||
s.currentFlight = flight6
|
||||
state = handshakePreparing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) Done() <-chan struct{} {
|
||||
return s.closed
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
|
||||
s.flights = nil
|
||||
// Prepare flights
|
||||
var (
|
||||
a *alert.Alert
|
||||
err error
|
||||
pkts []*packet
|
||||
)
|
||||
gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
|
||||
if errFlight != nil {
|
||||
err = errFlight
|
||||
a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
|
||||
} else {
|
||||
pkts, a, err = gen(c, s.state, s.cache, s.cfg)
|
||||
s.retransmit = retransmit
|
||||
}
|
||||
if a != nil {
|
||||
if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
|
||||
if err != nil {
|
||||
err = alertErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return handshakeErrored, err
|
||||
}
|
||||
|
||||
s.flights = pkts
|
||||
epoch := s.cfg.initialEpoch
|
||||
nextEpoch := epoch
|
||||
for _, p := range s.flights {
|
||||
p.record.Header.Epoch += epoch
|
||||
if p.record.Header.Epoch > nextEpoch {
|
||||
nextEpoch = p.record.Header.Epoch
|
||||
}
|
||||
if h, ok := p.record.Content.(*handshake.Handshake); ok {
|
||||
h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
|
||||
s.state.handshakeSendSequence++
|
||||
}
|
||||
}
|
||||
if epoch != nextEpoch {
|
||||
s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
|
||||
c.setLocalEpoch(nextEpoch)
|
||||
}
|
||||
return handshakeSending, nil
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
|
||||
// Send flights
|
||||
if err := c.writePackets(ctx, s.flights); err != nil {
|
||||
return handshakeErrored, err
|
||||
}
|
||||
|
||||
if s.currentFlight.isLastSendFlight() {
|
||||
return handshakeFinished, nil
|
||||
}
|
||||
return handshakeWaiting, nil
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
|
||||
parse, errFlight := s.currentFlight.getFlightParser()
|
||||
if errFlight != nil {
|
||||
if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
|
||||
if errFlight != nil {
|
||||
return handshakeErrored, alertErr
|
||||
}
|
||||
}
|
||||
return handshakeErrored, errFlight
|
||||
}
|
||||
|
||||
retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
|
||||
for {
|
||||
select {
|
||||
case done := <-c.recvHandshake():
|
||||
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
|
||||
close(done)
|
||||
if alert != nil {
|
||||
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
|
||||
if err != nil {
|
||||
err = alertErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return handshakeErrored, err
|
||||
}
|
||||
if nextFlight == 0 {
|
||||
break
|
||||
}
|
||||
s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
|
||||
if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
|
||||
return handshakeFinished, nil
|
||||
}
|
||||
s.currentFlight = nextFlight
|
||||
return handshakePreparing, nil
|
||||
|
||||
case <-retransmitTimer.C:
|
||||
if !s.retransmit {
|
||||
return handshakeWaiting, nil
|
||||
}
|
||||
return handshakeSending, nil
|
||||
case <-ctx.Done():
|
||||
return handshakeErrored, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
|
||||
parse, errFlight := s.currentFlight.getFlightParser()
|
||||
if errFlight != nil {
|
||||
if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
|
||||
if errFlight != nil {
|
||||
return handshakeErrored, alertErr
|
||||
}
|
||||
}
|
||||
return handshakeErrored, errFlight
|
||||
}
|
||||
|
||||
retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
|
||||
select {
|
||||
case done := <-c.recvHandshake():
|
||||
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
|
||||
close(done)
|
||||
if alert != nil {
|
||||
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
|
||||
if err != nil {
|
||||
err = alertErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return handshakeErrored, err
|
||||
}
|
||||
if nextFlight == 0 {
|
||||
break
|
||||
}
|
||||
<-retransmitTimer.C
|
||||
// Retransmit last flight
|
||||
return handshakeSending, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return handshakeErrored, ctx.Err()
|
||||
}
|
||||
return handshakeFinished, nil
|
||||
}
|
Reference in New Issue
Block a user