添加 github.com/pion/dtls 代码

This commit is contained in:
bjdgyc
2021-05-21 19:03:00 +08:00
parent 54a0cb7928
commit 28b5119f50
380 changed files with 16870 additions and 0 deletions

View File

@@ -0,0 +1,160 @@
// Package alert implements TLS alert protocol https://tools.ietf.org/html/rfc5246#section-7.2
package alert
import (
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/protocol"
)
var errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
// Level is the level of the TLS Alert
type Level byte
// Level enums
const (
Warning Level = 1
Fatal Level = 2
)
func (l Level) String() string {
switch l {
case Warning:
return "Warning"
case Fatal:
return "Fatal"
default:
return "Invalid alert level"
}
}
// Description is the extended info of the TLS Alert
type Description byte
// Description enums
const (
CloseNotify Description = 0
UnexpectedMessage Description = 10
BadRecordMac Description = 20
DecryptionFailed Description = 21
RecordOverflow Description = 22
DecompressionFailure Description = 30
HandshakeFailure Description = 40
NoCertificate Description = 41
BadCertificate Description = 42
UnsupportedCertificate Description = 43
CertificateRevoked Description = 44
CertificateExpired Description = 45
CertificateUnknown Description = 46
IllegalParameter Description = 47
UnknownCA Description = 48
AccessDenied Description = 49
DecodeError Description = 50
DecryptError Description = 51
ExportRestriction Description = 60
ProtocolVersion Description = 70
InsufficientSecurity Description = 71
InternalError Description = 80
UserCanceled Description = 90
NoRenegotiation Description = 100
UnsupportedExtension Description = 110
)
func (d Description) String() string {
switch d {
case CloseNotify:
return "CloseNotify"
case UnexpectedMessage:
return "UnexpectedMessage"
case BadRecordMac:
return "BadRecordMac"
case DecryptionFailed:
return "DecryptionFailed"
case RecordOverflow:
return "RecordOverflow"
case DecompressionFailure:
return "DecompressionFailure"
case HandshakeFailure:
return "HandshakeFailure"
case NoCertificate:
return "NoCertificate"
case BadCertificate:
return "BadCertificate"
case UnsupportedCertificate:
return "UnsupportedCertificate"
case CertificateRevoked:
return "CertificateRevoked"
case CertificateExpired:
return "CertificateExpired"
case CertificateUnknown:
return "CertificateUnknown"
case IllegalParameter:
return "IllegalParameter"
case UnknownCA:
return "UnknownCA"
case AccessDenied:
return "AccessDenied"
case DecodeError:
return "DecodeError"
case DecryptError:
return "DecryptError"
case ExportRestriction:
return "ExportRestriction"
case ProtocolVersion:
return "ProtocolVersion"
case InsufficientSecurity:
return "InsufficientSecurity"
case InternalError:
return "InternalError"
case UserCanceled:
return "UserCanceled"
case NoRenegotiation:
return "NoRenegotiation"
case UnsupportedExtension:
return "UnsupportedExtension"
default:
return "Invalid alert description"
}
}
// Alert is one of the content types supported by the TLS record layer.
// Alert messages convey the severity of the message
// (warning or fatal) and a description of the alert. Alert messages
// with a level of fatal result in the immediate termination of the
// connection. In this case, other connections corresponding to the
// session may continue, but the session identifier MUST be invalidated,
// preventing the failed session from being used to establish new
// connections. Like other messages, alert messages are encrypted and
// compressed, as specified by the current connection state.
// https://tools.ietf.org/html/rfc5246#section-7.2
type Alert struct {
Level Level
Description Description
}
// ContentType returns the ContentType of this Content
func (a Alert) ContentType() protocol.ContentType {
return protocol.ContentTypeAlert
}
// Marshal returns the encoded alert
func (a *Alert) Marshal() ([]byte, error) {
return []byte{byte(a.Level), byte(a.Description)}, nil
}
// Unmarshal populates the alert from binary data
func (a *Alert) Unmarshal(data []byte) error {
if len(data) != 2 {
return errBufferTooSmall
}
a.Level = Level(data[0])
a.Description = Description(data[1])
return nil
}
func (a *Alert) String() string {
return fmt.Sprintf("Alert %s: %s", a.Level, a.Description)
}

View File

@@ -0,0 +1,49 @@
package alert
import (
"errors"
"reflect"
"testing"
)
func TestAlert(t *testing.T) {
for _, test := range []struct {
Name string
Data []byte
Want *Alert
WantUnmarshalError error
}{
{
Name: "Valid Alert",
Data: []byte{0x02, 0x0A},
Want: &Alert{
Level: Fatal,
Description: UnexpectedMessage,
},
},
{
Name: "Invalid alert length",
Data: []byte{0x00},
Want: &Alert{},
WantUnmarshalError: errBufferTooSmall,
},
} {
a := &Alert{}
if err := a.Unmarshal(test.Data); !errors.Is(err, test.WantUnmarshalError) {
t.Errorf("Unexpected Error %v: exp: %v got: %v", test.Name, test.WantUnmarshalError, err)
} else if !reflect.DeepEqual(test.Want, a) {
t.Errorf("%q alert.unmarshal: got %v, want %v", test.Name, a, test.Want)
}
if test.WantUnmarshalError != nil {
return
}
data, marshalErr := a.Marshal()
if marshalErr != nil {
t.Errorf("Unexpected Error %v: got: %v", test.Name, marshalErr)
} else if !reflect.DeepEqual(test.Data, data) {
t.Errorf("%q alert.marshal: got % 02x, want % 02x", test.Name, data, test.Data)
}
}
}

View File

@@ -0,0 +1,26 @@
package protocol
// ApplicationData messages are carried by the record layer and are
// fragmented, compressed, and encrypted based on the current connection
// state. The messages are treated as transparent data to the record
// layer.
// https://tools.ietf.org/html/rfc5246#section-10
type ApplicationData struct {
Data []byte
}
// ContentType returns the ContentType of this content
func (a ApplicationData) ContentType() ContentType {
return ContentTypeApplicationData
}
// Marshal encodes the ApplicationData to binary
func (a *ApplicationData) Marshal() ([]byte, error) {
return append([]byte{}, a.Data...), nil
}
// Unmarshal populates the ApplicationData from binary
func (a *ApplicationData) Unmarshal(data []byte) error {
a.Data = append([]byte{}, data...)
return nil
}

View File

@@ -0,0 +1,28 @@
package protocol
// ChangeCipherSpec protocol exists to signal transitions in
// ciphering strategies. The protocol consists of a single message,
// which is encrypted and compressed under the current (not the pending)
// connection state. The message consists of a single byte of value 1.
// https://tools.ietf.org/html/rfc5246#section-7.1
type ChangeCipherSpec struct {
}
// ContentType returns the ContentType of this content
func (c ChangeCipherSpec) ContentType() ContentType {
return ContentTypeChangeCipherSpec
}
// Marshal encodes the ChangeCipherSpec to binary
func (c *ChangeCipherSpec) Marshal() ([]byte, error) {
return []byte{0x01}, nil
}
// Unmarshal populates the ChangeCipherSpec from binary
func (c *ChangeCipherSpec) Unmarshal(data []byte) error {
if len(data) == 1 && data[0] == 0x01 {
return nil
}
return errInvalidCipherSpec
}

View File

@@ -0,0 +1,31 @@
package protocol
import (
"errors"
"reflect"
"testing"
)
func TestChangeCipherSpecRoundTrip(t *testing.T) {
c := ChangeCipherSpec{}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
}
var cNew ChangeCipherSpec
if err := cNew.Unmarshal(raw); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(c, cNew) {
t.Errorf("ChangeCipherSpec round trip: got %#v, want %#v", cNew, c)
}
}
func TestChangeCipherSpecInvalid(t *testing.T) {
c := ChangeCipherSpec{}
if err := c.Unmarshal([]byte{0x00}); !errors.Is(err, errInvalidCipherSpec) {
t.Errorf("ChangeCipherSpec invalid assert: got %#v, want %#v", err, errInvalidCipherSpec)
}
}

View File

@@ -0,0 +1,48 @@
package protocol
// CompressionMethodID is the ID for a CompressionMethod
type CompressionMethodID byte
const (
compressionMethodNull CompressionMethodID = 0
)
// CompressionMethod represents a TLS Compression Method
type CompressionMethod struct {
ID CompressionMethodID
}
// CompressionMethods returns all supported CompressionMethods
func CompressionMethods() map[CompressionMethodID]*CompressionMethod {
return map[CompressionMethodID]*CompressionMethod{
compressionMethodNull: {ID: compressionMethodNull},
}
}
// DecodeCompressionMethods the given compression methods
func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) {
if len(buf) < 1 {
return nil, errBufferTooSmall
}
compressionMethodsCount := int(buf[0])
c := []*CompressionMethod{}
for i := 0; i < compressionMethodsCount; i++ {
if len(buf) <= i+1 {
return nil, errBufferTooSmall
}
id := CompressionMethodID(buf[i+1])
if compressionMethod, ok := CompressionMethods()[id]; ok {
c = append(c, compressionMethod)
}
}
return c, nil
}
// EncodeCompressionMethods the given compression methods
func EncodeCompressionMethods(c []*CompressionMethod) []byte {
out := []byte{byte(len(c))}
for i := len(c); i > 0; i-- {
out = append(out, byte(c[i-1].ID))
}
return out
}

View File

@@ -0,0 +1,23 @@
package protocol
import (
"errors"
"testing"
)
func TestDecodeCompressionMethods(t *testing.T) {
testCases := []struct {
buf []byte
result []*CompressionMethod
err error
}{
{[]byte{}, nil, errBufferTooSmall},
}
for _, testCase := range testCases {
_, err := DecodeCompressionMethods(testCase.buf)
if !errors.Is(err, testCase.err) {
t.Fatal("Unexpected error", err)
}
}
}

View File

@@ -0,0 +1,21 @@
package protocol
// ContentType represents the IANA Registered ContentTypes
//
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type ContentType uint8
// ContentType enums
const (
ContentTypeChangeCipherSpec ContentType = 20
ContentTypeAlert ContentType = 21
ContentTypeHandshake ContentType = 22
ContentTypeApplicationData ContentType = 23
)
// Content is the top level distinguisher for a DTLS Datagram
type Content interface {
ContentType() ContentType
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}

View File

@@ -0,0 +1,104 @@
package protocol
import (
"errors"
"fmt"
"net"
)
var (
errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidCipherSpec = &FatalError{Err: errors.New("cipher spec invalid")} //nolint:goerr113
)
// FatalError indicates that the DTLS connection is no longer available.
// It is mainly caused by wrong configuration of server or client.
type FatalError struct {
Err error
}
// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available.
// It is mainly caused by bugs or tried to use unimplemented features.
type InternalError struct {
Err error
}
// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary.
type TemporaryError struct {
Err error
}
// TimeoutError indicates that the request was timed out.
type TimeoutError struct {
Err error
}
// HandshakeError indicates that the handshake failed.
type HandshakeError struct {
Err error
}
// Timeout implements net.Error.Timeout()
func (*FatalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*FatalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *FatalError) Unwrap() error { return e.Err }
func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*InternalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*InternalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *InternalError) Unwrap() error { return e.Err }
func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TemporaryError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*TemporaryError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TemporaryError) Unwrap() error { return e.Err }
func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TimeoutError) Timeout() bool { return true }
// Temporary implements net.Error.Temporary()
func (*TimeoutError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TimeoutError) Unwrap() error { return e.Err }
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Timeout()
}
return false
}
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
}
return false
}
// Unwrap implements Go1.13 error unwrapper.
func (e *HandshakeError) Unwrap() error { return e.Err }
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }

View File

@@ -0,0 +1,14 @@
package extension
import (
"errors"
"github.com/pion/dtls/v2/pkg/protocol"
)
var (
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113
errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113
errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
)

View File

@@ -0,0 +1,96 @@
// Package extension implements the extension values in the ClientHello/ServerHello
package extension
import "encoding/binary"
// TypeValue is the 2 byte value for a TLS Extension as registered in the IANA
//
// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
type TypeValue uint16
// TypeValue constants
const (
ServerNameTypeValue TypeValue = 0
SupportedEllipticCurvesTypeValue TypeValue = 10
SupportedPointFormatsTypeValue TypeValue = 11
SupportedSignatureAlgorithmsTypeValue TypeValue = 13
UseSRTPTypeValue TypeValue = 14
UseExtendedMasterSecretTypeValue TypeValue = 23
RenegotiationInfoTypeValue TypeValue = 65281
)
// Extension represents a single TLS extension
type Extension interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
TypeValue() TypeValue
}
// Unmarshal many extensions at once
func Unmarshal(buf []byte) ([]Extension, error) {
switch {
case len(buf) == 0:
return []Extension{}, nil
case len(buf) < 2:
return nil, errBufferTooSmall
}
declaredLen := binary.BigEndian.Uint16(buf)
if len(buf)-2 != int(declaredLen) {
return nil, errLengthMismatch
}
extensions := []Extension{}
unmarshalAndAppend := func(data []byte, e Extension) error {
err := e.Unmarshal(data)
if err != nil {
return err
}
extensions = append(extensions, e)
return nil
}
for offset := 2; offset < len(buf); {
if len(buf) < (offset + 2) {
return nil, errBufferTooSmall
}
var err error
switch TypeValue(binary.BigEndian.Uint16(buf[offset:])) {
case ServerNameTypeValue:
err = unmarshalAndAppend(buf[offset:], &ServerName{})
case SupportedEllipticCurvesTypeValue:
err = unmarshalAndAppend(buf[offset:], &SupportedEllipticCurves{})
case UseSRTPTypeValue:
err = unmarshalAndAppend(buf[offset:], &UseSRTP{})
case UseExtendedMasterSecretTypeValue:
err = unmarshalAndAppend(buf[offset:], &UseExtendedMasterSecret{})
case RenegotiationInfoTypeValue:
err = unmarshalAndAppend(buf[offset:], &RenegotiationInfo{})
default:
}
if err != nil {
return nil, err
}
if len(buf) < (offset + 4) {
return nil, errBufferTooSmall
}
extensionLength := binary.BigEndian.Uint16(buf[offset+2:])
offset += (4 + int(extensionLength))
}
return extensions, nil
}
// Marshal many extensions at once
func Marshal(e []Extension) ([]byte, error) {
extensions := []byte{}
for _, e := range e {
raw, err := e.Marshal()
if err != nil {
return nil, err
}
extensions = append(extensions, raw...)
}
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out, uint16(len(extensions)))
return append(out, extensions...), nil
}

View File

@@ -0,0 +1,22 @@
package extension
import (
"errors"
"testing"
)
func TestExtensions(t *testing.T) {
t.Run("Zero", func(t *testing.T) {
extensions, err := Unmarshal([]byte{})
if err != nil || len(extensions) != 0 {
t.Fatal("Failed to decode zero extensions")
}
})
t.Run("Invalid", func(t *testing.T) {
extensions, err := Unmarshal([]byte{0x00})
if !errors.Is(err, errBufferTooSmall) || len(extensions) != 0 {
t.Fatal("Failed to error on invalid extension")
}
})
}

View File

@@ -0,0 +1,43 @@
package extension
import "encoding/binary"
const (
renegotiationInfoHeaderSize = 5
)
// RenegotiationInfo allows a Client/Server to
// communicate their renegotation support
//
// https://tools.ietf.org/html/rfc5746
type RenegotiationInfo struct {
RenegotiatedConnection uint8
}
// TypeValue returns the extension TypeValue
func (r RenegotiationInfo) TypeValue() TypeValue {
return RenegotiationInfoTypeValue
}
// Marshal encodes the extension
func (r *RenegotiationInfo) Marshal() ([]byte, error) {
out := make([]byte, renegotiationInfoHeaderSize)
binary.BigEndian.PutUint16(out, uint16(r.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1)) // length
out[4] = r.RenegotiatedConnection
return out, nil
}
// Unmarshal populates the extension from encoded data
func (r *RenegotiationInfo) Unmarshal(data []byte) error {
if len(data) < renegotiationInfoHeaderSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != r.TypeValue() {
return errInvalidExtensionType
}
r.RenegotiatedConnection = data[4]
return nil
}

View File

@@ -0,0 +1,22 @@
package extension
import "testing"
func TestRenegotiationInfo(t *testing.T) {
extension := RenegotiationInfo{RenegotiatedConnection: 0}
raw, err := extension.Marshal()
if err != nil {
t.Fatal(err)
}
newExtension := RenegotiationInfo{}
err = newExtension.Unmarshal(raw)
if err != nil {
t.Fatal(err)
}
if newExtension.RenegotiatedConnection != extension.RenegotiatedConnection {
t.Errorf("extensionRenegotiationInfo marshal: got %d expected %d", newExtension.RenegotiatedConnection, extension.RenegotiatedConnection)
}
}

View File

@@ -0,0 +1,78 @@
package extension
import (
"strings"
"golang.org/x/crypto/cryptobyte"
)
const serverNameTypeDNSHostName = 0
// ServerName allows the client to inform the server the specific
// name it wishs to contact. Useful if multiple DNS names resolve
// to one IP
//
// https://tools.ietf.org/html/rfc6066#section-3
type ServerName struct {
ServerName string
}
// TypeValue returns the extension TypeValue
func (s ServerName) TypeValue() TypeValue {
return ServerNameTypeValue
}
// Marshal encodes the extension
func (s *ServerName) Marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(uint16(s.TypeValue()))
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(serverNameTypeDNSHostName)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(s.ServerName))
})
})
})
return b.Bytes()
}
// Unmarshal populates the extension from encoded data
func (s *ServerName) Unmarshal(data []byte) error {
val := cryptobyte.String(data)
var extension uint16
val.ReadUint16(&extension)
if TypeValue(extension) != s.TypeValue() {
return errInvalidExtensionType
}
var extData cryptobyte.String
val.ReadUint16LengthPrefixed(&extData)
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return errInvalidSNIFormat
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return errInvalidSNIFormat
}
if nameType != serverNameTypeDNSHostName {
continue
}
if len(s.ServerName) != 0 {
// Multiple names of the same name_type are prohibited.
return errInvalidSNIFormat
}
s.ServerName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(s.ServerName, ".") {
return errInvalidSNIFormat
}
}
return nil
}

View File

@@ -0,0 +1,22 @@
package extension
import "testing"
func TestServerName(t *testing.T) {
extension := ServerName{ServerName: "test.domain"}
raw, err := extension.Marshal()
if err != nil {
t.Fatal(err)
}
newExtension := ServerName{}
err = newExtension.Unmarshal(raw)
if err != nil {
t.Fatal(err)
}
if newExtension.ServerName != extension.ServerName {
t.Errorf("extensionServerName marshal: got %s expected %s", newExtension.ServerName, extension.ServerName)
}
}

View File

@@ -0,0 +1,21 @@
package extension
// SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing
// https://tools.ietf.org/html/rfc5764#section-4.1.2
type SRTPProtectionProfile uint16
const (
SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = 0x0001 // nolint
SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = 0x0002 // nolint
SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = 0x0007 // nolint
SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = 0x0008 // nolint
)
func srtpProtectionProfiles() map[SRTPProtectionProfile]bool {
return map[SRTPProtectionProfile]bool{
SRTP_AES128_CM_HMAC_SHA1_80: true,
SRTP_AES128_CM_HMAC_SHA1_32: true,
SRTP_AEAD_AES_128_GCM: true,
SRTP_AEAD_AES_256_GCM: true,
}
}

View File

@@ -0,0 +1,62 @@
package extension
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
)
const (
supportedGroupsHeaderSize = 6
)
// SupportedEllipticCurves allows a Client/Server to communicate
// what curves they both support
//
// https://tools.ietf.org/html/rfc8422#section-5.1.1
type SupportedEllipticCurves struct {
EllipticCurves []elliptic.Curve
}
// TypeValue returns the extension TypeValue
func (s SupportedEllipticCurves) TypeValue() TypeValue {
return SupportedEllipticCurvesTypeValue
}
// Marshal encodes the extension
func (s *SupportedEllipticCurves) Marshal() ([]byte, error) {
out := make([]byte, supportedGroupsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(s.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2))
for _, v := range s.EllipticCurves {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
return out, nil
}
// Unmarshal populates the extension from encoded data
func (s *SupportedEllipticCurves) Unmarshal(data []byte) error {
if len(data) <= supportedGroupsHeaderSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() {
return errInvalidExtensionType
}
groupCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if supportedGroupsHeaderSize+(groupCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < groupCount; i++ {
supportedGroupID := elliptic.Curve(binary.BigEndian.Uint16(data[(supportedGroupsHeaderSize + (i * 2)):]))
if _, ok := elliptic.Curves()[supportedGroupID]; ok {
s.EllipticCurves = append(s.EllipticCurves, supportedGroupID)
}
}
return nil
}

View File

@@ -0,0 +1,22 @@
package extension
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
)
func TestExtensionSupportedGroups(t *testing.T) {
rawSupportedGroups := []byte{0x0, 0xa, 0x0, 0x4, 0x0, 0x2, 0x0, 0x1d}
parsedSupportedGroups := &SupportedEllipticCurves{
EllipticCurves: []elliptic.Curve{elliptic.X25519},
}
raw, err := parsedSupportedGroups.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawSupportedGroups) {
t.Errorf("extensionSupportedGroups marshal: got %#v, want %#v", raw, rawSupportedGroups)
}
}

View File

@@ -0,0 +1,62 @@
package extension
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
)
const (
supportedPointFormatsSize = 5
)
// SupportedPointFormats allows a Client/Server to negotiate
// the EllipticCurvePointFormats
//
// https://tools.ietf.org/html/rfc4492#section-5.1.2
type SupportedPointFormats struct {
PointFormats []elliptic.CurvePointFormat
}
// TypeValue returns the extension TypeValue
func (s SupportedPointFormats) TypeValue() TypeValue {
return SupportedPointFormatsTypeValue
}
// Marshal encodes the extension
func (s *SupportedPointFormats) Marshal() ([]byte, error) {
out := make([]byte, supportedPointFormatsSize)
binary.BigEndian.PutUint16(out, uint16(s.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats))))
out[4] = byte(len(s.PointFormats))
for _, v := range s.PointFormats {
out = append(out, byte(v))
}
return out, nil
}
// Unmarshal populates the extension from encoded data
func (s *SupportedPointFormats) Unmarshal(data []byte) error {
if len(data) <= supportedPointFormatsSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() {
return errInvalidExtensionType
}
pointFormatCount := int(binary.BigEndian.Uint16(data[4:]))
if supportedGroupsHeaderSize+(pointFormatCount) > len(data) {
return errLengthMismatch
}
for i := 0; i < pointFormatCount; i++ {
p := elliptic.CurvePointFormat(data[supportedPointFormatsSize+i])
switch p {
case elliptic.CurvePointFormatUncompressed:
s.PointFormats = append(s.PointFormats, p)
default:
}
}
return nil
}

View File

@@ -0,0 +1,22 @@
package extension
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
)
func TestExtensionSupportedPointFormats(t *testing.T) {
rawExtensionSupportedPointFormats := []byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00}
parsedExtensionSupportedPointFormats := &SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
}
raw, err := parsedExtensionSupportedPointFormats.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawExtensionSupportedPointFormats) {
t.Errorf("extensionSupportedPointFormats marshal: got %#v, want %#v", raw, rawExtensionSupportedPointFormats)
}
}

View File

@@ -0,0 +1,70 @@
package extension
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
)
const (
supportedSignatureAlgorithmsHeaderSize = 6
)
// SupportedSignatureAlgorithms allows a Client/Server to
// negotiate what SignatureHash Algorithms they both support
//
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
type SupportedSignatureAlgorithms struct {
SignatureHashAlgorithms []signaturehash.Algorithm
}
// TypeValue returns the extension TypeValue
func (s SupportedSignatureAlgorithms) TypeValue() TypeValue {
return SupportedSignatureAlgorithmsTypeValue
}
// Marshal encodes the extension
func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) {
out := make([]byte, supportedSignatureAlgorithmsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(s.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.SignatureHashAlgorithms)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(s.SignatureHashAlgorithms)*2))
for _, v := range s.SignatureHashAlgorithms {
out = append(out, []byte{0x00, 0x00}...)
out[len(out)-2] = byte(v.Hash)
out[len(out)-1] = byte(v.Signature)
}
return out, nil
}
// Unmarshal populates the extension from encoded data
func (s *SupportedSignatureAlgorithms) Unmarshal(data []byte) error {
if len(data) <= supportedSignatureAlgorithmsHeaderSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() {
return errInvalidExtensionType
}
algorithmCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if supportedSignatureAlgorithmsHeaderSize+(algorithmCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < algorithmCount; i++ {
supportedHashAlgorithm := hash.Algorithm(data[supportedSignatureAlgorithmsHeaderSize+(i*2)])
supportedSignatureAlgorithm := signature.Algorithm(data[supportedSignatureAlgorithmsHeaderSize+(i*2)+1])
if _, ok := hash.Algorithms()[supportedHashAlgorithm]; ok {
if _, ok := signature.Algorithms()[supportedSignatureAlgorithm]; ok {
s.SignatureHashAlgorithms = append(s.SignatureHashAlgorithms, signaturehash.Algorithm{
Hash: supportedHashAlgorithm,
Signature: supportedSignatureAlgorithm,
})
}
}
}
return nil
}

View File

@@ -0,0 +1,35 @@
package extension
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
)
func TestExtensionSupportedSignatureAlgorithms(t *testing.T) {
rawExtensionSupportedSignatureAlgorithms := []byte{
0x00, 0x0d,
0x00, 0x08,
0x00, 0x06,
0x04, 0x03,
0x05, 0x03,
0x06, 0x03,
}
parsedExtensionSupportedSignatureAlgorithms := &SupportedSignatureAlgorithms{
SignatureHashAlgorithms: []signaturehash.Algorithm{
{Hash: hash.SHA256, Signature: signature.ECDSA},
{Hash: hash.SHA384, Signature: signature.ECDSA},
{Hash: hash.SHA512, Signature: signature.ECDSA},
},
}
raw, err := parsedExtensionSupportedSignatureAlgorithms.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawExtensionSupportedSignatureAlgorithms) {
t.Errorf("extensionSupportedSignatureAlgorithms marshal: got %#v, want %#v", raw, rawExtensionSupportedSignatureAlgorithms)
}
}

View File

@@ -0,0 +1,45 @@
package extension
import "encoding/binary"
const (
useExtendedMasterSecretHeaderSize = 4
)
// UseExtendedMasterSecret defines a TLS extension that contextually binds the
// master secret to a log of the full handshake that computes it, thus
// preventing MITM attacks.
type UseExtendedMasterSecret struct {
Supported bool
}
// TypeValue returns the extension TypeValue
func (u UseExtendedMasterSecret) TypeValue() TypeValue {
return UseExtendedMasterSecretTypeValue
}
// Marshal encodes the extension
func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) {
if !u.Supported {
return []byte{}, nil
}
out := make([]byte, useExtendedMasterSecretHeaderSize)
binary.BigEndian.PutUint16(out, uint16(u.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(0)) // length
return out, nil
}
// Unmarshal populates the extension from encoded data
func (u *UseExtendedMasterSecret) Unmarshal(data []byte) error {
if len(data) < useExtendedMasterSecretHeaderSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != u.TypeValue() {
return errInvalidExtensionType
}
u.Supported = true
return nil
}

View File

@@ -0,0 +1,59 @@
package extension
import "encoding/binary"
const (
useSRTPHeaderSize = 6
)
// UseSRTP allows a Client/Server to negotiate what SRTPProtectionProfiles
// they both support
//
// https://tools.ietf.org/html/rfc8422
type UseSRTP struct {
ProtectionProfiles []SRTPProtectionProfile
}
// TypeValue returns the extension TypeValue
func (u UseSRTP) TypeValue() TypeValue {
return UseSRTPTypeValue
}
// Marshal encodes the extension
func (u *UseSRTP) Marshal() ([]byte, error) {
out := make([]byte, useSRTPHeaderSize)
binary.BigEndian.PutUint16(out, uint16(u.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1))
binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2))
for _, v := range u.ProtectionProfiles {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
out = append(out, 0x00) /* MKI Length */
return out, nil
}
// Unmarshal populates the extension from encoded data
func (u *UseSRTP) Unmarshal(data []byte) error {
if len(data) <= useSRTPHeaderSize {
return errBufferTooSmall
} else if TypeValue(binary.BigEndian.Uint16(data)) != u.TypeValue() {
return errInvalidExtensionType
}
profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if supportedGroupsHeaderSize+(profileCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < profileCount; i++ {
supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(useSRTPHeaderSize + (i * 2)):]))
if _, ok := srtpProtectionProfiles()[supportedProfile]; ok {
u.ProtectionProfiles = append(u.ProtectionProfiles, supportedProfile)
}
}
return nil
}

View File

@@ -0,0 +1,20 @@
package extension
import (
"reflect"
"testing"
)
func TestExtensionUseSRTP(t *testing.T) {
rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00}
parsedUseSRTP := &UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
}
raw, err := parsedUseSRTP.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawUseSRTP) {
t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", raw, rawUseSRTP)
}
}

View File

@@ -0,0 +1,29 @@
package handshake
import "encoding/binary"
func decodeCipherSuiteIDs(buf []byte) ([]uint16, error) {
if len(buf) < 2 {
return nil, errBufferTooSmall
}
cipherSuitesCount := int(binary.BigEndian.Uint16(buf[0:])) / 2
rtrn := make([]uint16, cipherSuitesCount)
for i := 0; i < cipherSuitesCount; i++ {
if len(buf) < (i*2 + 4) {
return nil, errBufferTooSmall
}
rtrn[i] = binary.BigEndian.Uint16(buf[(i*2)+2:])
}
return rtrn, nil
}
func encodeCipherSuiteIDs(cipherSuiteIDs []uint16) []byte {
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2))
for _, id := range cipherSuiteIDs {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], id)
}
return out
}

View File

@@ -0,0 +1,23 @@
package handshake
import (
"errors"
"testing"
)
func TestDecodeCipherSuiteIDs(t *testing.T) {
testCases := []struct {
buf []byte
result []uint16
err error
}{
{[]byte{}, nil, errBufferTooSmall},
}
for _, testCase := range testCases {
_, err := decodeCipherSuiteIDs(testCase.buf)
if !errors.Is(err, testCase.err) {
t.Fatal("Unexpected error", err)
}
}
}

View File

@@ -0,0 +1,25 @@
package handshake
import (
"errors"
"github.com/pion/dtls/v2/pkg/protocol"
)
// Typed errors
var (
errUnableToMarshalFragmented = &protocol.InternalError{Err: errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113
errHandshakeMessageUnset = &protocol.InternalError{Err: errors.New("handshake message unset, unable to marshal")} //nolint:goerr113
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
errInvalidClientKeyExchange = &protocol.FatalError{Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113
errInvalidHashAlgorithm = &protocol.FatalError{Err: errors.New("invalid hash algorithm")} //nolint:goerr113
errInvalidSignatureAlgorithm = &protocol.FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113
errCookieTooLong = &protocol.FatalError{Err: errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113
errInvalidEllipticCurveType = &protocol.FatalError{Err: errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113
errInvalidNamedCurve = &protocol.FatalError{Err: errors.New("invalid named curve")} //nolint:goerr113
errCipherSuiteUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113
errCompressionMethodUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a compression method")} //nolint:goerr113
errInvalidCompressionMethod = &protocol.FatalError{Err: errors.New("invalid or unknown compression method")} //nolint:goerr113
errNotImplemented = &protocol.InternalError{Err: errors.New("feature has not been implemented yet")} //nolint:goerr113
)

View File

@@ -0,0 +1,145 @@
// Package handshake provides the DTLS wire protocol for handshakes
package handshake
import (
"github.com/pion/dtls/v2/internal/util"
"github.com/pion/dtls/v2/pkg/protocol"
)
// Type is the unique identifier for each handshake message
// https://tools.ietf.org/html/rfc5246#section-7.4
type Type uint8
// Types of DTLS Handshake messages we know about
const (
TypeHelloRequest Type = 0
TypeClientHello Type = 1
TypeServerHello Type = 2
TypeHelloVerifyRequest Type = 3
TypeCertificate Type = 11
TypeServerKeyExchange Type = 12
TypeCertificateRequest Type = 13
TypeServerHelloDone Type = 14
TypeCertificateVerify Type = 15
TypeClientKeyExchange Type = 16
TypeFinished Type = 20
)
// String returns the string representation of this type
func (t Type) String() string {
switch t {
case TypeHelloRequest:
return "HelloRequest"
case TypeClientHello:
return "ClientHello"
case TypeServerHello:
return "ServerHello"
case TypeHelloVerifyRequest:
return "HelloVerifyRequest"
case TypeCertificate:
return "TypeCertificate"
case TypeServerKeyExchange:
return "ServerKeyExchange"
case TypeCertificateRequest:
return "CertificateRequest"
case TypeServerHelloDone:
return "ServerHelloDone"
case TypeCertificateVerify:
return "CertificateVerify"
case TypeClientKeyExchange:
return "ClientKeyExchange"
case TypeFinished:
return "Finished"
}
return ""
}
// Message is the body of a Handshake datagram
type Message interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
Type() Type
}
// Handshake protocol is responsible for selecting a cipher spec and
// generating a master secret, which together comprise the primary
// cryptographic parameters associated with a secure session. The
// handshake protocol can also optionally authenticate parties who have
// certificates signed by a trusted certificate authority.
// https://tools.ietf.org/html/rfc5246#section-7.3
type Handshake struct {
Header Header
Message Message
}
// ContentType returns what kind of content this message is carying
func (h Handshake) ContentType() protocol.ContentType {
return protocol.ContentTypeHandshake
}
// Marshal encodes a handshake into a binary message
func (h *Handshake) Marshal() ([]byte, error) {
if h.Message == nil {
return nil, errHandshakeMessageUnset
} else if h.Header.FragmentOffset != 0 {
return nil, errUnableToMarshalFragmented
}
msg, err := h.Message.Marshal()
if err != nil {
return nil, err
}
h.Header.Length = uint32(len(msg))
h.Header.FragmentLength = h.Header.Length
h.Header.Type = h.Message.Type()
header, err := h.Header.Marshal()
if err != nil {
return nil, err
}
return append(header, msg...), nil
}
// Unmarshal decodes a handshake from a binary message
func (h *Handshake) Unmarshal(data []byte) error {
if err := h.Header.Unmarshal(data); err != nil {
return err
}
reportedLen := util.BigEndianUint24(data[1:])
if uint32(len(data)-HeaderLength) != reportedLen {
return errLengthMismatch
} else if reportedLen != h.Header.FragmentLength {
return errLengthMismatch
}
switch Type(data[0]) {
case TypeHelloRequest:
return errNotImplemented
case TypeClientHello:
h.Message = &MessageClientHello{}
case TypeHelloVerifyRequest:
h.Message = &MessageHelloVerifyRequest{}
case TypeServerHello:
h.Message = &MessageServerHello{}
case TypeCertificate:
h.Message = &MessageCertificate{}
case TypeServerKeyExchange:
h.Message = &MessageServerKeyExchange{}
case TypeCertificateRequest:
h.Message = &MessageCertificateRequest{}
case TypeServerHelloDone:
h.Message = &MessageServerHelloDone{}
case TypeClientKeyExchange:
h.Message = &MessageClientKeyExchange{}
case TypeFinished:
h.Message = &MessageFinished{}
case TypeCertificateVerify:
h.Message = &MessageCertificateVerify{}
default:
return errNotImplemented
}
return h.Message.Unmarshal(data[HeaderLength:])
}

View File

@@ -0,0 +1,50 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/internal/util"
)
// HeaderLength msg_len for Handshake messages assumes an extra
// 12 bytes for sequence, fragment and version information vs TLS
const HeaderLength = 12
// Header is the static first 12 bytes of each RecordLayer
// of type Handshake. These fields allow us to support message loss, reordering, and
// message fragmentation,
//
// https://tools.ietf.org/html/rfc6347#section-4.2.2
type Header struct {
Type Type
Length uint32 // uint24 in spec
MessageSequence uint16
FragmentOffset uint32 // uint24 in spec
FragmentLength uint32 // uint24 in spec
}
// Marshal encodes the Header
func (h *Header) Marshal() ([]byte, error) {
out := make([]byte, HeaderLength)
out[0] = byte(h.Type)
util.PutBigEndianUint24(out[1:], h.Length)
binary.BigEndian.PutUint16(out[4:], h.MessageSequence)
util.PutBigEndianUint24(out[6:], h.FragmentOffset)
util.PutBigEndianUint24(out[9:], h.FragmentLength)
return out, nil
}
// Unmarshal populates the header from encoded data
func (h *Header) Unmarshal(data []byte) error {
if len(data) < HeaderLength {
return errBufferTooSmall
}
h.Type = Type(data[0])
h.Length = util.BigEndianUint24(data[1:])
h.MessageSequence = binary.BigEndian.Uint16(data[4:])
h.FragmentOffset = util.BigEndianUint24(data[6:])
h.FragmentLength = util.BigEndianUint24(data[9:])
return nil
}

View File

@@ -0,0 +1,66 @@
package handshake
import (
"github.com/pion/dtls/v2/internal/util"
)
// MessageCertificate is a DTLS Handshake Message
// it can contain either a Client or Server Certificate
//
// https://tools.ietf.org/html/rfc5246#section-7.4.2
type MessageCertificate struct {
Certificate [][]byte
}
// Type returns the Handshake Type
func (m MessageCertificate) Type() Type {
return TypeCertificate
}
const (
handshakeMessageCertificateLengthFieldSize = 3
)
// Marshal encodes the Handshake
func (m *MessageCertificate) Marshal() ([]byte, error) {
out := make([]byte, handshakeMessageCertificateLengthFieldSize)
for _, r := range m.Certificate {
// Certificate Length
out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...)
util.PutBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r)))
// Certificate body
out = append(out, append([]byte{}, r...)...)
}
// Total Payload Size
util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:])))
return out, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageCertificate) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateLengthFieldSize {
return errBufferTooSmall
}
if certificateBodyLen := int(util.BigEndianUint24(data)); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) {
return errLengthMismatch
}
offset := handshakeMessageCertificateLengthFieldSize
for offset < len(data) {
certificateLen := int(util.BigEndianUint24(data[offset:]))
offset += handshakeMessageCertificateLengthFieldSize
if offset+certificateLen > len(data) {
return errLengthMismatch
}
m.Certificate = append(m.Certificate, append([]byte{}, data[offset:offset+certificateLen]...))
offset += certificateLen
}
return nil
}

View File

@@ -0,0 +1,100 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
)
/*
MessageCertificateRequest is so a non-anonymous server can optionally
request a certificate from the client, if appropriate for the selected cipher
suite. This message, if sent, will immediately follow the ServerKeyExchange
message (if it is sent; otherwise, this message follows the
server's Certificate message).
https://tools.ietf.org/html/rfc5246#section-7.4.4
*/
type MessageCertificateRequest struct {
CertificateTypes []clientcertificate.Type
SignatureHashAlgorithms []signaturehash.Algorithm
}
const (
messageCertificateRequestMinLength = 5
)
// Type returns the Handshake Type
func (m MessageCertificateRequest) Type() Type {
return TypeCertificateRequest
}
// Marshal encodes the Handshake
func (m *MessageCertificateRequest) Marshal() ([]byte, error) {
out := []byte{byte(len(m.CertificateTypes))}
for _, v := range m.CertificateTypes {
out = append(out, byte(v))
}
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2))
for _, v := range m.SignatureHashAlgorithms {
out = append(out, byte(v.Hash))
out = append(out, byte(v.Signature))
}
out = append(out, []byte{0x00, 0x00}...) // Distinguished Names Length
return out, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageCertificateRequest) Unmarshal(data []byte) error {
if len(data) < messageCertificateRequestMinLength {
return errBufferTooSmall
}
offset := 0
certificateTypesLength := int(data[0])
offset++
if (offset + certificateTypesLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < certificateTypesLength; i++ {
certType := clientcertificate.Type(data[offset+i])
if _, ok := clientcertificate.Types()[certType]; ok {
m.CertificateTypes = append(m.CertificateTypes, certType)
}
}
offset += certificateTypesLength
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureHashAlgorithmsLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if (offset + signatureHashAlgorithmsLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < signatureHashAlgorithmsLength; i += 2 {
if len(data) < (offset + i + 2) {
return errBufferTooSmall
}
h := hash.Algorithm(data[offset+i])
s := signature.Algorithm(data[offset+i+1])
if _, ok := hash.Algorithms()[h]; !ok {
continue
} else if _, ok := signature.Algorithms()[s]; !ok {
continue
}
m.SignatureHashAlgorithms = append(m.SignatureHashAlgorithms, signaturehash.Algorithm{Signature: s, Hash: h})
}
return nil
}

View File

@@ -0,0 +1,46 @@
package handshake
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
)
func TestHandshakeMessageCertificateRequest(t *testing.T) {
rawCertificateRequest := []byte{
0x02, 0x01, 0x40, 0x00, 0x0C, 0x04, 0x03, 0x04, 0x01, 0x05,
0x03, 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x00, 0x00,
}
parsedCertificateRequest := &MessageCertificateRequest{
CertificateTypes: []clientcertificate.Type{
clientcertificate.RSASign,
clientcertificate.ECDSASign,
},
SignatureHashAlgorithms: []signaturehash.Algorithm{
{Hash: hash.SHA256, Signature: signature.ECDSA},
{Hash: hash.SHA256, Signature: signature.RSA},
{Hash: hash.SHA384, Signature: signature.ECDSA},
{Hash: hash.SHA384, Signature: signature.RSA},
{Hash: hash.SHA512, Signature: signature.RSA},
{Hash: hash.SHA1, Signature: signature.RSA},
},
}
c := &MessageCertificateRequest{}
if err := c.Unmarshal(rawCertificateRequest); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedCertificateRequest) {
t.Errorf("parsedCertificateRequest unmarshal: got %#v, want %#v", c, parsedCertificateRequest)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawCertificateRequest) {
t.Errorf("parsedCertificateRequest marshal: got %#v, want %#v", raw, rawCertificateRequest)
}
}

View File

@@ -0,0 +1,99 @@
package handshake
import (
"crypto/x509"
"reflect"
"testing"
)
func TestHandshakeMessageCertificate(t *testing.T) {
// Not easy to mock out these members, just copy for now (since everything else matches)
copyCertificatePrivateMembers := func(src, dst *x509.Certificate) {
dst.PublicKey = src.PublicKey
dst.SerialNumber = src.SerialNumber
dst.Issuer = src.Issuer
dst.Subject = src.Subject
dst.NotBefore = src.NotBefore
dst.NotAfter = src.NotAfter
}
rawCertificate := []byte{
0x00, 0x01, 0x8c, 0x00, 0x01, 0x89, 0x30, 0x82, 0x01, 0x85, 0x30, 0x82, 0x01, 0x2b, 0x02, 0x14,
0x7d, 0x00, 0xcf, 0x07, 0xfc, 0xe2, 0xb6, 0xb8, 0x3f, 0x72, 0xeb, 0x11, 0x36, 0x1b, 0xf6, 0x39,
0xf1, 0x3c, 0x33, 0x41, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02,
0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31,
0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53,
0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x18, 0x49,
0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20,
0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x32,
0x35, 0x30, 0x38, 0x35, 0x31, 0x31, 0x32, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x32, 0x35,
0x30, 0x38, 0x35, 0x31, 0x31, 0x32, 0x5a, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c,
0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
0x03, 0x55, 0x04, 0x0a, 0x0c, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x59,
0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48,
0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xf9, 0xb1, 0x62, 0xd6, 0x07, 0xae, 0xc3,
0x36, 0x34, 0xf5, 0xa3, 0x09, 0x39, 0x86, 0xe7, 0x3b, 0x59, 0xf7, 0x4a, 0x1d, 0xf4, 0x97, 0x4f,
0x91, 0x40, 0x56, 0x1b, 0x3d, 0x6c, 0x5a, 0x38, 0x10, 0x15, 0x58, 0xf5, 0xa4, 0xcc, 0xdf, 0xd5,
0xf5, 0x4a, 0x35, 0x40, 0x0f, 0x9f, 0x54, 0xb7, 0xe9, 0xe2, 0xae, 0x63, 0x83, 0x6a, 0x4c, 0xfc,
0xc2, 0x5f, 0x78, 0xa0, 0xbb, 0x46, 0x54, 0xa4, 0xda, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48,
0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x48, 0x00, 0x30, 0x45, 0x02, 0x20, 0x47, 0x1a, 0x5f, 0x58,
0x2a, 0x74, 0x33, 0x6d, 0xed, 0xac, 0x37, 0x21, 0xfa, 0x76, 0x5a, 0x4d, 0x78, 0x68, 0x1a, 0xdd,
0x80, 0xa4, 0xd4, 0xb7, 0x7f, 0x7d, 0x78, 0xb3, 0xfb, 0xf3, 0x95, 0xfb, 0x02, 0x21, 0x00, 0xc0,
0x73, 0x30, 0xda, 0x2b, 0xc0, 0x0c, 0x9e, 0xb2, 0x25, 0x0d, 0x46, 0xb0, 0xbc, 0x66, 0x7f, 0x71,
0x66, 0xbf, 0x16, 0xb3, 0x80, 0x78, 0xd0, 0x0c, 0xef, 0xcc, 0xf5, 0xc1, 0x15, 0x0f, 0x58,
}
parsedCertificate := &x509.Certificate{
Raw: rawCertificate[6:],
RawTBSCertificate: rawCertificate[10:313],
RawSubjectPublicKeyInfo: rawCertificate[222:313],
RawSubject: rawCertificate[48:119],
RawIssuer: rawCertificate[48:119],
Signature: rawCertificate[328:],
SignatureAlgorithm: x509.ECDSAWithSHA256,
PublicKeyAlgorithm: x509.ECDSA,
Version: 1,
}
c := &MessageCertificate{}
if err := c.Unmarshal(rawCertificate); err != nil {
t.Error(err)
} else {
certificate, err := x509.ParseCertificate(c.Certificate[0])
if err != nil {
t.Error(err)
}
copyCertificatePrivateMembers(certificate, parsedCertificate)
if !reflect.DeepEqual(certificate, parsedCertificate) {
t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, parsedCertificate)
}
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawCertificate) {
t.Errorf("handshakeMessageCertificate marshal: got %#v, want %#v", raw, rawCertificate)
}
}
func TestEmptyHandshakeMessageCertificate(t *testing.T) {
rawCertificate := []byte{
0x00, 0x00, 0x00,
}
expectedCertificate := &MessageCertificate{
Certificate: nil,
}
c := &MessageCertificate{}
if err := c.Unmarshal(rawCertificate); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(c, expectedCertificate) {
t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, expectedCertificate)
}
}

View File

@@ -0,0 +1,61 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
)
// MessageCertificateVerify provide explicit verification of a
// client certificate.
//
// https://tools.ietf.org/html/rfc5246#section-7.4.8
type MessageCertificateVerify struct {
HashAlgorithm hash.Algorithm
SignatureAlgorithm signature.Algorithm
Signature []byte
}
const handshakeMessageCertificateVerifyMinLength = 4
// Type returns the Handshake Type
func (m MessageCertificateVerify) Type() Type {
return TypeCertificateVerify
}
// Marshal encodes the Handshake
func (m *MessageCertificateVerify) Marshal() ([]byte, error) {
out := make([]byte, 1+1+2+len(m.Signature))
out[0] = byte(m.HashAlgorithm)
out[1] = byte(m.SignatureAlgorithm)
binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature)))
copy(out[4:], m.Signature)
return out, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageCertificateVerify) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateVerifyMinLength {
return errBufferTooSmall
}
m.HashAlgorithm = hash.Algorithm(data[0])
if _, ok := hash.Algorithms()[m.HashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
m.SignatureAlgorithm = signature.Algorithm(data[1])
if _, ok := signature.Algorithms()[m.SignatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
signatureLength := int(binary.BigEndian.Uint16(data[2:]))
if (signatureLength + 4) != len(data) {
return errBufferTooSmall
}
m.Signature = append([]byte{}, data[4:]...)
return nil
}

View File

@@ -0,0 +1,38 @@
package handshake
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
)
func TestHandshakeMessageCertificateVerify(t *testing.T) {
rawCertificateVerify := []byte{
0x04, 0x03, 0x00, 0x47, 0x30, 0x45, 0x02, 0x20, 0x6b, 0x63, 0x17, 0xad, 0xbe, 0xb7, 0x7b, 0x0f,
0x86, 0x73, 0x39, 0x1e, 0xba, 0xb3, 0x50, 0x9c, 0xce, 0x9c, 0xe4, 0x8b, 0xe5, 0x13, 0x07, 0x59,
0x18, 0x1f, 0xe5, 0xa0, 0x2b, 0xca, 0xa6, 0xad, 0x02, 0x21, 0x00, 0xd3, 0xb5, 0x01, 0xbe, 0x87,
0x6c, 0x04, 0xa1, 0xdc, 0x28, 0xaa, 0x5f, 0xf7, 0x1e, 0x9c, 0xc0, 0x1e, 0x00, 0x2c, 0xe5, 0x94,
0xbb, 0x03, 0x0e, 0xf1, 0xcb, 0x28, 0x22, 0x33, 0x23, 0x88, 0xad,
}
parsedCertificateVerify := &MessageCertificateVerify{
HashAlgorithm: hash.Algorithm(rawCertificateVerify[0]),
SignatureAlgorithm: signature.Algorithm(rawCertificateVerify[1]),
Signature: rawCertificateVerify[4:],
}
c := &MessageCertificateVerify{}
if err := c.Unmarshal(rawCertificateVerify); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedCertificateVerify) {
t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, parsedCertificateVerify)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawCertificateVerify) {
t.Errorf("handshakeMessageCertificateVerify marshal: got %#v, want %#v", raw, rawCertificateVerify)
}
}

View File

@@ -0,0 +1,130 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
)
/*
MessageClientHello is for when a client first connects to a server it is
required to send the client hello as its first message. The client can also send a
client hello in response to a hello request or on its own
initiative in order to renegotiate the security parameters in an
existing connection.
*/
type MessageClientHello struct {
Version protocol.Version
Random Random
Cookie []byte
SessionID []byte // TODO 添加anylink支持
CipherSuiteIDs []uint16
CompressionMethods []*protocol.CompressionMethod
Extensions []extension.Extension
}
const handshakeMessageClientHelloVariableWidthStart = 34
// Type returns the Handshake Type
func (m MessageClientHello) Type() Type {
return TypeClientHello
}
// Marshal encodes the Handshake
func (m *MessageClientHello) Marshal() ([]byte, error) {
if len(m.Cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, handshakeMessageClientHelloVariableWidthStart)
out[0] = m.Version.Major
out[1] = m.Version.Minor
rand := m.Random.MarshalFixed()
copy(out[2:], rand[:])
out = append(out, 0x00) // SessionID
out = append(out, byte(len(m.Cookie)))
out = append(out, m.Cookie...)
out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...)
out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...)
extensions, err := extension.Marshal(m.Extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
// Unmarshal populates the message from encoded data
func (m *MessageClientHello) Unmarshal(data []byte) error {
if len(data) < 2+RandomLength {
return errBufferTooSmall
}
m.Version.Major = data[0]
m.Version.Minor = data[1]
var random [RandomLength]byte
copy(random[:], data[2:])
m.Random.UnmarshalFixed(random)
// rest of packet has variable width sections
currOffset := handshakeMessageClientHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
// TODO 添加SessionID
m.SessionID = data[handshakeMessageClientHelloVariableWidthStart+1 : currOffset]
currOffset++
if len(data) <= currOffset {
return errBufferTooSmall
}
n := int(data[currOffset-1])
if len(data) <= currOffset+n {
return errBufferTooSmall
}
m.Cookie = append([]byte{}, data[currOffset:currOffset+n]...)
currOffset += len(m.Cookie)
// Cipher Suites
if len(data) < currOffset {
return errBufferTooSmall
}
cipherSuiteIDs, err := decodeCipherSuiteIDs(data[currOffset:])
if err != nil {
return err
}
m.CipherSuiteIDs = cipherSuiteIDs
if len(data) < currOffset+2 {
return errBufferTooSmall
}
currOffset += int(binary.BigEndian.Uint16(data[currOffset:])) + 2
// Compression Methods
if len(data) < currOffset {
return errBufferTooSmall
}
compressionMethods, err := protocol.DecodeCompressionMethods(data[currOffset:])
if err != nil {
return err
}
m.CompressionMethods = compressionMethods
if len(data) < currOffset {
return errBufferTooSmall
}
currOffset += int(data[currOffset]) + 1
// Extensions
extensions, err := extension.Unmarshal(data[currOffset:])
if err != nil {
return err
}
m.Extensions = extensions
return nil
}

View File

@@ -0,0 +1,53 @@
package handshake
import (
"reflect"
"testing"
"time"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
)
func TestHandshakeMessageClientHello(t *testing.T) {
rawClientHello := []byte{
0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42,
0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec,
0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x14, 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14,
0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, 0x00, 0x04, 0xc0, 0x2b,
0xc0, 0x0a, 0x01, 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d,
}
parsedClientHello := &MessageClientHello{
Version: protocol.Version{Major: 0xFE, Minor: 0xFD},
Random: Random{
GMTUnixTime: time.Unix(3056586332, 0),
RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b},
},
Cookie: []byte{0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8},
CipherSuiteIDs: []uint16{
0xc02b,
0xc00a,
},
CompressionMethods: []*protocol.CompressionMethod{
{},
},
Extensions: []extension.Extension{
&extension.SupportedEllipticCurves{EllipticCurves: []elliptic.Curve{elliptic.X25519}},
},
}
c := &MessageClientHello{}
if err := c.Unmarshal(rawClientHello); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedClientHello) {
t.Errorf("handshakeMessageClientHello unmarshal: got %#v, want %#v", c, parsedClientHello)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawClientHello) {
t.Errorf("handshakeMessageClientHello marshal: got %#v, want %#v", raw, rawClientHello)
}
}

View File

@@ -0,0 +1,56 @@
package handshake
import (
"encoding/binary"
)
// MessageClientKeyExchange is a DTLS Handshake Message
// With this message, the premaster secret is set, either by direct
// transmission of the RSA-encrypted secret or by the transmission of
// Diffie-Hellman parameters that will allow each side to agree upon
// the same premaster secret.
//
// https://tools.ietf.org/html/rfc5246#section-7.4.7
type MessageClientKeyExchange struct {
IdentityHint []byte
PublicKey []byte
}
// Type returns the Handshake Type
func (m MessageClientKeyExchange) Type() Type {
return TypeClientKeyExchange
}
// Marshal encodes the Handshake
func (m *MessageClientKeyExchange) Marshal() ([]byte, error) {
switch {
case (m.IdentityHint != nil && m.PublicKey != nil) || (m.IdentityHint == nil && m.PublicKey == nil):
return nil, errInvalidClientKeyExchange
case m.PublicKey != nil:
return append([]byte{byte(len(m.PublicKey))}, m.PublicKey...), nil
default:
out := append([]byte{0x00, 0x00}, m.IdentityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
}
// Unmarshal populates the message from encoded data
func (m *MessageClientKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
m.IdentityHint = append([]byte{}, data[2:]...)
return nil
}
if publicKeyLength := int(data[0]); len(data) != publicKeyLength+1 {
return errBufferTooSmall
}
m.PublicKey = append([]byte{}, data[1:]...)
return nil
}

View File

@@ -0,0 +1,31 @@
package handshake
import (
"reflect"
"testing"
)
func TestHandshakeMessageClientKeyExchange(t *testing.T) {
rawClientKeyExchange := []byte{
0x20, 0x26, 0x78, 0x4a, 0x78, 0x70, 0xc1, 0xf9, 0x71, 0xea, 0x50, 0x4a, 0xb5, 0xbb, 0x00, 0x76,
0x02, 0x05, 0xda, 0xf7, 0xd0, 0x3f, 0xe3, 0xf7, 0x4e, 0x8a, 0x14, 0x6f, 0xb7, 0xe0, 0xc0, 0xff,
0x54,
}
parsedClientKeyExchange := &MessageClientKeyExchange{
PublicKey: rawClientKeyExchange[1:],
}
c := &MessageClientKeyExchange{}
if err := c.Unmarshal(rawClientKeyExchange); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedClientKeyExchange) {
t.Errorf("handshakeMessageClientKeyExchange unmarshal: got %#v, want %#v", c, parsedClientKeyExchange)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawClientKeyExchange) {
t.Errorf("handshakeMessageClientKeyExchange marshal: got %#v, want %#v", raw, rawClientKeyExchange)
}
}

View File

@@ -0,0 +1,27 @@
package handshake
// MessageFinished is a DTLS Handshake Message
// this message is the first one protected with the just
// negotiated algorithms, keys, and secrets. Recipients of Finished
// messages MUST verify that the contents are correct.
//
// https://tools.ietf.org/html/rfc5246#section-7.4.9
type MessageFinished struct {
VerifyData []byte
}
// Type returns the Handshake Type
func (m MessageFinished) Type() Type {
return TypeFinished
}
// Marshal encodes the Handshake
func (m *MessageFinished) Marshal() ([]byte, error) {
return append([]byte{}, m.VerifyData...), nil
}
// Unmarshal populates the message from encoded data
func (m *MessageFinished) Unmarshal(data []byte) error {
m.VerifyData = append([]byte{}, data...)
return nil
}

View File

@@ -0,0 +1,29 @@
package handshake
import (
"reflect"
"testing"
)
func TestHandshakeMessageFinished(t *testing.T) {
rawFinished := []byte{
0x01, 0x01, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
}
parsedFinished := &MessageFinished{
VerifyData: rawFinished,
}
c := &MessageFinished{}
if err := c.Unmarshal(rawFinished); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedFinished) {
t.Errorf("handshakeMessageFinished unmarshal: got %#v, want %#v", c, parsedFinished)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawFinished) {
t.Errorf("handshakeMessageFinished marshal: got %#v, want %#v", raw, rawFinished)
}
}

View File

@@ -0,0 +1,62 @@
package handshake
import (
"github.com/pion/dtls/v2/pkg/protocol"
)
// MessageHelloVerifyRequest is as follows:
//
// struct {
// ProtocolVersion server_version;
// opaque cookie<0..2^8-1>;
// } HelloVerifyRequest;
//
// The HelloVerifyRequest message type is hello_verify_request(3).
//
// When the client sends its ClientHello message to the server, the server
// MAY respond with a HelloVerifyRequest message. This message contains
// a stateless cookie generated using the technique of [PHOTURIS]. The
// client MUST retransmit the ClientHello with the cookie added.
//
// https://tools.ietf.org/html/rfc6347#section-4.2.1
type MessageHelloVerifyRequest struct {
Version protocol.Version
Cookie []byte
}
// Type returns the Handshake Type
func (m MessageHelloVerifyRequest) Type() Type {
return TypeHelloVerifyRequest
}
// Marshal encodes the Handshake
func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) {
if len(m.Cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, 3+len(m.Cookie))
out[0] = m.Version.Major
out[1] = m.Version.Minor
out[2] = byte(len(m.Cookie))
copy(out[3:], m.Cookie)
return out, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error {
if len(data) < 3 {
return errBufferTooSmall
}
m.Version.Major = data[0]
m.Version.Minor = data[1]
cookieLength := data[2]
if len(data) < (int(cookieLength) + 3) {
return errBufferTooSmall
}
m.Cookie = make([]byte, cookieLength)
copy(m.Cookie, data[3:3+cookieLength])
return nil
}

View File

@@ -0,0 +1,33 @@
package handshake
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/protocol"
)
func TestHandshakeMessageHelloVerifyRequest(t *testing.T) {
rawHelloVerifyRequest := []byte{
0xfe, 0xff, 0x14, 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00,
0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd,
}
parsedHelloVerifyRequest := &MessageHelloVerifyRequest{
Version: protocol.Version{Major: 0xFE, Minor: 0xFF},
Cookie: []byte{0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd},
}
h := &MessageHelloVerifyRequest{}
if err := h.Unmarshal(rawHelloVerifyRequest); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(h, parsedHelloVerifyRequest) {
t.Errorf("handshakeMessageClientHello unmarshal: got %#v, want %#v", h, parsedHelloVerifyRequest)
}
raw, err := h.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawHelloVerifyRequest) {
t.Errorf("handshakeMessageClientHello marshal: got %#v, want %#v", raw, rawHelloVerifyRequest)
}
}

View File

@@ -0,0 +1,111 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
)
// MessageServerHello is sent in response to a ClientHello
// message when it was able to find an acceptable set of algorithms.
// If it cannot find such a match, it will respond with a handshake
// failure alert.
//
// https://tools.ietf.org/html/rfc5246#section-7.4.1.3
type MessageServerHello struct {
Version protocol.Version
Random Random
SessionID []byte // TODO 添加anylink支持
CipherSuiteID *uint16
CompressionMethod *protocol.CompressionMethod
Extensions []extension.Extension
}
const messageServerHelloVariableWidthStart = 2 + RandomLength
// Type returns the Handshake Type
func (m MessageServerHello) Type() Type {
return TypeServerHello
}
// Marshal encodes the Handshake
func (m *MessageServerHello) Marshal() ([]byte, error) {
if m.CipherSuiteID == nil {
return nil, errCipherSuiteUnset
} else if m.CompressionMethod == nil {
return nil, errCompressionMethodUnset
}
out := make([]byte, messageServerHelloVariableWidthStart)
out[0] = m.Version.Major
out[1] = m.Version.Minor
rand := m.Random.MarshalFixed()
copy(out[2:], rand[:])
// out = append(out, 0x00) // SessionID
// TODO 添加SessionID
out = append(out, byte(len(m.SessionID))) // SessionID
out = append(out, m.SessionID...)
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], *m.CipherSuiteID)
out = append(out, byte(m.CompressionMethod.ID))
extensions, err := extension.Marshal(m.Extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
// Unmarshal populates the message from encoded data
func (m *MessageServerHello) Unmarshal(data []byte) error {
if len(data) < 2+RandomLength {
return errBufferTooSmall
}
m.Version.Major = data[0]
m.Version.Minor = data[1]
var random [RandomLength]byte
copy(random[:], data[2:])
m.Random.UnmarshalFixed(random)
currOffset := messageServerHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
if len(data) < (currOffset + 2) {
return errBufferTooSmall
}
m.CipherSuiteID = new(uint16)
*m.CipherSuiteID = binary.BigEndian.Uint16(data[currOffset:])
currOffset += 2
if len(data) < currOffset {
return errBufferTooSmall
}
if compressionMethod, ok := protocol.CompressionMethods()[protocol.CompressionMethodID(data[currOffset])]; ok {
m.CompressionMethod = compressionMethod
currOffset++
} else {
return errInvalidCompressionMethod
}
if len(data) <= currOffset {
m.Extensions = []extension.Extension{}
return nil
}
extensions, err := extension.Unmarshal(data[currOffset:])
if err != nil {
return err
}
m.Extensions = extensions
return nil
}

View File

@@ -0,0 +1,22 @@
package handshake
// MessageServerHelloDone is final non-encrypted message from server
// this communicates server has sent all its handshake messages and next
// should be MessageFinished
type MessageServerHelloDone struct {
}
// Type returns the Handshake Type
func (m MessageServerHelloDone) Type() Type {
return TypeServerHelloDone
}
// Marshal encodes the Handshake
func (m *MessageServerHelloDone) Marshal() ([]byte, error) {
return []byte{}, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageServerHelloDone) Unmarshal(data []byte) error {
return nil
}

View File

@@ -0,0 +1,25 @@
package handshake
import (
"reflect"
"testing"
)
func TestHandshakeMessageServerHelloDone(t *testing.T) {
rawServerHelloDone := []byte{}
parsedServerHelloDone := &MessageServerHelloDone{}
c := &MessageServerHelloDone{}
if err := c.Unmarshal(rawServerHelloDone); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedServerHelloDone) {
t.Errorf("handshakeMessageServerHelloDone unmarshal: got %#v, want %#v", c, parsedServerHelloDone)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawServerHelloDone) {
t.Errorf("handshakeMessageServerHelloDone marshal: got %#v, want %#v", raw, rawServerHelloDone)
}
}

View File

@@ -0,0 +1,46 @@
package handshake
import (
"reflect"
"testing"
"time"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
)
func TestHandshakeMessageServerHello(t *testing.T) {
rawServerHello := []byte{
0xfe, 0xfd, 0x21, 0x63, 0x32, 0x21, 0x81, 0x0e, 0x98, 0x6c,
0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20,
0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e,
0xcf, 0x63, 0x84, 0x28, 0x00, 0xc0, 0x2b, 0x00, 0x00, 0x00,
}
cipherSuiteID := uint16(0xc02b)
parsedServerHello := &MessageServerHello{
Version: protocol.Version{Major: 0xFE, Minor: 0xFD},
Random: Random{
GMTUnixTime: time.Unix(560149025, 0),
RandomBytes: [28]byte{0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28},
},
CipherSuiteID: &cipherSuiteID,
CompressionMethod: &protocol.CompressionMethod{},
Extensions: []extension.Extension{},
}
c := &MessageServerHello{}
if err := c.Unmarshal(rawServerHello); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedServerHello) {
t.Errorf("handshakeMessageServerHello unmarshal: got %#v, want %#v", c, parsedServerHello)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawServerHello) {
t.Errorf("handshakeMessageServerHello marshal: got %#v, want %#v", raw, rawServerHello)
}
}

View File

@@ -0,0 +1,119 @@
package handshake
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
)
// MessageServerKeyExchange supports ECDH and PSK
type MessageServerKeyExchange struct {
IdentityHint []byte
EllipticCurveType elliptic.CurveType
NamedCurve elliptic.Curve
PublicKey []byte
HashAlgorithm hash.Algorithm
SignatureAlgorithm signature.Algorithm
Signature []byte
}
// Type returns the Handshake Type
func (m MessageServerKeyExchange) Type() Type {
return TypeServerKeyExchange
}
// Marshal encodes the Handshake
func (m *MessageServerKeyExchange) Marshal() ([]byte, error) {
if m.IdentityHint != nil {
out := append([]byte{0x00, 0x00}, m.IdentityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
out := []byte{byte(m.EllipticCurveType), 0x00, 0x00}
binary.BigEndian.PutUint16(out[1:], uint16(m.NamedCurve))
out = append(out, byte(len(m.PublicKey)))
out = append(out, m.PublicKey...)
if m.HashAlgorithm == hash.None && m.SignatureAlgorithm == signature.Anonymous && len(m.Signature) == 0 {
return out, nil
}
out = append(out, []byte{byte(m.HashAlgorithm), byte(m.SignatureAlgorithm), 0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature)))
out = append(out, m.Signature...)
return out, nil
}
// Unmarshal populates the message from encoded data
func (m *MessageServerKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
m.IdentityHint = append([]byte{}, data[2:]...)
return nil
}
if _, ok := elliptic.CurveTypes()[elliptic.CurveType(data[0])]; ok {
m.EllipticCurveType = elliptic.CurveType(data[0])
} else {
return errInvalidEllipticCurveType
}
if len(data[1:]) < 2 {
return errBufferTooSmall
}
m.NamedCurve = elliptic.Curve(binary.BigEndian.Uint16(data[1:3]))
if _, ok := elliptic.Curves()[m.NamedCurve]; !ok {
return errInvalidNamedCurve
}
if len(data) < 4 {
return errBufferTooSmall
}
publicKeyLength := int(data[3])
offset := 4 + publicKeyLength
if len(data) < offset {
return errBufferTooSmall
}
m.PublicKey = append([]byte{}, data[4:offset]...)
// Anon connection doesn't contains hashAlgorithm, signatureAlgorithm, signature
if len(data) == offset {
return nil
} else if len(data) <= offset {
return errBufferTooSmall
}
m.HashAlgorithm = hash.Algorithm(data[offset])
if _, ok := hash.Algorithms()[m.HashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
offset++
if len(data) <= offset {
return errBufferTooSmall
}
m.SignatureAlgorithm = signature.Algorithm(data[offset])
if _, ok := signature.Algorithms()[m.SignatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
offset++
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if len(data) < offset+signatureLength {
return errBufferTooSmall
}
m.Signature = append([]byte{}, data[offset:offset+signatureLength]...)
return nil
}

View File

@@ -0,0 +1,71 @@
package handshake
import (
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
)
func TestHandshakeMessageServerKeyExchange(t *testing.T) {
test := func(rawServerKeyExchange []byte, parsedServerKeyExchange *MessageServerKeyExchange) {
c := &MessageServerKeyExchange{}
if err := c.Unmarshal(rawServerKeyExchange); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(c, parsedServerKeyExchange) {
t.Errorf("handshakeMessageServerKeyExchange unmarshal: got %#v, want %#v", c, parsedServerKeyExchange)
}
raw, err := c.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawServerKeyExchange) {
t.Errorf("handshakeMessageServerKeyExchange marshal: got %#v, want %#v", raw, rawServerKeyExchange)
}
}
t.Run("Hash+Signature", func(t *testing.T) {
rawServerKeyExchange := []byte{
0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, 0xaf,
0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, 0x2b, 0x08,
0x0c, 0xd4, 0x16, 0x5b, 0xcc, 0x81, 0xf2, 0x58, 0x91, 0x8e, 0x62, 0xdf, 0xc1, 0xec, 0x72, 0xe8,
0x47, 0x24, 0x42, 0x96, 0xb8, 0x7b, 0xee, 0xe7, 0x0d, 0xdc, 0x44, 0xec, 0xf3, 0x97, 0x6b, 0x1b,
0x45, 0x28, 0xac, 0x3f, 0x35, 0x02, 0x03, 0x00, 0x47, 0x30, 0x45, 0x02, 0x21, 0x00, 0xb2, 0x0b,
0x22, 0x95, 0x3d, 0x56, 0x57, 0x6a, 0x3f, 0x85, 0x30, 0x6f, 0x55, 0xc3, 0xf4, 0x24, 0x1b, 0x21,
0x07, 0xe5, 0xdf, 0xba, 0x24, 0x02, 0x68, 0x95, 0x1f, 0x6e, 0x13, 0xbd, 0x9f, 0xaa, 0x02, 0x20,
0x49, 0x9c, 0x9d, 0xdf, 0x84, 0x60, 0x33, 0x27, 0x96, 0x9e, 0x58, 0x6d, 0x72, 0x13, 0xe7, 0x3a,
0xe8, 0xdf, 0x43, 0x75, 0xc7, 0xb9, 0x37, 0x6e, 0x90, 0xe5, 0x3b, 0x81, 0xd4, 0xda, 0x68, 0xcd,
}
parsedServerKeyExchange := &MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: elliptic.X25519,
PublicKey: rawServerKeyExchange[4:69],
HashAlgorithm: hash.SHA1,
SignatureAlgorithm: signature.ECDSA,
Signature: rawServerKeyExchange[73:144],
}
test(rawServerKeyExchange, parsedServerKeyExchange)
})
t.Run("Anonymous", func(t *testing.T) {
rawServerKeyExchange := []byte{
0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, 0xaf,
0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, 0x2b, 0x08,
0x0c, 0xd4, 0x16, 0x5b, 0xcc, 0x81, 0xf2, 0x58, 0x91, 0x8e, 0x62, 0xdf, 0xc1, 0xec, 0x72, 0xe8,
0x47, 0x24, 0x42, 0x96, 0xb8, 0x7b, 0xee, 0xe7, 0x0d, 0xdc, 0x44, 0xec, 0xf3, 0x97, 0x6b, 0x1b,
0x45, 0x28, 0xac, 0x3f, 0x35,
}
parsedServerKeyExchange := &MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: elliptic.X25519,
PublicKey: rawServerKeyExchange[4:69],
HashAlgorithm: hash.None,
SignatureAlgorithm: signature.Anonymous,
}
test(rawServerKeyExchange, parsedServerKeyExchange)
})
}

View File

@@ -0,0 +1,49 @@
package handshake
import (
"crypto/rand"
"encoding/binary"
"time"
)
// Consts for Random in Handshake
const (
RandomBytesLength = 28
RandomLength = RandomBytesLength + 4
)
// Random value that is used in ClientHello and ServerHello
//
// https://tools.ietf.org/html/rfc4346#section-7.4.1.2
type Random struct {
GMTUnixTime time.Time
RandomBytes [RandomBytesLength]byte
}
// MarshalFixed encodes the Handshake
func (r *Random) MarshalFixed() [RandomLength]byte {
var out [RandomLength]byte
binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix()))
copy(out[4:], r.RandomBytes[:])
return out
}
// UnmarshalFixed populates the message from encoded data
func (r *Random) UnmarshalFixed(data [RandomLength]byte) {
r.GMTUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0)
copy(r.RandomBytes[:], data[4:])
}
// Populate fills the handshakeRandom with random values
// may be called multiple times
func (r *Random) Populate() error {
r.GMTUnixTime = time.Now()
tmp := make([]byte, RandomBytesLength)
_, err := rand.Read(tmp)
copy(r.RandomBytes[:], tmp)
return err
}

View File

@@ -0,0 +1,16 @@
// Package recordlayer implements the TLS Record Layer https://tools.ietf.org/html/rfc5246#section-6
package recordlayer
import (
"errors"
"github.com/pion/dtls/v2/pkg/protocol"
)
var (
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113
errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
)

View File

@@ -0,0 +1,61 @@
package recordlayer
import (
"encoding/binary"
"github.com/pion/dtls/v2/internal/util"
"github.com/pion/dtls/v2/pkg/protocol"
)
// Header implements a TLS RecordLayer header
type Header struct {
ContentType protocol.ContentType
ContentLen uint16
Version protocol.Version
Epoch uint16
SequenceNumber uint64 // uint48 in spec
}
// RecordLayer enums
const (
HeaderSize = 13
MaxSequenceNumber = 0x0000FFFFFFFFFFFF
)
// Marshal encodes a TLS RecordLayer Header to binary
func (h *Header) Marshal() ([]byte, error) {
if h.SequenceNumber > MaxSequenceNumber {
return nil, errSequenceNumberOverflow
}
out := make([]byte, HeaderSize)
out[0] = byte(h.ContentType)
out[1] = h.Version.Major
out[2] = h.Version.Minor
binary.BigEndian.PutUint16(out[3:], h.Epoch)
util.PutBigEndianUint48(out[5:], h.SequenceNumber)
binary.BigEndian.PutUint16(out[HeaderSize-2:], h.ContentLen)
return out, nil
}
// Unmarshal populates a TLS RecordLayer Header from binary
func (h *Header) Unmarshal(data []byte) error {
if len(data) < HeaderSize {
return errBufferTooSmall
}
h.ContentType = protocol.ContentType(data[0])
h.Version.Major = data[1]
h.Version.Minor = data[2]
h.Epoch = binary.BigEndian.Uint16(data[3:])
// SequenceNumber is stored as uint48, make into uint64
seqCopy := make([]byte, 8)
copy(seqCopy[2:], data[5:11])
h.SequenceNumber = binary.BigEndian.Uint64(seqCopy)
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
return errUnsupportedProtocolVersion
}
return nil
}

View File

@@ -0,0 +1,99 @@
package recordlayer
import (
"encoding/binary"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
)
// RecordLayer which handles all data transport.
// The record layer is assumed to sit directly on top of some
// reliable transport such as TCP. The record layer can carry four types of content:
//
// 1. Handshake messages—used for algorithm negotiation and key establishment.
// 2. ChangeCipherSpec messages—really part of the handshake but technically a separate kind of message.
// 3. Alert messages—used to signal that errors have occurred
// 4. Application layer data
//
// The DTLS record layer is extremely similar to that of TLS 1.1. The
// only change is the inclusion of an explicit sequence number in the
// record. This sequence number allows the recipient to correctly
// verify the TLS MAC.
//
// https://tools.ietf.org/html/rfc4347#section-4.1
type RecordLayer struct {
Header Header
Content protocol.Content
}
// Marshal encodes the RecordLayer to binary
func (r *RecordLayer) Marshal() ([]byte, error) {
contentRaw, err := r.Content.Marshal()
if err != nil {
return nil, err
}
r.Header.ContentLen = uint16(len(contentRaw))
r.Header.ContentType = r.Content.ContentType()
headerRaw, err := r.Header.Marshal()
if err != nil {
return nil, err
}
return append(headerRaw, contentRaw...), nil
}
// Unmarshal populates the RecordLayer from binary
func (r *RecordLayer) Unmarshal(data []byte) error {
if len(data) < HeaderSize {
return errBufferTooSmall
}
if err := r.Header.Unmarshal(data); err != nil {
return err
}
switch protocol.ContentType(data[0]) {
case protocol.ContentTypeChangeCipherSpec:
r.Content = &protocol.ChangeCipherSpec{}
case protocol.ContentTypeAlert:
r.Content = &alert.Alert{}
case protocol.ContentTypeHandshake:
r.Content = &handshake.Handshake{}
case protocol.ContentTypeApplicationData:
r.Content = &protocol.ApplicationData{}
default:
return errInvalidContentType
}
return r.Content.Unmarshal(data[HeaderSize:])
}
// UnpackDatagram extracts all RecordLayer messages from a single datagram.
// Note that as with TLS, multiple handshake messages may be placed in
// the same DTLS record, provided that there is room and that they are
// part of the same flight. Thus, there are two acceptable ways to pack
// two DTLS messages into the same datagram: in the same record or in
// separate records.
// https://tools.ietf.org/html/rfc6347#section-4.2.3
func UnpackDatagram(buf []byte) ([][]byte, error) {
out := [][]byte{}
for offset := 0; len(buf) != offset; {
if len(buf)-offset <= HeaderSize {
return nil, errInvalidPacketLength
}
pktLen := (HeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
}
out = append(out, buf[offset:offset+pktLen])
offset += pktLen
}
return out, nil
}

View File

@@ -0,0 +1,92 @@
package recordlayer
import (
"errors"
"reflect"
"testing"
"github.com/pion/dtls/v2/pkg/protocol"
)
func TestUDPDecode(t *testing.T) {
for _, test := range []struct {
Name string
Data []byte
Want [][]byte
WantError error
}{
{
Name: "Change Cipher Spec, single packet",
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
Want: [][]byte{
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
},
},
{
Name: "Change Cipher Spec, multi packet",
Data: []byte{
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01,
},
Want: [][]byte{
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01},
},
},
{
Name: "Invalid packet length",
Data: []byte{0x14, 0xfe},
WantError: errInvalidPacketLength,
},
{
Name: "Packet declared invalid length",
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01},
WantError: errInvalidPacketLength,
},
} {
dtlsPkts, err := UnpackDatagram(test.Data)
if !errors.Is(err, test.WantError) {
t.Errorf("Unexpected Error %q: exp: %v got: %v", test.Name, test.WantError, err)
} else if !reflect.DeepEqual(test.Want, dtlsPkts) {
t.Errorf("%q UDP decode: got %q, want %q", test.Name, dtlsPkts, test.Want)
}
}
}
func TestRecordLayerRoundTrip(t *testing.T) {
for _, test := range []struct {
Name string
Data []byte
Want *RecordLayer
WantMarshalError error
WantUnmarshalError error
}{
{
Name: "Change Cipher Spec, single packet",
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
Want: &RecordLayer{
Header: Header{
ContentType: protocol.ContentTypeChangeCipherSpec,
Version: protocol.Version{Major: 0xfe, Minor: 0xff},
Epoch: 0,
SequenceNumber: 18,
},
Content: &protocol.ChangeCipherSpec{},
},
},
} {
r := &RecordLayer{}
if err := r.Unmarshal(test.Data); !errors.Is(err, test.WantUnmarshalError) {
t.Errorf("Unexpected Error %q: exp: %v got: %v", test.Name, test.WantUnmarshalError, err)
} else if !reflect.DeepEqual(test.Want, r) {
t.Errorf("%q recordLayer.unmarshal: got %q, want %q", test.Name, r, test.Want)
}
data, marshalErr := r.Marshal()
if !errors.Is(marshalErr, test.WantMarshalError) {
t.Errorf("Unexpected Error %q: exp: %v got: %v", test.Name, test.WantMarshalError, marshalErr)
} else if !reflect.DeepEqual(test.Data, data) {
t.Errorf("%q recordLayer.marshal: got % 02x, want % 02x", test.Name, data, test.Data)
}
}
}

View File

@@ -0,0 +1,21 @@
// Package protocol provides the DTLS wire format
package protocol
// Version enums
var (
Version1_0 = Version{Major: 0xfe, Minor: 0xff} //nolint:gochecknoglobals
Version1_2 = Version{Major: 0xfe, Minor: 0xfd} //nolint:gochecknoglobals
)
// Version is the minor/major value in the RecordLayer
// and ClientHello/ServerHello
//
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type Version struct {
Major, Minor uint8
}
// Equal determines if two protocol versions are equal
func (v Version) Equal(x Version) bool {
return v.Major == x.Major && v.Minor == x.Minor
}