mirror of https://github.com/bjdgyc/anylink.git
commit
c744983cbc
|
@ -60,6 +60,7 @@ sudo ./anylink -conf="conf/server.toml"
|
||||||
|
|
||||||
- [x] IP分配(实现IP、MAC映射信息的持久化)
|
- [x] IP分配(实现IP、MAC映射信息的持久化)
|
||||||
- [x] TLS-TCP通道
|
- [x] TLS-TCP通道
|
||||||
|
- [x] DTLS-UDP通道
|
||||||
- [x] 兼容AnyConnect
|
- [x] 兼容AnyConnect
|
||||||
- [x] 基于tun设备的nat访问模式
|
- [x] 基于tun设备的nat访问模式
|
||||||
- [x] 基于tap设备的桥接访问模式
|
- [x] 基于tap设备的桥接访问模式
|
||||||
|
@ -72,8 +73,6 @@ sudo ./anylink -conf="conf/server.toml"
|
||||||
- [x] 后台管理界面
|
- [x] 后台管理界面
|
||||||
- [x] 访问权限管理
|
- [x] 访问权限管理
|
||||||
|
|
||||||
- [ ] DTLS-UDP通道
|
|
||||||
|
|
||||||
## Config
|
## Config
|
||||||
|
|
||||||
默认配置文件内有详细的注释,根据注释填写配置即可。
|
默认配置文件内有详细的注释,根据注释填写配置即可。
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# http://editorconfig.org/
|
||||||
|
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
charset = utf-8
|
||||||
|
insert_final_newline = true
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
end_of_line = lf
|
||||||
|
|
||||||
|
[*.go]
|
||||||
|
indent_style = tab
|
||||||
|
indent_size = 4
|
||||||
|
|
||||||
|
[{*.yml,*.yaml}]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
# Makefiles always use tabs for indentation
|
||||||
|
[Makefile]
|
||||||
|
indent_style = tab
|
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||||
|
|
||||||
|
if [ -f ${SCRIPT_PATH}/.ci.conf ]
|
||||||
|
then
|
||||||
|
. ${SCRIPT_PATH}/.ci.conf
|
||||||
|
fi
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS
|
||||||
|
#
|
||||||
|
EXCLUDED_CONTRIBUTORS+=('John R. Bradley' 'renovate[bot]' 'Renovate Bot' 'Pion Bot')
|
||||||
|
# If you want to exclude a name from all repositories, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
# If you want to exclude a name only from this repository,
|
||||||
|
# add EXCLUDED_CONTRIBUTORS=('name') to .github/.ci.conf
|
||||||
|
|
||||||
|
MISSING_CONTRIBUTORS=()
|
||||||
|
|
||||||
|
shouldBeIncluded () {
|
||||||
|
for i in "${EXCLUDED_CONTRIBUTORS[@]}"
|
||||||
|
do
|
||||||
|
if [ "$i" == "$1" ] ; then
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
IFS=$'\n' #Only split on newline
|
||||||
|
for contributor in $(git log --format='%aN' | sort -u)
|
||||||
|
do
|
||||||
|
if shouldBeIncluded $contributor; then
|
||||||
|
if ! grep -q "$contributor" "$SCRIPT_PATH/../README.md"; then
|
||||||
|
MISSING_CONTRIBUTORS+=("$contributor")
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
unset IFS
|
||||||
|
|
||||||
|
if [ ${#MISSING_CONTRIBUTORS[@]} -ne 0 ]; then
|
||||||
|
echo "Please add the following contributors to the README"
|
||||||
|
for i in "${MISSING_CONTRIBUTORS[@]}"
|
||||||
|
do
|
||||||
|
echo "$i"
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
|
@ -0,0 +1,11 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE DIRECTLY
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
.github/lint-commit-message.sh $1
|
|
@ -0,0 +1,12 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE DIRECTLY
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
# Redirect output to stderr.
|
||||||
|
exec 1>&2
|
||||||
|
|
||||||
|
.github/lint-disallowed-functions-in-library.sh
|
|
@ -0,0 +1,13 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE DIRECTLY
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
.github/assert-contributors.sh
|
||||||
|
|
||||||
|
exit 0
|
|
@ -0,0 +1,16 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||||
|
|
||||||
|
cp "$SCRIPT_PATH/hooks/commit-msg.sh" "$SCRIPT_PATH/../.git/hooks/commit-msg"
|
||||||
|
cp "$SCRIPT_PATH/hooks/pre-commit.sh" "$SCRIPT_PATH/../.git/hooks/pre-commit"
|
||||||
|
cp "$SCRIPT_PATH/hooks/pre-push.sh" "$SCRIPT_PATH/../.git/hooks/pre-push"
|
|
@ -0,0 +1,64 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
display_commit_message_error() {
|
||||||
|
cat << EndOfMessage
|
||||||
|
$1
|
||||||
|
|
||||||
|
-------------------------------------------------
|
||||||
|
The preceding commit message is invalid
|
||||||
|
it failed '$2' of the following checks
|
||||||
|
|
||||||
|
* Separate subject from body with a blank line
|
||||||
|
* Limit the subject line to 50 characters
|
||||||
|
* Capitalize the subject line
|
||||||
|
* Do not end the subject line with a period
|
||||||
|
* Wrap the body at 72 characters
|
||||||
|
EndOfMessage
|
||||||
|
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
lint_commit_message() {
|
||||||
|
if [[ "$(echo "$1" | awk 'NR == 2 {print $1;}' | wc -c)" -ne 1 ]]; then
|
||||||
|
display_commit_message_error "$1" 'Separate subject from body with a blank line'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$(echo "$1" | head -n1 | awk '{print length}')" -gt 50 ]]; then
|
||||||
|
display_commit_message_error "$1" 'Limit the subject line to 50 characters'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! $1 =~ ^[A-Z] ]]; then
|
||||||
|
display_commit_message_error "$1" 'Capitalize the subject line'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$(echo "$1" | awk 'NR == 1 {print substr($0,length($0),1)}')" == "." ]]; then
|
||||||
|
display_commit_message_error "$1" 'Do not end the subject line with a period'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$(echo "$1" | awk '{print length}' | sort -nr | head -1)" -gt 72 ]]; then
|
||||||
|
display_commit_message_error "$1" 'Wrap the body at 72 characters'
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ "$#" -eq 1 ]; then
|
||||||
|
if [ ! -f "$1" ]; then
|
||||||
|
echo "$0 was passed one argument, but was not a valid file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
lint_commit_message "$(sed -n '/# Please enter the commit message for your changes. Lines starting/q;p' "$1")"
|
||||||
|
else
|
||||||
|
for commit in $(git rev-list --no-merges origin/master..); do
|
||||||
|
lint_commit_message "$(git log --format="%B" -n 1 $commit)"
|
||||||
|
done
|
||||||
|
fi
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Disallow usages of functions that cause the program to exit in the library code
|
||||||
|
SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||||
|
if [ -f ${SCRIPT_PATH}/.ci.conf ]
|
||||||
|
then
|
||||||
|
. ${SCRIPT_PATH}/.ci.conf
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXCLUDE_DIRECTORIES=${DISALLOWED_FUNCTIONS_EXCLUDED_DIRECTORIES:-"examples"}
|
||||||
|
DISALLOWED_FUNCTIONS=('os.Exit(' 'panic(' 'Fatal(' 'Fatalf(' 'Fatalln(' 'fmt.Println(' 'fmt.Printf(' 'log.Print(' 'log.Println(' 'log.Printf(')
|
||||||
|
|
||||||
|
files=$(
|
||||||
|
find "$SCRIPT_PATH/.." -name "*.go" \
|
||||||
|
| grep -v -e '^.*_test.go$' \
|
||||||
|
| while read file
|
||||||
|
do
|
||||||
|
excluded=false
|
||||||
|
for ex in $EXCLUDE_DIRECTORIES
|
||||||
|
do
|
||||||
|
if [[ $file == */$ex/* ]]
|
||||||
|
then
|
||||||
|
excluded=true
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
$excluded || echo "$file"
|
||||||
|
done
|
||||||
|
)
|
||||||
|
|
||||||
|
for disallowedFunction in "${DISALLOWED_FUNCTIONS[@]}"
|
||||||
|
do
|
||||||
|
if grep -e "$disallowedFunction" $files | grep -v -e 'nolint'; then
|
||||||
|
echo "$disallowedFunction may only be used in example code"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
|
@ -0,0 +1,24 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||||
|
GO_REGEX="^[a-zA-Z][a-zA-Z0-9_]*\.go$"
|
||||||
|
|
||||||
|
find "$SCRIPT_PATH/.." -name "*.go" | while read fullpath; do
|
||||||
|
filename=$(basename -- "$fullpath")
|
||||||
|
|
||||||
|
if ! [[ $filename =~ $GO_REGEX ]]; then
|
||||||
|
echo "$filename is not a valid filename for Go code, only alpha, numbers and underscores are supported"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
|
@ -0,0 +1,20 @@
|
||||||
|
name: E2E
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
e2e-test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
- name: test
|
||||||
|
run: |
|
||||||
|
docker build -t pion-dtls-e2e -f e2e/Dockerfile .
|
||||||
|
docker run -i --rm pion-dtls-e2e
|
|
@ -0,0 +1,43 @@
|
||||||
|
name: Lint
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- edited
|
||||||
|
- synchronize
|
||||||
|
jobs:
|
||||||
|
lint-commit-message:
|
||||||
|
name: Metadata
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Commit Message
|
||||||
|
run: .github/lint-commit-message.sh
|
||||||
|
|
||||||
|
- name: File names
|
||||||
|
run: .github/lint-filename.sh
|
||||||
|
|
||||||
|
- name: Contributors
|
||||||
|
run: .github/assert-contributors.sh
|
||||||
|
|
||||||
|
- name: Functions
|
||||||
|
run: .github/lint-disallowed-functions-in-library.sh
|
||||||
|
|
||||||
|
lint-go:
|
||||||
|
name: Go
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v2
|
||||||
|
with:
|
||||||
|
version: v1.31
|
||||||
|
args: $GOLANGCI_LINT_EXRA_ARGS
|
|
@ -0,0 +1,33 @@
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
# If this repository should have package specific CI config,
|
||||||
|
# remove the repository name from .goassets/.github/workflows/assets-sync.yml.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
name: go-mod-fix
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- renovate/*
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
go-mod-fix:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 2
|
||||||
|
- name: fix
|
||||||
|
uses: at-wat/go-sum-fix-action@v0
|
||||||
|
with:
|
||||||
|
git_user: Pion Bot
|
||||||
|
git_email: 59523206+pionbot@users.noreply.github.com
|
||||||
|
github_token: ${{ secrets.PIONBOT_PRIVATE_KEY }}
|
||||||
|
commit_style: squash
|
||||||
|
push: force
|
|
@ -0,0 +1,139 @@
|
||||||
|
name: Test
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go: ["1.15", "1.16"]
|
||||||
|
fail-fast: false
|
||||||
|
name: Go ${{ matrix.go }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/go/bin
|
||||||
|
~/.cache
|
||||||
|
key: ${{ runner.os }}-amd64-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-amd64-go-
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
|
- name: Setup go-acc
|
||||||
|
run: |
|
||||||
|
go get github.com/ory/go-acc
|
||||||
|
git checkout go.mod go.sum
|
||||||
|
|
||||||
|
- name: Run test
|
||||||
|
run: |
|
||||||
|
go-acc -o cover.out ./... -- \
|
||||||
|
-bench=. \
|
||||||
|
-v -race
|
||||||
|
|
||||||
|
- uses: codecov/codecov-action@v1
|
||||||
|
with:
|
||||||
|
file: ./cover.out
|
||||||
|
name: codecov-umbrella
|
||||||
|
fail_ci_if_error: true
|
||||||
|
flags: go
|
||||||
|
|
||||||
|
test-i386:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go: ["1.15", "1.16"]
|
||||||
|
fail-fast: false
|
||||||
|
name: Go i386 ${{ matrix.go }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache
|
||||||
|
key: ${{ runner.os }}-i386-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-i386-go-
|
||||||
|
|
||||||
|
- name: Run test
|
||||||
|
run: |
|
||||||
|
mkdir -p $HOME/go/pkg/mod $HOME/.cache
|
||||||
|
docker run \
|
||||||
|
-u $(id -u):$(id -g) \
|
||||||
|
-e "GO111MODULE=on" \
|
||||||
|
-e "CGO_ENABLED=0" \
|
||||||
|
-v $GITHUB_WORKSPACE:/go/src/github.com/pion/$(basename $GITHUB_WORKSPACE) \
|
||||||
|
-v $HOME/go/pkg/mod:/go/pkg/mod \
|
||||||
|
-v $HOME/.cache:/.cache \
|
||||||
|
-w /go/src/github.com/pion/$(basename $GITHUB_WORKSPACE) \
|
||||||
|
i386/golang:${{matrix.go}}-alpine \
|
||||||
|
/usr/local/go/bin/go test \
|
||||||
|
${TEST_EXTRA_ARGS:-} \
|
||||||
|
-v ./...
|
||||||
|
|
||||||
|
test-wasm:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
name: WASM
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Use Node.js
|
||||||
|
uses: actions/setup-node@v2
|
||||||
|
with:
|
||||||
|
node-version: '12.x'
|
||||||
|
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache
|
||||||
|
key: ${{ runner.os }}-wasm-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-wasm-go-
|
||||||
|
|
||||||
|
- name: Download Go
|
||||||
|
run: curl -sSfL https://dl.google.com/go/go${GO_VERSION}.linux-amd64.tar.gz | tar -C ~ -xzf -
|
||||||
|
env:
|
||||||
|
GO_VERSION: 1.16
|
||||||
|
|
||||||
|
- name: Set Go Root
|
||||||
|
run: echo "GOROOT=${HOME}/go" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Set Go Path
|
||||||
|
run: echo "GOPATH=${HOME}/go" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Set Go Path
|
||||||
|
run: echo "GO_JS_WASM_EXEC=${GOROOT}/misc/wasm/go_js_wasm_exec" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Insall NPM modules
|
||||||
|
run: yarn install
|
||||||
|
|
||||||
|
- name: Run Tests
|
||||||
|
run: |
|
||||||
|
GOOS=js GOARCH=wasm $GOPATH/bin/go test \
|
||||||
|
-coverprofile=cover.out -covermode=atomic \
|
||||||
|
-exec="${GO_JS_WASM_EXEC}" \
|
||||||
|
-v ./...
|
||||||
|
|
||||||
|
- uses: codecov/codecov-action@v1
|
||||||
|
with:
|
||||||
|
file: ./cover.out
|
||||||
|
name: codecov-umbrella
|
||||||
|
fail_ci_if_error: true
|
||||||
|
flags: wasm
|
|
@ -0,0 +1,37 @@
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
# If this repository should have package specific CI config,
|
||||||
|
# remove the repository name from .goassets/.github/workflows/assets-sync.yml.
|
||||||
|
#
|
||||||
|
# If you want to update the shared CI config, send a PR to
|
||||||
|
# https://github.com/pion/.goassets instead of this repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
name: Go mod tidy
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
Check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
- name: check
|
||||||
|
run: |
|
||||||
|
go mod download
|
||||||
|
go mod tidy
|
||||||
|
if ! git diff --exit-code
|
||||||
|
then
|
||||||
|
echo "Not go mod tidied"
|
||||||
|
exit 1
|
||||||
|
fi
|
|
@ -0,0 +1,24 @@
|
||||||
|
### JetBrains IDE ###
|
||||||
|
#####################
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
### Emacs Temporary Files ###
|
||||||
|
#############################
|
||||||
|
*~
|
||||||
|
|
||||||
|
### Folders ###
|
||||||
|
###############
|
||||||
|
bin/
|
||||||
|
vendor/
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
### Files ###
|
||||||
|
#############
|
||||||
|
*.ivf
|
||||||
|
*.ogg
|
||||||
|
tags
|
||||||
|
cover.out
|
||||||
|
*.sw[poe]
|
||||||
|
*.wasm
|
||||||
|
examples/sfu-ws/cert.pem
|
||||||
|
examples/sfu-ws/key.pem
|
|
@ -0,0 +1,89 @@
|
||||||
|
linters-settings:
|
||||||
|
govet:
|
||||||
|
check-shadowing: true
|
||||||
|
misspell:
|
||||||
|
locale: US
|
||||||
|
exhaustive:
|
||||||
|
default-signifies-exhaustive: true
|
||||||
|
gomodguard:
|
||||||
|
blocked:
|
||||||
|
modules:
|
||||||
|
- github.com/pkg/errors:
|
||||||
|
recommendations:
|
||||||
|
- errors
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||||
|
- bodyclose # checks whether HTTP response body is closed successfully
|
||||||
|
- deadcode # Finds unused code
|
||||||
|
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
||||||
|
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||||
|
- dupl # Tool for code clone detection
|
||||||
|
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||||
|
- exhaustive # check exhaustiveness of enum switch statements
|
||||||
|
- exportloopref # checks for pointers to enclosing loop variables
|
||||||
|
- gci # Gci control golang package import order and make it always deterministic.
|
||||||
|
- gochecknoglobals # Checks that no globals are present in Go code
|
||||||
|
- gochecknoinits # Checks that no init functions are present in Go code
|
||||||
|
- gocognit # Computes and checks the cognitive complexity of functions
|
||||||
|
- goconst # Finds repeated strings that could be replaced by a constant
|
||||||
|
- gocritic # The most opinionated Go source code linter
|
||||||
|
- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||||
|
- goerr113 # Golang linter to check the errors handling expressions
|
||||||
|
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||||
|
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
||||||
|
- goheader # Checks is file header matches to pattern
|
||||||
|
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
||||||
|
- golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
|
||||||
|
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||||
|
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
||||||
|
- gosec # Inspects source code for security problems
|
||||||
|
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||||
|
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||||
|
- ineffassign # Detects when assignments to existing variables are not used
|
||||||
|
- misspell # Finds commonly misspelled English words in comments
|
||||||
|
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||||
|
- noctx # noctx finds sending http request without context.Context
|
||||||
|
- scopelint # Scopelint checks for unpinned variables in go programs
|
||||||
|
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
|
||||||
|
- structcheck # Finds unused struct fields
|
||||||
|
- stylecheck # Stylecheck is a replacement for golint
|
||||||
|
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||||
|
- unconvert # Remove unnecessary type conversions
|
||||||
|
- unparam # Reports unused function parameters
|
||||||
|
- unused # Checks Go code for unused constants, variables, functions and types
|
||||||
|
- varcheck # Finds unused global variables and constants
|
||||||
|
- whitespace # Tool for detection of leading and trailing whitespace
|
||||||
|
disable:
|
||||||
|
- funlen # Tool for detection of long functions
|
||||||
|
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||||
|
- godot # Check if comments end in a period
|
||||||
|
- gomnd # An analyzer to detect magic numbers.
|
||||||
|
- lll # Reports long lines
|
||||||
|
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
|
||||||
|
- nestif # Reports deeply nested if statements
|
||||||
|
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||||
|
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||||
|
- prealloc # Finds slice declarations that could potentially be preallocated
|
||||||
|
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||||
|
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||||
|
- testpackage # linter that makes you use a separate _test package
|
||||||
|
- wsl # Whitespace Linter - Forces you to use empty lines!
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-use-default: false
|
||||||
|
exclude-rules:
|
||||||
|
# Allow complex tests, better to be self contained
|
||||||
|
- path: _test\.go
|
||||||
|
linters:
|
||||||
|
- gocognit
|
||||||
|
|
||||||
|
# Allow complex main function in examples
|
||||||
|
- path: examples
|
||||||
|
text: "of func `main` is high"
|
||||||
|
linters:
|
||||||
|
- gocognit
|
||||||
|
|
||||||
|
run:
|
||||||
|
skip-dirs-use-default: false
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2018
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,6 @@
|
||||||
|
fuzz-build-record-layer: fuzz-prepare
|
||||||
|
go-fuzz-build -tags gofuzz -func FuzzRecordLayer
|
||||||
|
fuzz-run-record-layer:
|
||||||
|
go-fuzz -bin dtls-fuzz.zip -workdir fuzz
|
||||||
|
fuzz-prepare:
|
||||||
|
@GO111MODULE=on go mod vendor
|
|
@ -0,0 +1,156 @@
|
||||||
|
<h1 align="center">
|
||||||
|
<br>
|
||||||
|
Pion DTLS
|
||||||
|
<br>
|
||||||
|
</h1>
|
||||||
|
<h4 align="center">A Go implementation of DTLS</h4>
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-dtls-gray.svg?longCache=true&colorB=brightgreen" alt="Pion DTLS"></a>
|
||||||
|
<a href="https://sourcegraph.com/github.com/pion/dtls"><img src="https://sourcegraph.com/github.com/pion/dtls/-/badge.svg" alt="Sourcegraph Widget"></a>
|
||||||
|
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
|
||||||
|
<br>
|
||||||
|
<a href="https://travis-ci.org/pion/dtls"><img src="https://travis-ci.org/pion/dtls.svg?branch=master" alt="Build Status"></a>
|
||||||
|
<a href="https://pkg.go.dev/github.com/pion/dtls"><img src="https://godoc.org/github.com/pion/dtls?status.svg" alt="GoDoc"></a>
|
||||||
|
<a href="https://codecov.io/gh/pion/dtls"><img src="https://codecov.io/gh/pion/dtls/branch/master/graph/badge.svg" alt="Coverage Status"></a>
|
||||||
|
<a href="https://goreportcard.com/report/github.com/pion/dtls"><img src="https://goreportcard.com/badge/github.com/pion/dtls" alt="Go Report Card"></a>
|
||||||
|
<a href="https://www.codacy.com/app/Sean-Der/dtls"><img src="https://api.codacy.com/project/badge/Grade/18f4aec384894e6aac0b94effe51961d" alt="Codacy Badge"></a>
|
||||||
|
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
|
||||||
|
</p>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
Native [DTLS 1.2][rfc6347] implementation in the Go programming language.
|
||||||
|
|
||||||
|
A long term goal is a professional security review, and maye inclusion in stdlib.
|
||||||
|
|
||||||
|
[rfc6347]: https://tools.ietf.org/html/rfc6347
|
||||||
|
|
||||||
|
### Goals/Progress
|
||||||
|
This will only be targeting DTLS 1.2, and the most modern/common cipher suites.
|
||||||
|
We would love contributes that fall under the 'Planned Features' and fixing any bugs!
|
||||||
|
|
||||||
|
#### Current features
|
||||||
|
* DTLS 1.2 Client/Server
|
||||||
|
* Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK
|
||||||
|
* Packet loss and re-ordering is handled during handshaking
|
||||||
|
* Key export ([RFC 5705][rfc5705])
|
||||||
|
* Serialization and Resumption of sessions
|
||||||
|
* Extended Master Secret extension ([RFC 7627][rfc7627])
|
||||||
|
|
||||||
|
[rfc5705]: https://tools.ietf.org/html/rfc5705
|
||||||
|
[rfc7627]: https://tools.ietf.org/html/rfc7627
|
||||||
|
|
||||||
|
#### Supported ciphers
|
||||||
|
|
||||||
|
##### ECDHE
|
||||||
|
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655])
|
||||||
|
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
|
||||||
|
* TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
|
||||||
|
* TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
|
||||||
|
* TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
|
||||||
|
* TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
|
||||||
|
|
||||||
|
##### PSK
|
||||||
|
* TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655])
|
||||||
|
* TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
|
||||||
|
* TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487])
|
||||||
|
* TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487])
|
||||||
|
|
||||||
|
[rfc5289]: https://tools.ietf.org/html/rfc5289
|
||||||
|
[rfc8422]: https://tools.ietf.org/html/rfc8422
|
||||||
|
[rfc6655]: https://tools.ietf.org/html/rfc6655
|
||||||
|
[rfc5487]: https://tools.ietf.org/html/rfc5487
|
||||||
|
|
||||||
|
#### Planned Features
|
||||||
|
* Chacha20Poly1305
|
||||||
|
|
||||||
|
#### Excluded Features
|
||||||
|
* DTLS 1.0
|
||||||
|
* Renegotiation
|
||||||
|
* Compression
|
||||||
|
|
||||||
|
### Using
|
||||||
|
|
||||||
|
This library needs at least Go 1.13, and you should have [Go modules
|
||||||
|
enabled](https://github.com/golang/go/wiki/Modules).
|
||||||
|
|
||||||
|
#### Pion DTLS
|
||||||
|
For a DTLS 1.2 Server that listens on 127.0.0.1:4444
|
||||||
|
```sh
|
||||||
|
go run examples/listen/selfsign/main.go
|
||||||
|
```
|
||||||
|
|
||||||
|
For a DTLS 1.2 Client that connects to 127.0.0.1:4444
|
||||||
|
```sh
|
||||||
|
go run examples/dial/selfsign/main.go
|
||||||
|
```
|
||||||
|
|
||||||
|
#### OpenSSL
|
||||||
|
Pion DTLS can connect to itself and OpenSSL.
|
||||||
|
```
|
||||||
|
// Generate a certificate
|
||||||
|
openssl ecparam -out key.pem -name prime256v1 -genkey
|
||||||
|
openssl req -new -sha256 -key key.pem -out server.csr
|
||||||
|
openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem
|
||||||
|
|
||||||
|
// Use with examples/dial/selfsign/main.go
|
||||||
|
openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444
|
||||||
|
|
||||||
|
// Use with examples/listen/selfsign/main.go
|
||||||
|
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using with PSK
|
||||||
|
Pion DTLS also comes with examples that do key exchange via PSK
|
||||||
|
|
||||||
|
|
||||||
|
#### Pion DTLS
|
||||||
|
```sh
|
||||||
|
go run examples/listen/psk/main.go
|
||||||
|
```
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run examples/dial/psk/main.go
|
||||||
|
```
|
||||||
|
|
||||||
|
#### OpenSSL
|
||||||
|
```
|
||||||
|
// Use with examples/dial/psk/main.go
|
||||||
|
openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8
|
||||||
|
|
||||||
|
// Use with examples/listen/psk/main.go
|
||||||
|
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Contributing
|
||||||
|
Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible:
|
||||||
|
|
||||||
|
* [Sean DuBois](https://github.com/Sean-Der) - *Original Author*
|
||||||
|
* [Michiel De Backker](https://github.com/backkem) - *Public API*
|
||||||
|
* [Chris Hiszpanski](https://github.com/thinkski) - *Support Signature Algorithms Extension*
|
||||||
|
* [Iñigo Garcia Olaizola](https://github.com/igolaizola) - *Serialization & resumption, cert verification, E2E*
|
||||||
|
* [Daniele Sluijters](https://github.com/daenney) - *AES-CCM support*
|
||||||
|
* [Jin Lei](https://github.com/jinleileiking) - *Logging*
|
||||||
|
* [Hugo Arregui](https://github.com/hugoArregui)
|
||||||
|
* [Lander Noterman](https://github.com/LanderN)
|
||||||
|
* [Aleksandr Razumov](https://github.com/ernado) - *Fuzzing*
|
||||||
|
* [Ryan Gordon](https://github.com/ryangordon)
|
||||||
|
* [Stefan Tatschner](https://rumpelsepp.org/contact.html)
|
||||||
|
* [Hayden James](https://github.com/hjames9)
|
||||||
|
* [Jozef Kralik](https://github.com/jkralik)
|
||||||
|
* [Robert Eperjesi](https://github.com/epes)
|
||||||
|
* [Atsushi Watanabe](https://github.com/at-wat)
|
||||||
|
* [Julien Salleyron](https://github.com/juliens) - *Server Name Indication*
|
||||||
|
* [Jeroen de Bruijn](https://github.com/vidavidorra)
|
||||||
|
* [bjdgyc](https://github.com/bjdgyc)
|
||||||
|
* [Jeffrey Stoke (Jeff Ctor)](https://github.com/jeffreystoke) - *Fragmentbuffer Fix*
|
||||||
|
* [Frank Olbricht](https://github.com/folbricht)
|
||||||
|
* [ZHENK](https://github.com/scorpionknifes)
|
||||||
|
* [Carson Hoffman](https://github.com/CarsonHoffman)
|
||||||
|
* [Vadim Filimonov](https://github.com/fffilimonov)
|
||||||
|
* [Jim Wert](https://github.com/bocajim)
|
||||||
|
* [Alvaro Viebrantz](https://github.com/alvarowolfx)
|
||||||
|
* [Kegan Dougal](https://github.com/Kegsay)
|
||||||
|
* [Michael Zabka](https://github.com/misak113)
|
||||||
|
|
||||||
|
### License
|
||||||
|
MIT License - see [LICENSE](LICENSE) for full text
|
|
@ -0,0 +1,118 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/internal/net/dpipe"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
"github.com/pion/logging"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimpleReadWrite(t *testing.T) {
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ca, cb := dpipe.Pipe()
|
||||||
|
certificate, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
gotHello := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server, sErr := testServer(ctx, cb, &Config{
|
||||||
|
Certificates: []tls.Certificate{certificate},
|
||||||
|
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||||
|
}, false)
|
||||||
|
if sErr != nil {
|
||||||
|
t.Error(sErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
if _, sErr = server.Read(buf); sErr != nil {
|
||||||
|
t.Error(sErr)
|
||||||
|
}
|
||||||
|
gotHello <- struct{}{}
|
||||||
|
if sErr = server.Close(); sErr != nil {
|
||||||
|
t.Error(sErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := testClient(ctx, ca, &Config{
|
||||||
|
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err = client.Write([]byte("hello")); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-gotHello:
|
||||||
|
// OK
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Error("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = client.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkConn(b *testing.B, n int64) {
|
||||||
|
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ca, cb := dpipe.Pipe()
|
||||||
|
certificate, err := selfsign.GenerateSelfSigned()
|
||||||
|
server := make(chan *Conn)
|
||||||
|
go func() {
|
||||||
|
s, sErr := testServer(ctx, cb, &Config{
|
||||||
|
Certificates: []tls.Certificate{certificate},
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
b.Error(sErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server <- s
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
hw := make([]byte, n)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(hw)))
|
||||||
|
go func() {
|
||||||
|
client, cErr := testClient(ctx, ca, &Config{InsecureSkipVerify: true}, false)
|
||||||
|
if cErr != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if _, cErr = client.Write(hw); cErr != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s := <-server
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err = s.Read(buf); err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConnReadWrite(b *testing.B) {
|
||||||
|
for _, n := range []int64{16, 128, 512, 1024, 2048} {
|
||||||
|
benchmarkConn(b, n)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.nameToCertificate == nil {
|
||||||
|
nameToCertificate := make(map[string]*tls.Certificate)
|
||||||
|
for i := range c.localCertificates {
|
||||||
|
cert := &c.localCertificates[i]
|
||||||
|
x509Cert := cert.Leaf
|
||||||
|
if x509Cert == nil {
|
||||||
|
var parseErr error
|
||||||
|
x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
if parseErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(x509Cert.Subject.CommonName) > 0 {
|
||||||
|
nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
|
||||||
|
}
|
||||||
|
for _, san := range x509Cert.DNSNames {
|
||||||
|
nameToCertificate[strings.ToLower(san)] = cert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.nameToCertificate = nameToCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.localCertificates) == 0 {
|
||||||
|
return nil, errNoCertificates
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.localCertificates) == 1 {
|
||||||
|
// There's only one choice, so no point doing any work.
|
||||||
|
return &c.localCertificates[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(serverName) == 0 {
|
||||||
|
return &c.localCertificates[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimRight(strings.ToLower(serverName), ".")
|
||||||
|
|
||||||
|
if cert, ok := c.nameToCertificate[name]; ok {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// try replacing labels in the name with wildcards until we get a
|
||||||
|
// match.
|
||||||
|
labels := strings.Split(name, ".")
|
||||||
|
for i := range labels {
|
||||||
|
labels[i] = "*"
|
||||||
|
candidate := strings.Join(labels, ".")
|
||||||
|
if cert, ok := c.nameToCertificate[candidate]; ok {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If nothing matches, return the first certificate.
|
||||||
|
return &c.localCertificates[0], nil
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetCertificate(t *testing.T) {
|
||||||
|
certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certificateRandom, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &handshakeConfig{
|
||||||
|
localCertificates: []tls.Certificate{
|
||||||
|
certificateRandom,
|
||||||
|
certificateTest,
|
||||||
|
certificateWildcard,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
serverName string
|
||||||
|
expectedCertificate tls.Certificate
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Simple match in CN",
|
||||||
|
serverName: "test.test",
|
||||||
|
expectedCertificate: certificateTest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Simple match in SANs",
|
||||||
|
serverName: "www.test.test",
|
||||||
|
expectedCertificate: certificateTest,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
desc: "Wildcard match",
|
||||||
|
serverName: "foo.test.test",
|
||||||
|
expectedCertificate: certificateWildcard,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "No match return first",
|
||||||
|
serverName: "foo.bar",
|
||||||
|
expectedCertificate: certificateRandom,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cert, err := cfg.getCertificate(test.serverName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(cert.Leaf, test.expectedCertificate.Leaf) {
|
||||||
|
t.Fatalf("Certificate does not match: expected(%v) actual(%v)", test.expectedCertificate.Leaf, cert.Leaf)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,213 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/internal/ciphersuite"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CipherSuiteID is an ID for our supported CipherSuites
|
||||||
|
type CipherSuiteID = ciphersuite.ID
|
||||||
|
|
||||||
|
// Supported Cipher Suites
|
||||||
|
const (
|
||||||
|
// AES-128-CCM
|
||||||
|
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:golint,stylecheck
|
||||||
|
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
|
||||||
|
|
||||||
|
// AES-128-GCM-SHA256
|
||||||
|
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||||
|
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||||
|
|
||||||
|
// AES-256-CBC-SHA
|
||||||
|
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
|
||||||
|
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
|
||||||
|
|
||||||
|
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:golint,stylecheck
|
||||||
|
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
|
||||||
|
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||||
|
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
|
||||||
|
)
|
||||||
|
|
||||||
|
// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite
|
||||||
|
type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType
|
||||||
|
|
||||||
|
// AuthenticationType Enums
|
||||||
|
const (
|
||||||
|
CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate
|
||||||
|
CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey
|
||||||
|
CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14
|
||||||
|
|
||||||
|
// CipherSuite is an interface that all DTLS CipherSuites must satisfy
|
||||||
|
type CipherSuite interface {
|
||||||
|
// String of CipherSuite, only used for logging
|
||||||
|
String() string
|
||||||
|
|
||||||
|
// ID of CipherSuite.
|
||||||
|
ID() CipherSuiteID
|
||||||
|
|
||||||
|
// What type of Certificate does this CipherSuite use
|
||||||
|
CertificateType() clientcertificate.Type
|
||||||
|
|
||||||
|
// What Hash function is used during verification
|
||||||
|
HashFunc() func() hash.Hash
|
||||||
|
|
||||||
|
// AuthenticationType controls what authentication method is using during the handshake
|
||||||
|
AuthenticationType() CipherSuiteAuthenticationType
|
||||||
|
|
||||||
|
// Called when keying material has been generated, should initialize the internal cipher
|
||||||
|
Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
|
||||||
|
IsInitialized() bool
|
||||||
|
|
||||||
|
Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error)
|
||||||
|
Decrypt(in []byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CipherSuiteName provides the same functionality as tls.CipherSuiteName
|
||||||
|
// that appeared first in Go 1.14.
|
||||||
|
//
|
||||||
|
// Our implementation differs slightly in that it takes in a CiperSuiteID,
|
||||||
|
// like the rest of our library, instead of a uint16 like crypto/tls.
|
||||||
|
func CipherSuiteName(id CipherSuiteID) string {
|
||||||
|
suite := cipherSuiteForID(id, nil)
|
||||||
|
if suite != nil {
|
||||||
|
return suite.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("0x%04X", uint16(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
|
||||||
|
// A cipherSuite is a specific combination of key agreement, cipher and MAC
|
||||||
|
// function.
|
||||||
|
func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite {
|
||||||
|
switch id { //nolint:exhaustive
|
||||||
|
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
|
||||||
|
return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm()
|
||||||
|
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
|
||||||
|
return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8()
|
||||||
|
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
|
||||||
|
return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}
|
||||||
|
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
|
||||||
|
return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}
|
||||||
|
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
|
||||||
|
return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}
|
||||||
|
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
|
||||||
|
return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}
|
||||||
|
case TLS_PSK_WITH_AES_128_CCM:
|
||||||
|
return ciphersuite.NewTLSPskWithAes128Ccm()
|
||||||
|
case TLS_PSK_WITH_AES_128_CCM_8:
|
||||||
|
return ciphersuite.NewTLSPskWithAes128Ccm8()
|
||||||
|
case TLS_PSK_WITH_AES_128_GCM_SHA256:
|
||||||
|
return &ciphersuite.TLSPskWithAes128GcmSha256{}
|
||||||
|
case TLS_PSK_WITH_AES_128_CBC_SHA256:
|
||||||
|
return &ciphersuite.TLSPskWithAes128CbcSha256{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if customCiphers != nil {
|
||||||
|
for _, c := range customCiphers() {
|
||||||
|
if c.ID() == id {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CipherSuites we support in order of preference
|
||||||
|
func defaultCipherSuites() []CipherSuite {
|
||||||
|
return []CipherSuite{
|
||||||
|
&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
|
||||||
|
&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
|
||||||
|
&ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{},
|
||||||
|
&ciphersuite.TLSEcdheRsaWithAes256CbcSha{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func allCipherSuites() []CipherSuite {
|
||||||
|
return []CipherSuite{
|
||||||
|
ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(),
|
||||||
|
ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(),
|
||||||
|
&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
|
||||||
|
&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
|
||||||
|
&ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{},
|
||||||
|
&ciphersuite.TLSEcdheRsaWithAes256CbcSha{},
|
||||||
|
ciphersuite.NewTLSPskWithAes128Ccm(),
|
||||||
|
ciphersuite.NewTLSPskWithAes128Ccm8(),
|
||||||
|
&ciphersuite.TLSPskWithAes128GcmSha256{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 {
|
||||||
|
rtrn := []uint16{}
|
||||||
|
for _, c := range cipherSuites {
|
||||||
|
rtrn = append(rtrn, uint16(c.ID()))
|
||||||
|
}
|
||||||
|
return rtrn
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) {
|
||||||
|
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) {
|
||||||
|
cipherSuites := []CipherSuite{}
|
||||||
|
for _, id := range ids {
|
||||||
|
c := cipherSuiteForID(id, nil)
|
||||||
|
if c == nil {
|
||||||
|
return nil, &invalidCipherSuite{id}
|
||||||
|
}
|
||||||
|
cipherSuites = append(cipherSuites, c)
|
||||||
|
}
|
||||||
|
return cipherSuites, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
cipherSuites []CipherSuite
|
||||||
|
err error
|
||||||
|
i int
|
||||||
|
)
|
||||||
|
if userSelectedSuites != nil {
|
||||||
|
cipherSuites, err = cipherSuitesForIDs(userSelectedSuites)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cipherSuites = defaultCipherSuites()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put CustomCipherSuites before ID selected suites
|
||||||
|
if customCipherSuites != nil {
|
||||||
|
cipherSuites = append(customCipherSuites(), cipherSuites...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool
|
||||||
|
for _, c := range cipherSuites {
|
||||||
|
switch {
|
||||||
|
case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
|
||||||
|
foundCertificateSuite = true
|
||||||
|
case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey:
|
||||||
|
foundPSKSuite = true
|
||||||
|
case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous:
|
||||||
|
foundAnonymousSuite = true
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cipherSuites[i] = c
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite:
|
||||||
|
return nil, errNoAvailableCertificateCipherSuite
|
||||||
|
case includePSKSuites && !foundPSKSuite:
|
||||||
|
return nil, errNoAvailablePSKCipherSuite
|
||||||
|
case i == 0:
|
||||||
|
return nil, errNoAvailableCipherSuites
|
||||||
|
}
|
||||||
|
|
||||||
|
return cipherSuites[:i], nil
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
// +build go1.14
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VersionDTLS12 is the DTLS version in the same style as
|
||||||
|
// VersionTLSXX from crypto/tls
|
||||||
|
const VersionDTLS12 = 0xfefd
|
||||||
|
|
||||||
|
// Convert from our cipherSuite interface to a tls.CipherSuite struct
|
||||||
|
func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite {
|
||||||
|
return &tls.CipherSuite{
|
||||||
|
ID: uint16(c.ID()),
|
||||||
|
Name: c.String(),
|
||||||
|
SupportedVersions: []uint16{VersionDTLS12},
|
||||||
|
Insecure: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CipherSuites returns a list of cipher suites currently implemented by this
|
||||||
|
// package, excluding those with security issues, which are returned by
|
||||||
|
// InsecureCipherSuites.
|
||||||
|
func CipherSuites() []*tls.CipherSuite {
|
||||||
|
suites := allCipherSuites()
|
||||||
|
res := make([]*tls.CipherSuite, len(suites))
|
||||||
|
for i, c := range suites {
|
||||||
|
res[i] = toTLSCipherSuite(c)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsecureCipherSuites returns a list of cipher suites currently implemented by
|
||||||
|
// this package and which have security issues.
|
||||||
|
func InsecureCipherSuites() []*tls.CipherSuite {
|
||||||
|
var res []*tls.CipherSuite
|
||||||
|
return res
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
// +build go1.14
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInsecureCipherSuites(t *testing.T) {
|
||||||
|
r := InsecureCipherSuites()
|
||||||
|
|
||||||
|
if len(r) != 0 {
|
||||||
|
t.Fatalf("Expected no insecure ciphersuites, got %d", len(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCipherSuites(t *testing.T) {
|
||||||
|
ours := allCipherSuites()
|
||||||
|
theirs := CipherSuites()
|
||||||
|
|
||||||
|
if len(ours) != len(theirs) {
|
||||||
|
t.Fatalf("Expected %d CipherSuites, got %d", len(ours), len(theirs))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range ours {
|
||||||
|
i := i
|
||||||
|
s := s
|
||||||
|
t.Run(s.String(), func(t *testing.T) {
|
||||||
|
c := theirs[i]
|
||||||
|
if c.ID != uint16(s.ID()) {
|
||||||
|
t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), c.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Name != s.String() {
|
||||||
|
t.Fatalf("Expected Name: %s, got %s", s.String(), c.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.SupportedVersions) != 1 {
|
||||||
|
t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(c.SupportedVersions))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SupportedVersions[0] != VersionDTLS12 {
|
||||||
|
t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, c.SupportedVersions[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Insecure {
|
||||||
|
t.Fatalf("Expected Insecure %t, got %t", false, c.Insecure)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/internal/ciphersuite"
|
||||||
|
"github.com/pion/dtls/v2/internal/net/dpipe"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCipherSuiteName(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
suite CipherSuiteID
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{TLS_ECDHE_ECDSA_WITH_AES_128_CCM, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"},
|
||||||
|
{CipherSuiteID(0x0000), "0x0000"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
res := CipherSuiteName(testCase.suite)
|
||||||
|
if res != testCase.expected {
|
||||||
|
t.Fatalf("Expected: %s, got %s", testCase.expected, res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllCipherSuites(t *testing.T) {
|
||||||
|
actual := len(allCipherSuites())
|
||||||
|
if actual == 0 {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomCipher that is just used to assert Custom IDs work
|
||||||
|
type testCustomCipherSuite struct {
|
||||||
|
ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256
|
||||||
|
authenticationType CipherSuiteAuthenticationType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testCustomCipherSuite) ID() CipherSuiteID {
|
||||||
|
return 0xFFFF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationType {
|
||||||
|
return t.authenticationType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that two connections that pass in a CipherSuite with a CustomID works
|
||||||
|
func TestCustomCipherSuite(t *testing.T) {
|
||||||
|
type result struct {
|
||||||
|
c *Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for leaking routines
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
runTest := func(cipherFactory func() []CipherSuite) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ca, cb := dpipe.Pipe()
|
||||||
|
c := make(chan result)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
client, err := testClient(ctx, ca, &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{},
|
||||||
|
CustomCipherSuites: cipherFactory,
|
||||||
|
}, true)
|
||||||
|
c <- result{client, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
server, err := testServer(ctx, cb, &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{},
|
||||||
|
CustomCipherSuites: cipherFactory,
|
||||||
|
}, true)
|
||||||
|
|
||||||
|
clientResult := <-c
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
_ = server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientResult.err != nil {
|
||||||
|
t.Error(clientResult.err)
|
||||||
|
} else {
|
||||||
|
_ = clientResult.c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Custom ID", func(t *testing.T) {
|
||||||
|
runTest(func() []CipherSuite {
|
||||||
|
return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeCertificate}}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Anonymous Cipher", func(t *testing.T) {
|
||||||
|
runTest(func() []CipherSuite {
|
||||||
|
return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeAnonymous}}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
#
|
||||||
|
# DO NOT EDIT THIS FILE
|
||||||
|
#
|
||||||
|
# It is automatically copied from https://github.com/pion/.goassets repository.
|
||||||
|
#
|
||||||
|
|
||||||
|
coverage:
|
||||||
|
status:
|
||||||
|
project:
|
||||||
|
default:
|
||||||
|
# Allow decreasing 2% of total coverage to avoid noise.
|
||||||
|
threshold: 2%
|
||||||
|
patch:
|
||||||
|
default:
|
||||||
|
target: 70%
|
||||||
|
only_pulls: true
|
||||||
|
|
||||||
|
ignore:
|
||||||
|
- "examples/*"
|
||||||
|
- "examples/**/*"
|
|
@ -0,0 +1,9 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import "github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
|
||||||
|
func defaultCompressionMethods() []*protocol.CompressionMethod {
|
||||||
|
return []*protocol.CompressionMethod{
|
||||||
|
{},
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,197 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
const keyLogLabelTLS12 = "CLIENT_RANDOM"
|
||||||
|
|
||||||
|
// Config is used to configure a DTLS client or server.
|
||||||
|
// After a Config is passed to a DTLS function it must not be modified.
|
||||||
|
type Config struct {
|
||||||
|
// Certificates contains certificate chain to present to the other side of the connection.
|
||||||
|
// Server MUST set this if PSK is non-nil
|
||||||
|
// client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil
|
||||||
|
Certificates []tls.Certificate
|
||||||
|
|
||||||
|
// CipherSuites is a list of supported cipher suites.
|
||||||
|
// If CipherSuites is nil, a default list is used
|
||||||
|
CipherSuites []CipherSuiteID
|
||||||
|
|
||||||
|
// CustomCipherSuites is a list of CipherSuites that can be
|
||||||
|
// provided by the user. This allow users to user Ciphers that are reserved
|
||||||
|
// for private usage.
|
||||||
|
CustomCipherSuites func() []CipherSuite
|
||||||
|
|
||||||
|
// SignatureSchemes contains the signature and hash schemes that the peer requests to verify.
|
||||||
|
SignatureSchemes []tls.SignatureScheme
|
||||||
|
|
||||||
|
// SRTPProtectionProfiles are the supported protection profiles
|
||||||
|
// Clients will send this via use_srtp and assert that the server properly responds
|
||||||
|
// Servers will assert that clients send one of these profiles and will respond as needed
|
||||||
|
SRTPProtectionProfiles []SRTPProtectionProfile
|
||||||
|
|
||||||
|
// ClientAuth determines the server's policy for
|
||||||
|
// TLS Client Authentication. The default is NoClientCert.
|
||||||
|
ClientAuth ClientAuthType
|
||||||
|
|
||||||
|
// RequireExtendedMasterSecret determines if the "Extended Master Secret" extension
|
||||||
|
// should be disabled, requested, or required (default requested).
|
||||||
|
ExtendedMasterSecret ExtendedMasterSecretType
|
||||||
|
|
||||||
|
// FlightInterval controls how often we send outbound handshake messages
|
||||||
|
// defaults to time.Second
|
||||||
|
FlightInterval time.Duration
|
||||||
|
|
||||||
|
// PSK sets the pre-shared key used by this DTLS connection
|
||||||
|
// If PSK is non-nil only PSK CipherSuites will be used
|
||||||
|
PSK PSKCallback
|
||||||
|
PSKIdentityHint []byte
|
||||||
|
|
||||||
|
CiscoCompat PSKCallback // TODO add cisco anyconnect support
|
||||||
|
|
||||||
|
// InsecureSkipVerify controls whether a client verifies the
|
||||||
|
// server's certificate chain and host name.
|
||||||
|
// If InsecureSkipVerify is true, TLS accepts any certificate
|
||||||
|
// presented by the server and any host name in that certificate.
|
||||||
|
// In this mode, TLS is susceptible to man-in-the-middle attacks.
|
||||||
|
// This should be used only for testing.
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
|
||||||
|
// InsecureHashes allows the use of hashing algorithms that are known
|
||||||
|
// to be vulnerable.
|
||||||
|
InsecureHashes bool
|
||||||
|
|
||||||
|
// VerifyPeerCertificate, if not nil, is called after normal
|
||||||
|
// certificate verification by either a client or server. It
|
||||||
|
// receives the certificate provided by the peer and also a flag
|
||||||
|
// that tells if normal verification has succeedded. If it returns a
|
||||||
|
// non-nil error, the handshake is aborted and that error results.
|
||||||
|
//
|
||||||
|
// If normal verification fails then the handshake will abort before
|
||||||
|
// considering this callback. If normal verification is disabled by
|
||||||
|
// setting InsecureSkipVerify, or (for a server) when ClientAuth is
|
||||||
|
// RequestClientCert or RequireAnyClientCert, then this callback will
|
||||||
|
// be considered but the verifiedChains will always be nil.
|
||||||
|
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||||
|
|
||||||
|
// RootCAs defines the set of root certificate authorities
|
||||||
|
// that one peer uses when verifying the other peer's certificates.
|
||||||
|
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||||
|
RootCAs *x509.CertPool
|
||||||
|
|
||||||
|
// ClientCAs defines the set of root certificate authorities
|
||||||
|
// that servers use if required to verify a client certificate
|
||||||
|
// by the policy in ClientAuth.
|
||||||
|
ClientCAs *x509.CertPool
|
||||||
|
|
||||||
|
// ServerName is used to verify the hostname on the returned
|
||||||
|
// certificates unless InsecureSkipVerify is given.
|
||||||
|
ServerName string
|
||||||
|
|
||||||
|
LoggerFactory logging.LoggerFactory
|
||||||
|
|
||||||
|
// ConnectContextMaker is a function to make a context used in Dial(),
|
||||||
|
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
|
||||||
|
// is used. It can be implemented as following.
|
||||||
|
//
|
||||||
|
// func ConnectContextMaker() (context.Context, func()) {
|
||||||
|
// return context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// }
|
||||||
|
ConnectContextMaker func() (context.Context, func())
|
||||||
|
|
||||||
|
// MTU is the length at which handshake messages will be fragmented to
|
||||||
|
// fit within the maximum transmission unit (default is 1200 bytes)
|
||||||
|
MTU int
|
||||||
|
|
||||||
|
// ReplayProtectionWindow is the size of the replay attack protection window.
|
||||||
|
// Duplication of the sequence number is checked in this window size.
|
||||||
|
// Packet with sequence number older than this value compared to the latest
|
||||||
|
// accepted packet will be discarded. (default is 64)
|
||||||
|
ReplayProtectionWindow int
|
||||||
|
|
||||||
|
// KeyLogWriter optionally specifies a destination for TLS master secrets
|
||||||
|
// in NSS key log format that can be used to allow external programs
|
||||||
|
// such as Wireshark to decrypt TLS connections.
|
||||||
|
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format.
|
||||||
|
// Use of KeyLogWriter compromises security and should only be
|
||||||
|
// used for debugging.
|
||||||
|
KeyLogWriter io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultConnectContextMaker() (context.Context, func()) {
|
||||||
|
return context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) connectContextMaker() (context.Context, func()) {
|
||||||
|
if c.ConnectContextMaker == nil {
|
||||||
|
return defaultConnectContextMaker()
|
||||||
|
}
|
||||||
|
return c.ConnectContextMaker()
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultMTU = 1200 // bytes
|
||||||
|
|
||||||
|
// PSKCallback is called once we have the remote's PSKIdentityHint.
|
||||||
|
// If the remote provided none it will be nil
|
||||||
|
type PSKCallback func([]byte) ([]byte, error)
|
||||||
|
|
||||||
|
// ClientAuthType declares the policy the server will follow for
|
||||||
|
// TLS Client Authentication.
|
||||||
|
type ClientAuthType int
|
||||||
|
|
||||||
|
// ClientAuthType enums
|
||||||
|
const (
|
||||||
|
NoClientCert ClientAuthType = iota
|
||||||
|
RequestClientCert
|
||||||
|
RequireAnyClientCert
|
||||||
|
VerifyClientCertIfGiven
|
||||||
|
RequireAndVerifyClientCert
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtendedMasterSecretType declares the policy the client and server
|
||||||
|
// will follow for the Extended Master Secret extension
|
||||||
|
type ExtendedMasterSecretType int
|
||||||
|
|
||||||
|
// ExtendedMasterSecretType enums
|
||||||
|
const (
|
||||||
|
RequestExtendedMasterSecret ExtendedMasterSecretType = iota
|
||||||
|
RequireExtendedMasterSecret
|
||||||
|
DisableExtendedMasterSecret
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateConfig(config *Config) error {
|
||||||
|
switch {
|
||||||
|
case config == nil:
|
||||||
|
return errNoConfigProvided
|
||||||
|
case config.PSKIdentityHint != nil && config.PSK == nil:
|
||||||
|
return errIdentityNoPSK
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cert := range config.Certificates {
|
||||||
|
if cert.Certificate == nil {
|
||||||
|
return errInvalidCertificate
|
||||||
|
}
|
||||||
|
if cert.PrivateKey != nil {
|
||||||
|
switch cert.PrivateKey.(type) {
|
||||||
|
case ed25519.PrivateKey:
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
default:
|
||||||
|
return errInvalidPrivateKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/dsa" //nolint
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateConfig(t *testing.T) {
|
||||||
|
// Empty config
|
||||||
|
if err := validateConfig(nil); !errors.Is(err, errNoConfigProvided) {
|
||||||
|
t.Fatalf("TestValidateConfig: Config validation error exp(%v) failed(%v)", errNoConfigProvided, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PSK and Certificate, valid cipher suites
|
||||||
|
cert, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Config validation error(%v), self signed certificate not generated", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config := &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PSK and Certificate, no PSK cipher suite
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); !errors.Is(errNoAvailablePSKCipherSuite, err) {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", errNoAvailablePSKCipherSuite, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PSK and Certificate, no non-PSK cipher suite
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); !errors.Is(errNoAvailableCertificateCipherSuite, err) {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", errNoAvailableCertificateCipherSuite, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PSK identity hint with not PSK
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
PSK: nil,
|
||||||
|
PSKIdentityHint: []byte{},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); !errors.Is(err, errIdentityNoPSK) {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", errIdentityNoPSK, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid private key
|
||||||
|
dsaPrivateKey := &dsa.PrivateKey{}
|
||||||
|
err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Config validation error(%v), DSA parameters not generated", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = dsa.GenerateKey(dsaPrivateKey, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Config validation error(%v), DSA private key not generated", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
Certificates: []tls.Certificate{{Certificate: cert.Certificate, PrivateKey: dsaPrivateKey}},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); !errors.Is(err, errInvalidPrivateKey) {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", errInvalidPrivateKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivateKey without Certificate
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
Certificates: []tls.Certificate{{PrivateKey: cert.PrivateKey}},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); !errors.Is(err, errInvalidCertificate) {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", errInvalidCertificate, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid cipher suites
|
||||||
|
config = &Config{CipherSuites: []CipherSuiteID{0x0000}}
|
||||||
|
if err = validateConfig(config); err == nil {
|
||||||
|
t.Fatal("TestValidateConfig: Client error expected with invalid CipherSuiteID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid config
|
||||||
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Config validation error(%v), RSA private key not generated", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config = &Config{
|
||||||
|
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
Certificates: []tls.Certificate{cert, {Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}},
|
||||||
|
}
|
||||||
|
if err = validateConfig(config); err != nil {
|
||||||
|
t.Fatalf("TestValidateConfig: Client error exp(%v) failed(%v)", nil, err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,979 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/internal/closer"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
"github.com/pion/logging"
|
||||||
|
"github.com/pion/transport/connctx"
|
||||||
|
"github.com/pion/transport/deadline"
|
||||||
|
"github.com/pion/transport/replaydetector"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
initialTickerInterval = time.Second
|
||||||
|
cookieLength = 20
|
||||||
|
defaultNamedCurve = elliptic.X25519
|
||||||
|
inboundBufferSize = 8192
|
||||||
|
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
|
||||||
|
defaultReplayProtectionWindow = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
func invalidKeyingLabels() map[string]bool {
|
||||||
|
return map[string]bool{
|
||||||
|
"client finished": true,
|
||||||
|
"server finished": true,
|
||||||
|
"master secret": true,
|
||||||
|
"key expansion": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn represents a DTLS connection
|
||||||
|
type Conn struct {
|
||||||
|
lock sync.RWMutex // Internal lock (must not be public)
|
||||||
|
nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
|
||||||
|
fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
|
||||||
|
handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
|
||||||
|
decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
|
||||||
|
|
||||||
|
state State // Internal state
|
||||||
|
|
||||||
|
maximumTransmissionUnit int
|
||||||
|
|
||||||
|
handshakeCompletedSuccessfully atomic.Value
|
||||||
|
|
||||||
|
encryptedPackets [][]byte
|
||||||
|
|
||||||
|
connectionClosedByUser bool
|
||||||
|
closeLock sync.Mutex
|
||||||
|
closed *closer.Closer
|
||||||
|
handshakeLoopsFinished sync.WaitGroup
|
||||||
|
|
||||||
|
readDeadline *deadline.Deadline
|
||||||
|
writeDeadline *deadline.Deadline
|
||||||
|
|
||||||
|
log logging.LeveledLogger
|
||||||
|
|
||||||
|
reading chan struct{}
|
||||||
|
handshakeRecv chan chan struct{}
|
||||||
|
cancelHandshaker func()
|
||||||
|
cancelHandshakeReader func()
|
||||||
|
|
||||||
|
fsm *handshakeFSM
|
||||||
|
|
||||||
|
replayProtectionWindow uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
|
||||||
|
err := validateConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextConn == nil {
|
||||||
|
return nil, errNilNextConn
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
workerInterval := initialTickerInterval
|
||||||
|
if config.FlightInterval != 0 {
|
||||||
|
workerInterval = config.FlightInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
loggerFactory := config.LoggerFactory
|
||||||
|
if loggerFactory == nil {
|
||||||
|
loggerFactory = logging.NewDefaultLoggerFactory()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := loggerFactory.NewLogger("dtls")
|
||||||
|
|
||||||
|
mtu := config.MTU
|
||||||
|
if mtu <= 0 {
|
||||||
|
mtu = defaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
replayProtectionWindow := config.ReplayProtectionWindow
|
||||||
|
if replayProtectionWindow <= 0 {
|
||||||
|
replayProtectionWindow = defaultReplayProtectionWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Conn{
|
||||||
|
nextConn: connctx.New(nextConn),
|
||||||
|
fragmentBuffer: newFragmentBuffer(),
|
||||||
|
handshakeCache: newHandshakeCache(),
|
||||||
|
maximumTransmissionUnit: mtu,
|
||||||
|
|
||||||
|
decrypted: make(chan interface{}, 1),
|
||||||
|
log: logger,
|
||||||
|
|
||||||
|
readDeadline: deadline.New(),
|
||||||
|
writeDeadline: deadline.New(),
|
||||||
|
|
||||||
|
reading: make(chan struct{}, 1),
|
||||||
|
handshakeRecv: make(chan chan struct{}),
|
||||||
|
closed: closer.NewCloser(),
|
||||||
|
cancelHandshaker: func() {},
|
||||||
|
|
||||||
|
replayProtectionWindow: uint(replayProtectionWindow),
|
||||||
|
|
||||||
|
state: State{
|
||||||
|
isClient: isClient,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.setRemoteEpoch(0)
|
||||||
|
c.setLocalEpoch(0)
|
||||||
|
|
||||||
|
serverName := config.ServerName
|
||||||
|
// Use host from conn address when serverName is not provided
|
||||||
|
if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
|
||||||
|
remoteAddr := nextConn.RemoteAddr().String()
|
||||||
|
var host string
|
||||||
|
host, _, err = net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
serverName = remoteAddr
|
||||||
|
} else {
|
||||||
|
serverName = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hsCfg := &handshakeConfig{
|
||||||
|
localPSKCallback: config.PSK,
|
||||||
|
localPSKIdentityHint: config.PSKIdentityHint,
|
||||||
|
localCiscoCompatCallback: config.CiscoCompat,
|
||||||
|
localCipherSuites: cipherSuites,
|
||||||
|
localSignatureSchemes: signatureSchemes,
|
||||||
|
extendedMasterSecret: config.ExtendedMasterSecret,
|
||||||
|
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
|
||||||
|
serverName: serverName,
|
||||||
|
clientAuth: config.ClientAuth,
|
||||||
|
localCertificates: config.Certificates,
|
||||||
|
insecureSkipVerify: config.InsecureSkipVerify,
|
||||||
|
verifyPeerCertificate: config.VerifyPeerCertificate,
|
||||||
|
rootCAs: config.RootCAs,
|
||||||
|
clientCAs: config.ClientCAs,
|
||||||
|
customCipherSuites: config.CustomCipherSuites,
|
||||||
|
retransmitInterval: workerInterval,
|
||||||
|
log: logger,
|
||||||
|
initialEpoch: 0,
|
||||||
|
keyLogWriter: config.KeyLogWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
var initialFlight flightVal
|
||||||
|
var initialFSMState handshakeState
|
||||||
|
|
||||||
|
if initialState != nil {
|
||||||
|
if c.state.isClient {
|
||||||
|
initialFlight = flight5
|
||||||
|
} else {
|
||||||
|
initialFlight = flight6
|
||||||
|
}
|
||||||
|
initialFSMState = handshakeFinished
|
||||||
|
|
||||||
|
c.state = *initialState
|
||||||
|
} else {
|
||||||
|
if c.state.isClient {
|
||||||
|
initialFlight = flight1
|
||||||
|
} else {
|
||||||
|
initialFlight = flight0
|
||||||
|
}
|
||||||
|
initialFSMState = handshakePreparing
|
||||||
|
}
|
||||||
|
// Do handshake
|
||||||
|
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Trace("Handshake Completed")
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address and establishes a DTLS connection on top.
|
||||||
|
// Connection handshake will timeout using ConnectContextMaker in the Config.
|
||||||
|
// If you want to specify the timeout duration, use DialWithContext() instead.
|
||||||
|
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
|
||||||
|
ctx, cancel := config.connectContextMaker()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return DialWithContext(ctx, network, raddr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client establishes a DTLS connection over an existing connection.
|
||||||
|
// Connection handshake will timeout using ConnectContextMaker in the Config.
|
||||||
|
// If you want to specify the timeout duration, use ClientWithContext() instead.
|
||||||
|
func Client(conn net.Conn, config *Config) (*Conn, error) {
|
||||||
|
ctx, cancel := config.connectContextMaker()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return ClientWithContext(ctx, conn, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server listens for incoming DTLS connections.
|
||||||
|
// Connection handshake will timeout using ConnectContextMaker in the Config.
|
||||||
|
// If you want to specify the timeout duration, use ServerWithContext() instead.
|
||||||
|
func Server(conn net.Conn, config *Config) (*Conn, error) {
|
||||||
|
ctx, cancel := config.connectContextMaker()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return ServerWithContext(ctx, conn, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialWithContext connects to the given network address and establishes a DTLS connection on top.
|
||||||
|
func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
|
||||||
|
pConn, err := net.DialUDP(network, nil, raddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ClientWithContext(ctx, pConn, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientWithContext establishes a DTLS connection over an existing connection.
|
||||||
|
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||||
|
switch {
|
||||||
|
case config == nil:
|
||||||
|
return nil, errNoConfigProvided
|
||||||
|
case config.PSK != nil && config.PSKIdentityHint == nil:
|
||||||
|
return nil, errPSKAndIdentityMustBeSetForClient
|
||||||
|
}
|
||||||
|
|
||||||
|
return createConn(ctx, conn, config, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerWithContext listens for incoming DTLS connections.
|
||||||
|
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
|
||||||
|
if config == nil {
|
||||||
|
return nil, errNoConfigProvided
|
||||||
|
}
|
||||||
|
|
||||||
|
return createConn(ctx, conn, config, false, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads data from the connection.
|
||||||
|
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||||
|
if !c.isHandshakeCompletedSuccessfully() {
|
||||||
|
return 0, errHandshakeInProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.readDeadline.Done():
|
||||||
|
return 0, errDeadlineExceeded
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.readDeadline.Done():
|
||||||
|
return 0, errDeadlineExceeded
|
||||||
|
case out, ok := <-c.decrypted:
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
switch val := out.(type) {
|
||||||
|
case ([]byte):
|
||||||
|
if len(p) < len(val) {
|
||||||
|
return 0, errBufferTooSmall
|
||||||
|
}
|
||||||
|
copy(p, val)
|
||||||
|
return len(val), nil
|
||||||
|
case (error):
|
||||||
|
return 0, val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes len(p) bytes from p to the DTLS connection
|
||||||
|
func (c *Conn) Write(p []byte) (int, error) {
|
||||||
|
if c.isConnectionClosed() {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.writeDeadline.Done():
|
||||||
|
return 0, errDeadlineExceeded
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.isHandshakeCompletedSuccessfully() {
|
||||||
|
return 0, errHandshakeInProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(p), c.writePackets(c.writeDeadline, []*packet{
|
||||||
|
{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Epoch: c.getLocalEpoch(),
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &protocol.ApplicationData{
|
||||||
|
Data: p,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldEncrypt: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
err := c.close(true)
|
||||||
|
c.handshakeLoopsFinished.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState returns basic DTLS details about the connection.
|
||||||
|
// Note that this replaced the `Export` function of v1.
|
||||||
|
func (c *Conn) ConnectionState() State {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
return *c.state.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
|
||||||
|
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
|
||||||
|
if c.state.srtpProtectionProfile == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.state.srtpProtectionProfile, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
var rawPackets [][]byte
|
||||||
|
|
||||||
|
for _, p := range pkts {
|
||||||
|
if h, ok := p.record.Content.(*handshake.Handshake); ok {
|
||||||
|
handshakeRaw, err := p.record.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
|
||||||
|
srvCliStr(c.state.isClient), h.Header.Type.String(),
|
||||||
|
p.record.Header.Epoch, h.Header.MessageSequence)
|
||||||
|
c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
|
||||||
|
|
||||||
|
rawHandshakePackets, err := c.processHandshakePacket(p, h)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rawPackets = append(rawPackets, rawHandshakePackets...)
|
||||||
|
} else {
|
||||||
|
rawPacket, err := c.processPacket(p)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rawPackets = append(rawPackets, rawPacket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(rawPackets) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
compactedRawPackets := c.compactRawPackets(rawPackets)
|
||||||
|
|
||||||
|
for _, compactedRawPackets := range compactedRawPackets {
|
||||||
|
if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
|
||||||
|
return netError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
|
||||||
|
combinedRawPackets := make([][]byte, 0)
|
||||||
|
currentCombinedRawPacket := make([]byte, 0)
|
||||||
|
|
||||||
|
for _, rawPacket := range rawPackets {
|
||||||
|
if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
|
||||||
|
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
|
||||||
|
currentCombinedRawPacket = []byte{}
|
||||||
|
}
|
||||||
|
currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
|
||||||
|
|
||||||
|
return combinedRawPackets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) processPacket(p *packet) ([]byte, error) {
|
||||||
|
epoch := p.record.Header.Epoch
|
||||||
|
for len(c.state.localSequenceNumber) <= int(epoch) {
|
||||||
|
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
|
||||||
|
}
|
||||||
|
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
|
||||||
|
if seq > recordlayer.MaxSequenceNumber {
|
||||||
|
// RFC 6347 Section 4.1.0
|
||||||
|
// The implementation must either abandon an association or rehandshake
|
||||||
|
// prior to allowing the sequence number to wrap.
|
||||||
|
return nil, errSequenceNumberOverflow
|
||||||
|
}
|
||||||
|
p.record.Header.SequenceNumber = seq
|
||||||
|
|
||||||
|
rawPacket, err := p.record.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.shouldEncrypt {
|
||||||
|
var err error
|
||||||
|
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPacket, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
|
||||||
|
rawPackets := make([][]byte, 0)
|
||||||
|
|
||||||
|
handshakeFragments, err := c.fragmentHandshake(h)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
epoch := p.record.Header.Epoch
|
||||||
|
for len(c.state.localSequenceNumber) <= int(epoch) {
|
||||||
|
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, handshakeFragment := range handshakeFragments {
|
||||||
|
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
|
||||||
|
if seq > recordlayer.MaxSequenceNumber {
|
||||||
|
return nil, errSequenceNumberOverflow
|
||||||
|
}
|
||||||
|
|
||||||
|
recordlayerHeader := &recordlayer.Header{
|
||||||
|
Version: p.record.Header.Version,
|
||||||
|
ContentType: p.record.Header.ContentType,
|
||||||
|
ContentLen: uint16(len(handshakeFragment)),
|
||||||
|
Epoch: p.record.Header.Epoch,
|
||||||
|
SequenceNumber: seq,
|
||||||
|
}
|
||||||
|
|
||||||
|
recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.record.Header = *recordlayerHeader
|
||||||
|
|
||||||
|
rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
|
||||||
|
if p.shouldEncrypt {
|
||||||
|
var err error
|
||||||
|
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawPackets = append(rawPackets, rawPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPackets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
|
||||||
|
content, err := h.Message.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fragmentedHandshakes := make([][]byte, 0)
|
||||||
|
|
||||||
|
contentFragments := splitBytes(content, c.maximumTransmissionUnit)
|
||||||
|
if len(contentFragments) == 0 {
|
||||||
|
contentFragments = [][]byte{
|
||||||
|
{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
for _, contentFragment := range contentFragments {
|
||||||
|
contentFragmentLen := len(contentFragment)
|
||||||
|
|
||||||
|
headerFragment := &handshake.Header{
|
||||||
|
Type: h.Header.Type,
|
||||||
|
Length: h.Header.Length,
|
||||||
|
MessageSequence: h.Header.MessageSequence,
|
||||||
|
FragmentOffset: uint32(offset),
|
||||||
|
FragmentLength: uint32(contentFragmentLen),
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += contentFragmentLen
|
||||||
|
|
||||||
|
headerFragmentRaw, err := headerFragment.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
|
||||||
|
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fragmentedHandshakes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
|
||||||
|
New: func() interface{} {
|
||||||
|
b := make([]byte, inboundBufferSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readAndBuffer(ctx context.Context) error {
|
||||||
|
bufptr := poolReadBuffer.Get().(*[]byte)
|
||||||
|
defer poolReadBuffer.Put(bufptr)
|
||||||
|
|
||||||
|
b := *bufptr
|
||||||
|
i, err := c.nextConn.ReadContext(ctx, b)
|
||||||
|
if err != nil {
|
||||||
|
return netError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts, err := recordlayer.UnpackDatagram(b[:i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var hasHandshake bool
|
||||||
|
for _, p := range pkts {
|
||||||
|
hs, alert, err := c.handleIncomingPacket(p, true)
|
||||||
|
if alert != nil {
|
||||||
|
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
|
||||||
|
if err == nil {
|
||||||
|
err = alertErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hs {
|
||||||
|
hasHandshake = true
|
||||||
|
}
|
||||||
|
switch e := err.(type) {
|
||||||
|
case nil:
|
||||||
|
case *errAlert:
|
||||||
|
if e.IsFatalOrCloseNotify() {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasHandshake {
|
||||||
|
done := make(chan struct{})
|
||||||
|
select {
|
||||||
|
case c.handshakeRecv <- done:
|
||||||
|
// If the other party may retransmit the flight,
|
||||||
|
// we should respond even if it not a new message.
|
||||||
|
<-done
|
||||||
|
case <-c.fsm.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) handleQueuedPackets(ctx context.Context) error {
|
||||||
|
pkts := c.encryptedPackets
|
||||||
|
c.encryptedPackets = nil
|
||||||
|
|
||||||
|
for _, p := range pkts {
|
||||||
|
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
|
||||||
|
if alert != nil {
|
||||||
|
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
|
||||||
|
if err == nil {
|
||||||
|
err = alertErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch e := err.(type) {
|
||||||
|
case nil:
|
||||||
|
case *errAlert:
|
||||||
|
if e.IsFatalOrCloseNotify() {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
|
||||||
|
h := &recordlayer.Header{}
|
||||||
|
if err := h.Unmarshal(buf); err != nil {
|
||||||
|
// Decode error must be silently discarded
|
||||||
|
// [RFC6347 Section-4.1.2.7]
|
||||||
|
c.log.Debugf("discarded broken packet: %v", err)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate epoch
|
||||||
|
remoteEpoch := c.getRemoteEpoch()
|
||||||
|
if h.Epoch > remoteEpoch {
|
||||||
|
if h.Epoch > remoteEpoch+1 {
|
||||||
|
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
|
||||||
|
h.Epoch, h.SequenceNumber,
|
||||||
|
)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
if enqueue {
|
||||||
|
c.log.Debug("received packet of next epoch, queuing packet")
|
||||||
|
c.encryptedPackets = append(c.encryptedPackets, buf)
|
||||||
|
}
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anti-replay protection
|
||||||
|
for len(c.state.replayDetector) <= int(h.Epoch) {
|
||||||
|
c.state.replayDetector = append(c.state.replayDetector,
|
||||||
|
replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
|
||||||
|
if !ok {
|
||||||
|
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
|
||||||
|
h.Epoch, h.SequenceNumber,
|
||||||
|
)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt
|
||||||
|
if h.Epoch != 0 {
|
||||||
|
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
|
||||||
|
if enqueue {
|
||||||
|
c.encryptedPackets = append(c.encryptedPackets, buf)
|
||||||
|
c.log.Debug("handshake not finished, queuing packet")
|
||||||
|
}
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
buf, err = c.state.cipherSuite.Decrypt(buf)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
|
||||||
|
if err != nil {
|
||||||
|
// Decode error must be silently discarded
|
||||||
|
// [RFC6347 Section-4.1.2.7]
|
||||||
|
c.log.Debugf("defragment failed: %s", err)
|
||||||
|
return false, nil, nil
|
||||||
|
} else if isHandshake {
|
||||||
|
markPacketAsValid()
|
||||||
|
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
|
||||||
|
rawHandshake := &handshake.Handshake{}
|
||||||
|
if err := rawHandshake.Unmarshal(out); err != nil {
|
||||||
|
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = c.handshakeCache.push(out, epoch, rawHandshake.Header.MessageSequence, rawHandshake.Header.Type, !c.state.isClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &recordlayer.RecordLayer{}
|
||||||
|
if err := r.Unmarshal(buf); err != nil {
|
||||||
|
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch content := r.Content.(type) {
|
||||||
|
case *alert.Alert:
|
||||||
|
c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
|
||||||
|
var a *alert.Alert
|
||||||
|
if content.Description == alert.CloseNotify {
|
||||||
|
// Respond with a close_notify [RFC5246 Section 7.2.1]
|
||||||
|
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
|
||||||
|
}
|
||||||
|
markPacketAsValid()
|
||||||
|
return false, a, &errAlert{content}
|
||||||
|
case *protocol.ChangeCipherSpec:
|
||||||
|
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
|
||||||
|
if enqueue {
|
||||||
|
c.encryptedPackets = append(c.encryptedPackets, buf)
|
||||||
|
c.log.Debugf("CipherSuite not initialized, queuing packet")
|
||||||
|
}
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newRemoteEpoch := h.Epoch + 1
|
||||||
|
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
|
||||||
|
|
||||||
|
if c.getRemoteEpoch()+1 == newRemoteEpoch {
|
||||||
|
c.setRemoteEpoch(newRemoteEpoch)
|
||||||
|
markPacketAsValid()
|
||||||
|
}
|
||||||
|
case *protocol.ApplicationData:
|
||||||
|
if h.Epoch == 0 {
|
||||||
|
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
|
||||||
|
}
|
||||||
|
|
||||||
|
markPacketAsValid()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.decrypted <- content.Data:
|
||||||
|
case <-c.closed.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
|
||||||
|
}
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) recvHandshake() <-chan chan struct{} {
|
||||||
|
return c.handshakeRecv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
|
||||||
|
return c.writePackets(ctx, []*packet{
|
||||||
|
{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Epoch: c.getLocalEpoch(),
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &alert.Alert{
|
||||||
|
Level: level,
|
||||||
|
Description: desc,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setHandshakeCompletedSuccessfully() {
|
||||||
|
c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isHandshakeCompletedSuccessfully() bool {
|
||||||
|
boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
|
||||||
|
return boolean.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
|
||||||
|
c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
ctxRead, cancelRead := context.WithCancel(context.Background())
|
||||||
|
c.cancelHandshakeReader = cancelRead
|
||||||
|
cfg.onFlightState = func(f flightVal, s handshakeState) {
|
||||||
|
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
|
||||||
|
c.setHandshakeCompletedSuccessfully()
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxHs, cancel := context.WithCancel(context.Background())
|
||||||
|
c.cancelHandshaker = cancel
|
||||||
|
|
||||||
|
firstErr := make(chan error, 1)
|
||||||
|
|
||||||
|
c.handshakeLoopsFinished.Add(2)
|
||||||
|
|
||||||
|
// Handshake routine should be live until close.
|
||||||
|
// The other party may request retransmission of the last flight to cope with packet drop.
|
||||||
|
go func() {
|
||||||
|
defer c.handshakeLoopsFinished.Done()
|
||||||
|
err := c.fsm.Run(ctxHs, c, initialState)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
select {
|
||||||
|
case firstErr <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
// Escaping read loop.
|
||||||
|
// It's safe to close decrypted channnel now.
|
||||||
|
close(c.decrypted)
|
||||||
|
|
||||||
|
// Force stop handshaker when the underlying connection is closed.
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
defer c.handshakeLoopsFinished.Done()
|
||||||
|
for {
|
||||||
|
if err := c.readAndBuffer(ctxRead); err != nil {
|
||||||
|
switch e := err.(type) {
|
||||||
|
case *errAlert:
|
||||||
|
if !e.IsFatalOrCloseNotify() {
|
||||||
|
if c.isHandshakeCompletedSuccessfully() {
|
||||||
|
// Pass the error to Read()
|
||||||
|
select {
|
||||||
|
case c.decrypted <- err:
|
||||||
|
case <-c.closed.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue // non-fatal alert must not stop read loop
|
||||||
|
}
|
||||||
|
case error:
|
||||||
|
switch err {
|
||||||
|
case context.DeadlineExceeded, context.Canceled, io.EOF:
|
||||||
|
default:
|
||||||
|
if c.isHandshakeCompletedSuccessfully() {
|
||||||
|
// Keep read loop and pass the read error to Read()
|
||||||
|
select {
|
||||||
|
case c.decrypted <- err:
|
||||||
|
case <-c.closed.Done():
|
||||||
|
}
|
||||||
|
continue // non-fatal alert must not stop read loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case firstErr <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if e, ok := err.(*errAlert); ok {
|
||||||
|
if e.IsFatalOrCloseNotify() {
|
||||||
|
_ = c.close(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-firstErr:
|
||||||
|
cancelRead()
|
||||||
|
cancel()
|
||||||
|
return c.translateHandshakeCtxError(err)
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancelRead()
|
||||||
|
cancel()
|
||||||
|
return c.translateHandshakeCtxError(ctx.Err())
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) translateHandshakeCtxError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &HandshakeError{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) close(byUser bool) error {
|
||||||
|
c.cancelHandshaker()
|
||||||
|
c.cancelHandshakeReader()
|
||||||
|
|
||||||
|
if c.isHandshakeCompletedSuccessfully() && byUser {
|
||||||
|
// Discard error from notify() to return non-error on the first user call of Close()
|
||||||
|
// even if the underlying connection is already closed.
|
||||||
|
_ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.closeLock.Lock()
|
||||||
|
// Don't return ErrConnClosed at the first time of the call from user.
|
||||||
|
closedByUser := c.connectionClosedByUser
|
||||||
|
if byUser {
|
||||||
|
c.connectionClosedByUser = true
|
||||||
|
}
|
||||||
|
c.closed.Close()
|
||||||
|
c.closeLock.Unlock()
|
||||||
|
|
||||||
|
if closedByUser {
|
||||||
|
return ErrConnClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.nextConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isConnectionClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed.Done():
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setLocalEpoch(epoch uint16) {
|
||||||
|
c.state.localEpoch.Store(epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) getLocalEpoch() uint16 {
|
||||||
|
return c.state.localEpoch.Load().(uint16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setRemoteEpoch(epoch uint16) {
|
||||||
|
c.state.remoteEpoch.Store(epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) getRemoteEpoch() uint16 {
|
||||||
|
return c.state.remoteEpoch.Load().(uint16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr implements net.Conn.LocalAddr
|
||||||
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
return c.nextConn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr implements net.Conn.RemoteAddr
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
return c.nextConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline implements net.Conn.SetDeadline
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
c.readDeadline.Set(t)
|
||||||
|
return c.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline implements net.Conn.SetReadDeadline
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
c.readDeadline.Set(t)
|
||||||
|
// Read deadline is fully managed by this layer.
|
||||||
|
// Don't set read deadline to underlying connection.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline implements net.Conn.SetWriteDeadline
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
c.writeDeadline.Set(t)
|
||||||
|
// Write deadline is also fully managed by this layer.
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,169 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/internal/net/dpipe"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextConfig(t *testing.T) {
|
||||||
|
// Limit runtime in case of deadlocks
|
||||||
|
lim := test.TimeOut(time.Second * 20)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
addrListen, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dummy listener
|
||||||
|
listen, err := net.ListenUDP("udp", addrListen)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = listen.Close()
|
||||||
|
}()
|
||||||
|
addr := listen.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
cert, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
config := &Config{
|
||||||
|
ConnectContextMaker: func() (context.Context, func()) {
|
||||||
|
return context.WithTimeout(context.Background(), 40*time.Millisecond)
|
||||||
|
},
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
|
||||||
|
dials := map[string]struct {
|
||||||
|
f func() (func() (net.Conn, error), func())
|
||||||
|
order []byte
|
||||||
|
}{
|
||||||
|
"Dial": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return Dial("udp", addr, config)
|
||||||
|
}, func() {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 1, 2},
|
||||||
|
},
|
||||||
|
"DialWithContext": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return DialWithContext(ctx, "udp", addr, config)
|
||||||
|
}, func() {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 2, 1},
|
||||||
|
},
|
||||||
|
"Client": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
ca, _ := dpipe.Pipe()
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return Client(ca, config)
|
||||||
|
}, func() {
|
||||||
|
_ = ca.Close()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 1, 2},
|
||||||
|
},
|
||||||
|
"ClientWithContext": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
|
||||||
|
ca, _ := dpipe.Pipe()
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return ClientWithContext(ctx, ca, config)
|
||||||
|
}, func() {
|
||||||
|
cancel()
|
||||||
|
_ = ca.Close()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 2, 1},
|
||||||
|
},
|
||||||
|
"Server": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
ca, _ := dpipe.Pipe()
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return Server(ca, config)
|
||||||
|
}, func() {
|
||||||
|
_ = ca.Close()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 1, 2},
|
||||||
|
},
|
||||||
|
"ServerWithContext": {
|
||||||
|
f: func() (func() (net.Conn, error), func()) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
|
||||||
|
ca, _ := dpipe.Pipe()
|
||||||
|
return func() (net.Conn, error) {
|
||||||
|
return ServerWithContext(ctx, ca, config)
|
||||||
|
}, func() {
|
||||||
|
cancel()
|
||||||
|
_ = ca.Close()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order: []byte{0, 2, 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, dial := range dials {
|
||||||
|
dial := dial
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
d, cancel := dial.f()
|
||||||
|
conn, err := d()
|
||||||
|
defer cancel()
|
||||||
|
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
||||||
|
t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
if err == nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var order []byte
|
||||||
|
early := time.After(20 * time.Millisecond)
|
||||||
|
late := time.After(60 * time.Millisecond)
|
||||||
|
func() {
|
||||||
|
for len(order) < 3 {
|
||||||
|
select {
|
||||||
|
case <-early:
|
||||||
|
order = append(order, 0)
|
||||||
|
case _, ok := <-done:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
order = append(order, 1)
|
||||||
|
case <-late:
|
||||||
|
order = append(order, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if !bytes.Equal(dial.order, order) {
|
||||||
|
t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,221 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/binary"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ecdsaSignature struct {
|
||||||
|
R, S *big.Int
|
||||||
|
}
|
||||||
|
|
||||||
|
func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte {
|
||||||
|
serverECDHParams := make([]byte, 4)
|
||||||
|
serverECDHParams[0] = 3 // named curve
|
||||||
|
binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve))
|
||||||
|
serverECDHParams[3] = byte(len(publicKey))
|
||||||
|
|
||||||
|
plaintext := []byte{}
|
||||||
|
plaintext = append(plaintext, clientRandom...)
|
||||||
|
plaintext = append(plaintext, serverRandom...)
|
||||||
|
plaintext = append(plaintext, serverECDHParams...)
|
||||||
|
plaintext = append(plaintext, publicKey...)
|
||||||
|
|
||||||
|
return plaintext
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client provided a "signature_algorithms" extension, then all
|
||||||
|
// certificates provided by the server MUST be signed by a
|
||||||
|
// hash/signature algorithm pair that appears in that extension
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc5246#section-7.4.2
|
||||||
|
func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) {
|
||||||
|
msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve)
|
||||||
|
switch p := privateKey.(type) {
|
||||||
|
case ed25519.PrivateKey:
|
||||||
|
// https://crypto.stackexchange.com/a/55483
|
||||||
|
return p.Sign(rand.Reader, msg, crypto.Hash(0))
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
hashed := hashAlgorithm.Digest(msg)
|
||||||
|
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
hashed := hashAlgorithm.Digest(msg)
|
||||||
|
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errKeySignatureGenerateUnimplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte) error { //nolint:dupl
|
||||||
|
if len(rawCertificates) == 0 {
|
||||||
|
return errLengthMismatch
|
||||||
|
}
|
||||||
|
certificate, err := x509.ParseCertificate(rawCertificates[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := certificate.PublicKey.(type) {
|
||||||
|
case ed25519.PublicKey:
|
||||||
|
if ok := ed25519.Verify(p, message, remoteKeySignature); !ok {
|
||||||
|
return errKeySignatureMismatch
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
ecdsaSig := &ecdsaSignature{}
|
||||||
|
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||||
|
return errInvalidECDSASignature
|
||||||
|
}
|
||||||
|
hashed := hashAlgorithm.Digest(message)
|
||||||
|
if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) {
|
||||||
|
return errKeySignatureMismatch
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
switch certificate.SignatureAlgorithm {
|
||||||
|
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
|
||||||
|
hashed := hashAlgorithm.Digest(message)
|
||||||
|
return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature)
|
||||||
|
default:
|
||||||
|
return errKeySignatureVerifyUnimplemented
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errKeySignatureVerifyUnimplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the server has sent a CertificateRequest message, the client MUST send the Certificate
|
||||||
|
// message. The ClientKeyExchange message is now sent, and the content
|
||||||
|
// of that message will depend on the public key algorithm selected
|
||||||
|
// between the ClientHello and the ServerHello. If the client has sent
|
||||||
|
// a certificate with signing ability, a digitally-signed
|
||||||
|
// CertificateVerify message is sent to explicitly verify possession of
|
||||||
|
// the private key in the certificate.
|
||||||
|
// https://tools.ietf.org/html/rfc5246#section-7.3
|
||||||
|
func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) {
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := h.Write(handshakeBodies); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hashed := h.Sum(nil)
|
||||||
|
|
||||||
|
switch p := privateKey.(type) {
|
||||||
|
case ed25519.PrivateKey:
|
||||||
|
// https://crypto.stackexchange.com/a/55483
|
||||||
|
return p.Sign(rand.Reader, hashed, crypto.Hash(0))
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errInvalidSignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl
|
||||||
|
if len(rawCertificates) == 0 {
|
||||||
|
return errLengthMismatch
|
||||||
|
}
|
||||||
|
certificate, err := x509.ParseCertificate(rawCertificates[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := certificate.PublicKey.(type) {
|
||||||
|
case ed25519.PublicKey:
|
||||||
|
if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok {
|
||||||
|
return errKeySignatureMismatch
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
ecdsaSig := &ecdsaSignature{}
|
||||||
|
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||||
|
return errInvalidECDSASignature
|
||||||
|
}
|
||||||
|
hash := hashAlgorithm.Digest(handshakeBodies)
|
||||||
|
if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) {
|
||||||
|
return errKeySignatureMismatch
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
switch certificate.SignatureAlgorithm {
|
||||||
|
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
|
||||||
|
hash := hashAlgorithm.Digest(handshakeBodies)
|
||||||
|
return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hash, remoteKeySignature)
|
||||||
|
default:
|
||||||
|
return errKeySignatureVerifyUnimplemented
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errKeySignatureVerifyUnimplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) {
|
||||||
|
if len(rawCertificates) == 0 {
|
||||||
|
return nil, errLengthMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
certs := make([]*x509.Certificate, 0, len(rawCertificates))
|
||||||
|
for _, rawCert := range rawCertificates {
|
||||||
|
cert, err := x509.ParseCertificate(rawCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
return certs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) {
|
||||||
|
certificate, err := loadCerts(rawCertificates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
intermediateCAPool := x509.NewCertPool()
|
||||||
|
for _, cert := range certificate[1:] {
|
||||||
|
intermediateCAPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: roots,
|
||||||
|
CurrentTime: time.Now(),
|
||||||
|
Intermediates: intermediateCAPool,
|
||||||
|
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||||
|
}
|
||||||
|
return certificate[0].Verify(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) {
|
||||||
|
certificate, err := loadCerts(rawCertificates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
intermediateCAPool := x509.NewCertPool()
|
||||||
|
for _, cert := range certificate[1:] {
|
||||||
|
intermediateCAPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: roots,
|
||||||
|
CurrentTime: time.Now(),
|
||||||
|
DNSName: serverName,
|
||||||
|
Intermediates: intermediateCAPool,
|
||||||
|
}
|
||||||
|
return certificate[0].Verify(opts)
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
const rawPrivateKey = `
|
||||||
|
-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP
|
||||||
|
6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3
|
||||||
|
bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1
|
||||||
|
ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB
|
||||||
|
ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL
|
||||||
|
rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0
|
||||||
|
N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra
|
||||||
|
ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3
|
||||||
|
AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0
|
||||||
|
xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD
|
||||||
|
8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl
|
||||||
|
pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22
|
||||||
|
HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S
|
||||||
|
Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87
|
||||||
|
mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin
|
||||||
|
mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71
|
||||||
|
hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5
|
||||||
|
lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH
|
||||||
|
BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz
|
||||||
|
uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ
|
||||||
|
srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq
|
||||||
|
eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O
|
||||||
|
UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr
|
||||||
|
nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L
|
||||||
|
JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/
|
||||||
|
-----END RSA PRIVATE KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestGenerateKeySignature(t *testing.T) {
|
||||||
|
block, _ := pem.Decode([]byte(rawPrivateKey))
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}
|
||||||
|
serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f}
|
||||||
|
publicKey := []byte{0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15}
|
||||||
|
expectedSignature := []byte{
|
||||||
|
0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf,
|
||||||
|
0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc,
|
||||||
|
0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0,
|
||||||
|
0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40,
|
||||||
|
0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30,
|
||||||
|
0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55,
|
||||||
|
0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30,
|
||||||
|
0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36,
|
||||||
|
0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25,
|
||||||
|
0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c,
|
||||||
|
0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86,
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if !bytes.Equal(expectedSignature, signature) {
|
||||||
|
t.Errorf("Signature generation failed \nexp % 02x \nactual % 02x ", expectedSignature, signature)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
// Package dtls implements Datagram Transport Layer Security (DTLS) 1.2
|
||||||
|
package dtls
|
|
@ -0,0 +1,11 @@
|
||||||
|
FROM golang:1.14-alpine3.11
|
||||||
|
|
||||||
|
RUN apk add --no-cache \
|
||||||
|
openssl
|
||||||
|
|
||||||
|
ENV CGO_ENABLED=0
|
||||||
|
|
||||||
|
COPY . /go/src/github.com/pion/dtls
|
||||||
|
WORKDIR /go/src/github.com/pion/dtls/e2e
|
||||||
|
|
||||||
|
CMD ["go", "test", "-tags=openssl", "-v", "."]
|
|
@ -0,0 +1,2 @@
|
||||||
|
// Package e2e contains end to end tests for pion/dtls
|
||||||
|
package e2e
|
|
@ -0,0 +1,207 @@
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
transportTest "github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
flightInterval = time.Millisecond * 100
|
||||||
|
lossyTestTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments
|
||||||
|
*/
|
||||||
|
func TestPionE2ELossy(t *testing.T) {
|
||||||
|
// Check for leaking routines
|
||||||
|
report := transportTest.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
type runResult struct {
|
||||||
|
dtlsConn *dtls.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
serverCert, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCert, err := selfsign.GenerateSelfSigned()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range []struct {
|
||||||
|
LossChanceRange int
|
||||||
|
DoClientAuth bool
|
||||||
|
CipherSuites []dtls.CipherSuiteID
|
||||||
|
MTU int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
LossChanceRange: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 0,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 10,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 20,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 50,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 0,
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 10,
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 20,
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 50,
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 10,
|
||||||
|
MTU: 100,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 20,
|
||||||
|
MTU: 100,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LossChanceRange: 50,
|
||||||
|
MTU: 100,
|
||||||
|
DoClientAuth: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU)
|
||||||
|
if test.DoClientAuth {
|
||||||
|
name += "_WithCliAuth"
|
||||||
|
}
|
||||||
|
for _, ciph := range test.CipherSuites {
|
||||||
|
name += "_With" + ciph.String()
|
||||||
|
}
|
||||||
|
test := test
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
// Limit runtime in case of deadlocks
|
||||||
|
lim := transportTest.TimeOut(lossyTestTimeout + time.Second)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
rand.Seed(time.Now().UTC().UnixNano())
|
||||||
|
chosenLoss := rand.Intn(9) + test.LossChanceRange //nolint:gosec
|
||||||
|
serverDone := make(chan runResult)
|
||||||
|
clientDone := make(chan runResult)
|
||||||
|
br := transportTest.NewBridge()
|
||||||
|
|
||||||
|
if err = br.SetLossChance(chosenLoss); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
FlightInterval: flightInterval,
|
||||||
|
CipherSuites: test.CipherSuites,
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
MTU: test.MTU,
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.DoClientAuth {
|
||||||
|
cfg.Certificates = []tls.Certificate{clientCert}
|
||||||
|
}
|
||||||
|
|
||||||
|
client, startupErr := dtls.Client(br.GetConn0(), cfg)
|
||||||
|
clientDone <- runResult{client, startupErr}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
FlightInterval: flightInterval,
|
||||||
|
MTU: test.MTU,
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.DoClientAuth {
|
||||||
|
cfg.ClientAuth = dtls.RequireAnyClientCert
|
||||||
|
}
|
||||||
|
|
||||||
|
server, startupErr := dtls.Server(br.GetConn1(), cfg)
|
||||||
|
serverDone <- runResult{server, startupErr}
|
||||||
|
}()
|
||||||
|
|
||||||
|
testTimer := time.NewTimer(lossyTestTimeout)
|
||||||
|
var serverConn, clientConn *dtls.Conn
|
||||||
|
defer func() {
|
||||||
|
if serverConn != nil {
|
||||||
|
if err = serverConn.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if clientConn != nil {
|
||||||
|
if err = clientConn.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if serverConn != nil && clientConn != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
br.Tick()
|
||||||
|
select {
|
||||||
|
case serverResult := <-serverDone:
|
||||||
|
if serverResult.err != nil {
|
||||||
|
t.Errorf("Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConn = serverResult.dtlsConn
|
||||||
|
case clientResult := <-clientDone:
|
||||||
|
if clientResult.err != nil {
|
||||||
|
t.Errorf("Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn = clientResult.dtlsConn
|
||||||
|
case <-testTimer.C:
|
||||||
|
t.Errorf("Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss)
|
||||||
|
return
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,250 @@
|
||||||
|
// +build openssl,!js
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func serverOpenSSL(c *comm) {
|
||||||
|
go func() {
|
||||||
|
c.serverMutex.Lock()
|
||||||
|
defer c.serverMutex.Unlock()
|
||||||
|
|
||||||
|
cfg := c.serverConfig
|
||||||
|
|
||||||
|
// create openssl arguments
|
||||||
|
args := []string{
|
||||||
|
"s_server",
|
||||||
|
"-dtls1_2",
|
||||||
|
"-quiet",
|
||||||
|
"-verify_quiet",
|
||||||
|
"-verify_return_error",
|
||||||
|
fmt.Sprintf("-accept=%d", c.serverPort),
|
||||||
|
}
|
||||||
|
ciphers := ciphersOpenSSL(cfg)
|
||||||
|
if ciphers != "" {
|
||||||
|
args = append(args, fmt.Sprintf("-cipher=%s", ciphers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// psk arguments
|
||||||
|
if cfg.PSK != nil {
|
||||||
|
psk, err := cfg.PSK(nil)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args, fmt.Sprintf("-psk=%X", psk))
|
||||||
|
if len(cfg.PSKIdentityHint) > 0 {
|
||||||
|
args = append(args, fmt.Sprintf("-psk_hint=%s", cfg.PSKIdentityHint))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// certs arguments
|
||||||
|
if len(cfg.Certificates) > 0 {
|
||||||
|
// create temporary cert files
|
||||||
|
certPEM, keyPEM, err := writeTempPEM(cfg)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args,
|
||||||
|
fmt.Sprintf("-cert=%s", certPEM),
|
||||||
|
fmt.Sprintf("-key=%s", keyPEM))
|
||||||
|
defer func() {
|
||||||
|
_ = os.Remove(certPEM)
|
||||||
|
_ = os.Remove(keyPEM)
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
args = append(args, "-nocert")
|
||||||
|
}
|
||||||
|
|
||||||
|
// launch command
|
||||||
|
// #nosec G204
|
||||||
|
cmd := exec.CommandContext(c.ctx, "openssl", args...)
|
||||||
|
var inner net.Conn
|
||||||
|
inner, c.serverConn = net.Pipe()
|
||||||
|
cmd.Stdin = inner
|
||||||
|
cmd.Stdout = inner
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
_ = inner.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that server has started
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
c.serverReady <- struct{}{}
|
||||||
|
simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientOpenSSL(c *comm) {
|
||||||
|
select {
|
||||||
|
case <-c.serverReady:
|
||||||
|
// OK
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
c.errChan <- errors.New("waiting on serverReady err: timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.clientMutex.Lock()
|
||||||
|
defer c.clientMutex.Unlock()
|
||||||
|
|
||||||
|
cfg := c.clientConfig
|
||||||
|
|
||||||
|
// create openssl arguments
|
||||||
|
args := []string{
|
||||||
|
"s_client",
|
||||||
|
"-dtls1_2",
|
||||||
|
"-quiet",
|
||||||
|
"-verify_quiet",
|
||||||
|
"-verify_return_error",
|
||||||
|
"-servername=localhost",
|
||||||
|
fmt.Sprintf("-connect=127.0.0.1:%d", c.serverPort),
|
||||||
|
}
|
||||||
|
ciphers := ciphersOpenSSL(cfg)
|
||||||
|
if ciphers != "" {
|
||||||
|
args = append(args, fmt.Sprintf("-cipher=%s", ciphers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// psk arguments
|
||||||
|
if cfg.PSK != nil {
|
||||||
|
psk, err := cfg.PSK(nil)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args, fmt.Sprintf("-psk=%X", psk))
|
||||||
|
}
|
||||||
|
|
||||||
|
// certificate arguments
|
||||||
|
if len(cfg.Certificates) > 0 {
|
||||||
|
// create temporary cert files
|
||||||
|
certPEM, keyPEM, err := writeTempPEM(cfg)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args, fmt.Sprintf("-CAfile=%s", certPEM))
|
||||||
|
defer func() {
|
||||||
|
_ = os.Remove(certPEM)
|
||||||
|
_ = os.Remove(keyPEM)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// launch command
|
||||||
|
// #nosec G204
|
||||||
|
cmd := exec.CommandContext(c.ctx, "openssl", args...)
|
||||||
|
var inner net.Conn
|
||||||
|
inner, c.clientConn = net.Pipe()
|
||||||
|
cmd.Stdin = inner
|
||||||
|
cmd.Stdout = inner
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
_ = inner.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ciphersOpenSSL(cfg *dtls.Config) string {
|
||||||
|
// See https://tls.mbed.org/supported-ssl-ciphersuites
|
||||||
|
translate := map[dtls.CipherSuiteID]string{
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM: "ECDHE-ECDSA-AES128-CCM",
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: "ECDHE-ECDSA-AES128-CCM8",
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256",
|
||||||
|
dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256",
|
||||||
|
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA",
|
||||||
|
dtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES128-SHA",
|
||||||
|
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_CCM: "PSK-AES128-CCM",
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_CCM_8: "PSK-AES128-CCM8",
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_GCM_SHA256: "PSK-AES128-GCM-SHA256",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ciphers []string
|
||||||
|
for _, c := range cfg.CipherSuites {
|
||||||
|
if text, ok := translate[c]; ok {
|
||||||
|
ciphers = append(ciphers, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(ciphers, ";")
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTempPEM(cfg *dtls.Config) (string, string, error) {
|
||||||
|
certOut, err := ioutil.TempFile("", "cert.pem")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||||
|
}
|
||||||
|
keyOut, err := ioutil.TempFile("", "key.pem")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := cfg.Certificates[0]
|
||||||
|
derBytes := cert.Certificate[0]
|
||||||
|
if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to write data to cert.pem: %w", err)
|
||||||
|
}
|
||||||
|
if err = certOut.Close(); err != nil {
|
||||||
|
return "", "", fmt.Errorf("error closing cert.pem: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
priv := cert.PrivateKey
|
||||||
|
var privBytes []byte
|
||||||
|
privBytes, err = x509.MarshalPKCS8PrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("unable to marshal private key: %w", err)
|
||||||
|
}
|
||||||
|
if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to write data to key.pem: %w", err)
|
||||||
|
}
|
||||||
|
if err = keyOut.Close(); err != nil {
|
||||||
|
return "", "", fmt.Errorf("error closing key.pem: %w", err)
|
||||||
|
}
|
||||||
|
return certOut.Name(), keyOut.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionOpenSSLE2ESimple(t *testing.T) {
|
||||||
|
t.Run("OpenSSLServer", func(t *testing.T) {
|
||||||
|
testPionE2ESimple(t, serverOpenSSL, clientPion)
|
||||||
|
})
|
||||||
|
t.Run("OpenSSLClient", func(t *testing.T) {
|
||||||
|
testPionE2ESimple(t, serverPion, clientOpenSSL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionOpenSSLE2ESimplePSK(t *testing.T) {
|
||||||
|
t.Run("OpenSSLServer", func(t *testing.T) {
|
||||||
|
testPionE2ESimplePSK(t, serverOpenSSL, clientPion)
|
||||||
|
})
|
||||||
|
t.Run("OpenSSLClient", func(t *testing.T) {
|
||||||
|
testPionE2ESimplePSK(t, serverPion, clientOpenSSL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionOpenSSLE2EMTUs(t *testing.T) {
|
||||||
|
t.Run("OpenSSLServer", func(t *testing.T) {
|
||||||
|
testPionE2EMTUs(t, serverOpenSSL, clientPion)
|
||||||
|
})
|
||||||
|
t.Run("OpenSSLClient", func(t *testing.T) {
|
||||||
|
testPionE2EMTUs(t, serverPion, clientOpenSSL)
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
// +build openssl,go1.13,!js
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPionOpenSSLE2ESimpleED25519(t *testing.T) {
|
||||||
|
t.Skip("TODO: waiting OpenSSL's DTLS Ed25519 support")
|
||||||
|
t.Run("OpenSSLServer", func(t *testing.T) {
|
||||||
|
testPionE2ESimpleED25519(t, serverOpenSSL, clientPion)
|
||||||
|
})
|
||||||
|
t.Run("OpenSSLClient", func(t *testing.T) {
|
||||||
|
testPionE2ESimpleED25519(t, serverPion, clientOpenSSL)
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,329 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testMessage = "Hello World"
|
||||||
|
testTimeLimit = 5 * time.Second
|
||||||
|
messageRetry = 200 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
var errServerTimeout = errors.New("waiting on serverReady err: timeout")
|
||||||
|
|
||||||
|
func randomPort(t testing.TB) int {
|
||||||
|
t.Helper()
|
||||||
|
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to pickPort: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
switch addr := conn.LocalAddr().(type) {
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return addr.Port
|
||||||
|
default:
|
||||||
|
t.Fatalf("unknown addr type %T", addr)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
|
||||||
|
go func() {
|
||||||
|
buffer := make([]byte, 8192)
|
||||||
|
n, err := conn.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
outChan <- string(buffer[:n])
|
||||||
|
atomic.AddUint64(messageRecvCount, 1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if atomic.LoadUint64(messageRecvCount) == 2 {
|
||||||
|
break
|
||||||
|
} else if _, err := conn.Write([]byte(testMessage)); err != nil {
|
||||||
|
errChan <- err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(messageRetry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type comm struct {
|
||||||
|
ctx context.Context
|
||||||
|
clientConfig, serverConfig *dtls.Config
|
||||||
|
serverPort int
|
||||||
|
messageRecvCount *uint64 // Counter to make sure both sides got a message
|
||||||
|
clientMutex *sync.Mutex
|
||||||
|
clientConn net.Conn
|
||||||
|
serverMutex *sync.Mutex
|
||||||
|
serverConn net.Conn
|
||||||
|
serverListener net.Listener
|
||||||
|
serverReady chan struct{}
|
||||||
|
errChan chan error
|
||||||
|
clientChan chan string
|
||||||
|
serverChan chan string
|
||||||
|
client func(*comm)
|
||||||
|
server func(*comm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
|
||||||
|
messageRecvCount := uint64(0)
|
||||||
|
c := &comm{
|
||||||
|
ctx: ctx,
|
||||||
|
clientConfig: clientConfig,
|
||||||
|
serverConfig: serverConfig,
|
||||||
|
serverPort: serverPort,
|
||||||
|
messageRecvCount: &messageRecvCount,
|
||||||
|
clientMutex: &sync.Mutex{},
|
||||||
|
serverMutex: &sync.Mutex{},
|
||||||
|
serverReady: make(chan struct{}),
|
||||||
|
errChan: make(chan error),
|
||||||
|
clientChan: make(chan string),
|
||||||
|
serverChan: make(chan string),
|
||||||
|
server: server,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *comm) assert(t *testing.T) {
|
||||||
|
// DTLS Client
|
||||||
|
go c.client(c)
|
||||||
|
|
||||||
|
// DTLS Server
|
||||||
|
go c.server(c)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if c.clientConn != nil {
|
||||||
|
if err := c.clientConn.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.serverConn != nil {
|
||||||
|
if err := c.serverConn.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.serverListener != nil {
|
||||||
|
if err := c.serverListener.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
seenClient, seenServer := false, false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-c.errChan:
|
||||||
|
t.Fatal(err)
|
||||||
|
case <-time.After(testTimeLimit):
|
||||||
|
t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
|
||||||
|
case clientMsg := <-c.clientChan:
|
||||||
|
if clientMsg != testMessage {
|
||||||
|
t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
seenClient = true
|
||||||
|
if seenClient && seenServer {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case serverMsg := <-c.serverChan:
|
||||||
|
if serverMsg != testMessage {
|
||||||
|
t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
seenServer = true
|
||||||
|
if seenClient && seenServer {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientPion(c *comm) {
|
||||||
|
select {
|
||||||
|
case <-c.serverReady:
|
||||||
|
// OK
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
c.errChan <- errServerTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
c.clientMutex.Lock()
|
||||||
|
defer c.clientMutex.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
|
||||||
|
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
||||||
|
c.clientConfig,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverPion(c *comm) {
|
||||||
|
c.serverMutex.Lock()
|
||||||
|
defer c.serverMutex.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.serverListener, err = dtls.Listen("udp",
|
||||||
|
&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
|
||||||
|
c.serverConfig,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.serverReady <- struct{}{}
|
||||||
|
c.serverConn, err = c.serverListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
c.errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Simple DTLS Client/Server can communicate
|
||||||
|
- Assert that you can send messages both ways
|
||||||
|
- Assert that Close() on both ends work
|
||||||
|
- Assert that no Goroutines are leaked
|
||||||
|
*/
|
||||||
|
func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
|
||||||
|
lim := test.TimeOut(time.Second * 30)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
for _, cipherSuite := range []dtls.CipherSuiteID{
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
} {
|
||||||
|
cipherSuite := cipherSuite
|
||||||
|
t.Run(cipherSuite.String(), func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
serverPort := randomPort(t)
|
||||||
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||||
|
comm.assert(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
|
||||||
|
lim := test.TimeOut(time.Second * 30)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
for _, cipherSuite := range []dtls.CipherSuiteID{
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_CCM,
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_CCM_8,
|
||||||
|
dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
|
||||||
|
} {
|
||||||
|
cipherSuite := cipherSuite
|
||||||
|
t.Run(cipherSuite.String(), func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
return []byte{0xAB, 0xC1, 0x23}, nil
|
||||||
|
},
|
||||||
|
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
||||||
|
}
|
||||||
|
serverPort := randomPort(t)
|
||||||
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||||
|
comm.assert(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
|
||||||
|
lim := test.TimeOut(time.Second * 30)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
for _, mtu := range []int{
|
||||||
|
10000,
|
||||||
|
1000,
|
||||||
|
100,
|
||||||
|
} {
|
||||||
|
mtu := mtu
|
||||||
|
t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
MTU: mtu,
|
||||||
|
}
|
||||||
|
serverPort := randomPort(t)
|
||||||
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||||
|
comm.assert(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionE2ESimple(t *testing.T) {
|
||||||
|
testPionE2ESimple(t, serverPion, clientPion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionE2ESimplePSK(t *testing.T) {
|
||||||
|
testPionE2ESimplePSK(t, serverPion, clientPion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionE2EMTUs(t *testing.T) {
|
||||||
|
testPionE2EMTUs(t, serverPion, clientPion)
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
// +build go1.13,!js
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
"github.com/pion/transport/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ED25519 is not supported in Go 1.12 crypto/x509.
|
||||||
|
// Once Go 1.12 is deprecated, move this test to e2e_test.go.
|
||||||
|
|
||||||
|
func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) {
|
||||||
|
lim := test.TimeOut(time.Second * 30)
|
||||||
|
defer lim.Stop()
|
||||||
|
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
for _, cipherSuite := range []dtls.CipherSuiteID{
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
} {
|
||||||
|
cipherSuite := cipherSuite
|
||||||
|
t.Run(cipherSuite.String(), func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cert, err := selfsign.SelfSign(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{cipherSuite},
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
serverPort := randomPort(t)
|
||||||
|
comm := newComm(ctx, cfg, cfg, serverPort, server, client)
|
||||||
|
comm.assert(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPionE2ESimpleED25519(t *testing.T) {
|
||||||
|
testPionE2ESimpleED25519(t, serverPion, clientPion)
|
||||||
|
}
|
|
@ -0,0 +1,141 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Typed errors
|
||||||
|
var (
|
||||||
|
ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113
|
||||||
|
|
||||||
|
errDeadlineExceeded = &TimeoutError{Err: xerrors.Errorf("read/write timeout: %w", context.DeadlineExceeded)}
|
||||||
|
errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
|
||||||
|
|
||||||
|
errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
|
||||||
|
errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113
|
||||||
|
errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113
|
||||||
|
errReservedExportKeyingMaterial = &TemporaryError{Err: errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113
|
||||||
|
errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113
|
||||||
|
errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113
|
||||||
|
|
||||||
|
errCertificateVerifyNoCertificate = &FatalError{Err: errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113
|
||||||
|
errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113
|
||||||
|
errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113
|
||||||
|
errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113
|
||||||
|
errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113
|
||||||
|
errClientRequiredButNoServerEMS = &FatalError{Err: errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113
|
||||||
|
errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113
|
||||||
|
errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113
|
||||||
|
errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113
|
||||||
|
errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113
|
||||||
|
errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113
|
||||||
|
errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113
|
||||||
|
errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113
|
||||||
|
errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113
|
||||||
|
errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113
|
||||||
|
errNoAvailableCipherSuites = &FatalError{Err: errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113
|
||||||
|
errNoAvailablePSKCipherSuite = &FatalError{Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite")} //nolint:goerr113
|
||||||
|
errNoAvailableCertificateCipherSuite = &FatalError{Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite")} //nolint:goerr113
|
||||||
|
errNoAvailableSignatureSchemes = &FatalError{Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113
|
||||||
|
errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113
|
||||||
|
errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113
|
||||||
|
errNoSupportedEllipticCurves = &FatalError{Err: errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113
|
||||||
|
errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
|
||||||
|
errPSKAndIdentityMustBeSetForClient = &FatalError{Err: errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113
|
||||||
|
errRequestedButNoSRTPExtension = &FatalError{Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113
|
||||||
|
errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
|
||||||
|
errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
|
||||||
|
errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113
|
||||||
|
|
||||||
|
errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113
|
||||||
|
errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
|
||||||
|
errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113
|
||||||
|
errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
|
||||||
|
errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
|
||||||
|
errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113
|
||||||
|
)
|
||||||
|
|
||||||
|
// FatalError indicates that the DTLS connection is no longer available.
|
||||||
|
// It is mainly caused by wrong configuration of server or client.
|
||||||
|
type FatalError = protocol.FatalError
|
||||||
|
|
||||||
|
// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available.
|
||||||
|
// It is mainly caused by bugs or tried to use unimplemented features.
|
||||||
|
type InternalError = protocol.InternalError
|
||||||
|
|
||||||
|
// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary.
|
||||||
|
type TemporaryError = protocol.TemporaryError
|
||||||
|
|
||||||
|
// TimeoutError indicates that the request was timed out.
|
||||||
|
type TimeoutError = protocol.TimeoutError
|
||||||
|
|
||||||
|
// HandshakeError indicates that the handshake failed.
|
||||||
|
type HandshakeError = protocol.HandshakeError
|
||||||
|
|
||||||
|
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
|
||||||
|
type invalidCipherSuite struct {
|
||||||
|
id CipherSuiteID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidCipherSuite) Error() string {
|
||||||
|
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidCipherSuite) Is(err error) bool {
|
||||||
|
if other, ok := err.(*invalidCipherSuite); ok {
|
||||||
|
return e.id == other.id
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// errAlert wraps DTLS alert notification as an error
|
||||||
|
type errAlert struct {
|
||||||
|
*alert.Alert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errAlert) Error() string {
|
||||||
|
return fmt.Sprintf("alert: %s", e.Alert.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errAlert) IsFatalOrCloseNotify() bool {
|
||||||
|
return e.Level == alert.Fatal || e.Description == alert.CloseNotify
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errAlert) Is(err error) bool {
|
||||||
|
if other, ok := err.(*errAlert); ok {
|
||||||
|
return e.Level == other.Level && e.Description == other.Description
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// netError translates an error from underlying Conn to corresponding net.Error.
|
||||||
|
func netError(err error) error {
|
||||||
|
switch err {
|
||||||
|
case io.EOF, context.Canceled, context.DeadlineExceeded:
|
||||||
|
// Return io.EOF and context errors as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch e := err.(type) {
|
||||||
|
case (*net.OpError):
|
||||||
|
if se, ok := e.Err.(*os.SyscallError); ok {
|
||||||
|
if se.Timeout() {
|
||||||
|
return &TimeoutError{Err: err}
|
||||||
|
}
|
||||||
|
if isOpErrorTemporary(se) {
|
||||||
|
return &TemporaryError{Err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case (net.Error):
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return &FatalError{Err: err}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
// +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows
|
||||||
|
|
||||||
|
// For systems having syscall.Errno.
|
||||||
|
// Update build targets by following command:
|
||||||
|
// $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \
|
||||||
|
// | tr "." "_" | cut -d"_" -f"2" | sort | uniq
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isOpErrorTemporary(err *os.SyscallError) bool {
|
||||||
|
if ne, ok := err.Err.(syscall.Errno); ok {
|
||||||
|
switch ne {
|
||||||
|
case syscall.ECONNREFUSED:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
// +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows
|
||||||
|
|
||||||
|
// For systems having syscall.Errno.
|
||||||
|
// The build target must be same as errors_errno.go.
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrorsTemporary(t *testing.T) {
|
||||||
|
addrListen, errListen := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
|
if errListen != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", errListen)
|
||||||
|
}
|
||||||
|
// Server is not listening.
|
||||||
|
conn, errDial := net.DialUDP("udp", nil, addrListen)
|
||||||
|
if errDial != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", errDial)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = conn.Write([]byte{0x00}) // trigger
|
||||||
|
_, err := conn.Read(make([]byte, 10))
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Skip("ECONNREFUSED is not set by system")
|
||||||
|
}
|
||||||
|
ne, ok := netError(err).(net.Error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("netError must return net.Error")
|
||||||
|
}
|
||||||
|
if ne.Timeout() {
|
||||||
|
t.Errorf("%v must not be timeout error", err)
|
||||||
|
}
|
||||||
|
if !ne.Temporary() {
|
||||||
|
t.Errorf("%v must be temporary error", err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows
|
||||||
|
|
||||||
|
// For systems without syscall.Errno.
|
||||||
|
// Build targets must be inverse of errors_errno.go
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isOpErrorTemporary(err *os.SyscallError) bool {
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errExample = errors.New("an example error")
|
||||||
|
|
||||||
|
func TestErrorUnwrap(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
err error
|
||||||
|
errUnwrapped []error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
&FatalError{Err: errExample},
|
||||||
|
[]error{errExample},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&TemporaryError{Err: errExample},
|
||||||
|
[]error{errExample},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&InternalError{Err: errExample},
|
||||||
|
[]error{errExample},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&TimeoutError{Err: errExample},
|
||||||
|
[]error{errExample},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&HandshakeError{Err: errExample},
|
||||||
|
[]error{errExample},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
c := c
|
||||||
|
t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
|
||||||
|
err := c.err
|
||||||
|
for _, unwrapped := range c.errUnwrapped {
|
||||||
|
e := xerrors.Unwrap(err)
|
||||||
|
if !errors.Is(e, unwrapped) {
|
||||||
|
t.Errorf("Unwrapped error is expected to be '%v', got '%v'", unwrapped, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorNetError(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
err error
|
||||||
|
str string
|
||||||
|
timeout, temporary bool
|
||||||
|
}{
|
||||||
|
{&FatalError{Err: errExample}, "dtls fatal: an example error", false, false},
|
||||||
|
{&TemporaryError{Err: errExample}, "dtls temporary: an example error", false, true},
|
||||||
|
{&InternalError{Err: errExample}, "dtls internal: an example error", false, false},
|
||||||
|
{&TimeoutError{Err: errExample}, "dtls timeout: an example error", true, true},
|
||||||
|
{&HandshakeError{Err: errExample}, "handshake error: an example error", false, false},
|
||||||
|
{&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
c := c
|
||||||
|
t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
|
||||||
|
ne, ok := c.err.(net.Error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("%T doesn't implement net.Error", c.err)
|
||||||
|
}
|
||||||
|
if ne.Timeout() != c.timeout {
|
||||||
|
t.Errorf("%T.Timeout() should be %v", c.err, c.timeout)
|
||||||
|
}
|
||||||
|
if ne.Temporary() != c.temporary {
|
||||||
|
t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
|
||||||
|
}
|
||||||
|
if ne.Error() != c.str {
|
||||||
|
t.Errorf("%T.Error() should be %v", c.err, c.str)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Certificates
|
||||||
|
|
||||||
|
The certificates in for the examples are generated using the commands shown below.
|
||||||
|
|
||||||
|
Note that this was run on OpenSSL 1.1.1d, of which the arguments can be found in the [OpenSSL Manpages](https://www.openssl.org/docs/man1.1.1/man1), and is not guaranteed to work on different OpenSSL versions.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Extensions required for certificate validation.
|
||||||
|
$ EXTFILE='extfile.conf'
|
||||||
|
$ echo 'subjectAltName = IP:127.0.0.1\nbasicConstraints = critical,CA:true' > "${EXTFILE}"
|
||||||
|
|
||||||
|
# Server.
|
||||||
|
$ SERVER_NAME='server'
|
||||||
|
$ openssl ecparam -name prime256v1 -genkey -noout -out "${SERVER_NAME}.pem"
|
||||||
|
$ openssl req -key "${SERVER_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${SERVER_NAME}.csr"
|
||||||
|
$ openssl x509 -req -in "${SERVER_NAME}.csr" -extfile "${EXTFILE}" -days 365 -signkey "${SERVER_NAME}.pem" -sha256 -out "${SERVER_NAME}.pub.pem"
|
||||||
|
|
||||||
|
# Client.
|
||||||
|
$ CLIENT_NAME='client'
|
||||||
|
$ openssl ecparam -name prime256v1 -genkey -noout -out "${CLIENT_NAME}.pem"
|
||||||
|
$ openssl req -key "${CLIENT_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${CLIENT_NAME}.csr"
|
||||||
|
$ openssl x509 -req -in "${CLIENT_NAME}.csr" -extfile "${EXTFILE}" -days 365 -CA "${SERVER_NAME}.pub.pem" -CAkey "${SERVER_NAME}.pem" -set_serial '0xabcd' -sha256 -out "${CLIENT_NAME}.pub.pem"
|
||||||
|
|
||||||
|
# Cleanup.
|
||||||
|
$ rm "${EXTFILE}" "${SERVER_NAME}.csr" "${CLIENT_NAME}.csr"
|
||||||
|
```
|
|
@ -0,0 +1,5 @@
|
||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIGOO78dEAcepxdUIeDzC28jMcFrJr2q7x+UdhgtJ/RS3oAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAEGLSNxlkJ9mETKI2Hogq3Cyh06pJKA1YMgcKqYKS6yQQlvvk5rU88
|
||||||
|
+RojFPgXJukymhfIJmw4eGxxEMSjuEZY7w==
|
||||||
|
-----END EC PRIVATE KEY-----
|
|
@ -0,0 +1,9 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBLTCB1aADAgECAgMAq80wCgYIKoZIzj0EAwIwDTELMAkGA1UEBhMCTkwwHhcN
|
||||||
|
MjAwMzIwMDk0NjQ0WhcNMjEwMzIwMDk0NjQ0WjANMQswCQYDVQQGEwJOTDBZMBMG
|
||||||
|
ByqGSM49AgEGCCqGSM49AwEHA0IABBi0jcZZCfZhEyiNh6IKtwsodOqSSgNWDIHC
|
||||||
|
qmCkuskEJb75Oa1PPPkaIxT4FybpMpoXyCZsOHhscRDEo7hGWO+jJDAiMA8GA1Ud
|
||||||
|
EQQIMAaHBH8AAAEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBx
|
||||||
|
sIkcADN9E60veZOFOeANaRWAiQaLWZfUxqkOmfHztQIgI2CfHMjDQwJZFh35HvFs
|
||||||
|
NOPJj8wxFhqR5pqMF23cgOY=
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -0,0 +1,5 @@
|
||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIDT8Xyx5RpPP+98ulYZKsvKIVdBUJug/L9H2M8JThv+GoAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAE6Wf0qQqIb5G7g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMy
|
||||||
|
bJ6MGIwGcVBMgoL3pfxDKdZ3mnzmoibU0w==
|
||||||
|
-----END EC PRIVATE KEY-----
|
|
@ -0,0 +1,9 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBPzCB5qADAgECAhRtzyVTL+9D0KHfbcKYeKckpLVRmTAKBggqhkjOPQQDAjAN
|
||||||
|
MQswCQYDVQQGEwJOTDAeFw0yMDAzMjAwOTQ2NDRaFw0yMTAzMjAwOTQ2NDRaMA0x
|
||||||
|
CzAJBgNVBAYTAk5MMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6Wf0qQqIb5G7
|
||||||
|
g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMybJ6MGIwGcVBMgoL3pfxDKdZ3
|
||||||
|
mnzmoibU06MkMCIwDwYDVR0RBAgwBocEfwAAATAPBgNVHRMBAf8EBTADAQH/MAoG
|
||||||
|
CCqGSM49BAMCA0gAMEUCIQD000SU+klkNLGvHZcMYNVkCFsImnGKIqPMy3LELSiF
|
||||||
|
0gIgSGIFkNEIAyNxn44CXZJu3piyz1ouK2fLefDJMYfcXgM=
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -0,0 +1,45 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
fmt.Printf("Server's hint: %s \n", hint)
|
||||||
|
return []byte{0xAB, 0xC1, 0x23}, nil
|
||||||
|
},
|
||||||
|
PSKIdentityHint: []byte("Pion DTLS Server"),
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(dtlsConn.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Connected; type 'exit' to shutdown gracefully")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
util.Chat(dtlsConn)
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
// Generate a certificate and private key to secure the connection
|
||||||
|
certificate, genErr := selfsign.GenerateSelfSigned()
|
||||||
|
util.Check(genErr)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{certificate},
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(dtlsConn.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Connected; type 'exit' to shutdown gracefully")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
util.Chat(dtlsConn)
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
certificate, err := util.LoadKeyAndCertificate("examples/certificates/client.pem",
|
||||||
|
"examples/certificates/client.pub.pem")
|
||||||
|
util.Check(err)
|
||||||
|
|
||||||
|
rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem")
|
||||||
|
util.Check(err)
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
cert, err := x509.ParseCertificate(rootCertificate.Certificate[0])
|
||||||
|
util.Check(err)
|
||||||
|
certPool.AddCert(cert)
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{*certificate},
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
RootCAs: certPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(dtlsConn.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Connected; type 'exit' to shutdown gracefully")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
util.Chat(dtlsConn)
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
// Create parent context to cleanup handshaking connections on exit.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
PSK: func(hint []byte) ([]byte, error) {
|
||||||
|
fmt.Printf("Client's hint: %s \n", hint)
|
||||||
|
return []byte{0xAB, 0xC1, 0x23}, nil
|
||||||
|
},
|
||||||
|
PSKIdentityHint: []byte("Pion DTLS Client"),
|
||||||
|
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8},
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
// Create timeout context for accepted connection.
|
||||||
|
ConnectContextMaker: func() (context.Context, func()) {
|
||||||
|
return context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
listener, err := dtls.Listen("udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(listener.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Listening")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
hub := util.NewHub()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
// Wait for a connection.
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
util.Check(err)
|
||||||
|
// defer conn.Close() // TODO: graceful shutdown
|
||||||
|
|
||||||
|
// `conn` is of type `net.Conn` but may be casted to `dtls.Conn`
|
||||||
|
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
|
||||||
|
// functions like `ConnectionState` etc.
|
||||||
|
|
||||||
|
// Register the connection with the chat hub
|
||||||
|
if err == nil {
|
||||||
|
hub.Register(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start chatting
|
||||||
|
hub.Chat()
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
// Generate a certificate and private key to secure the connection
|
||||||
|
certificate, genErr := selfsign.GenerateSelfSigned()
|
||||||
|
util.Check(genErr)
|
||||||
|
|
||||||
|
// Create parent context to cleanup handshaking connections on exit.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{certificate},
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
// Create timeout context for accepted connection.
|
||||||
|
ConnectContextMaker: func() (context.Context, func()) {
|
||||||
|
return context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
listener, err := dtls.Listen("udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(listener.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Listening")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
hub := util.NewHub()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
// Wait for a connection.
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
util.Check(err)
|
||||||
|
// defer conn.Close() // TODO: graceful shutdown
|
||||||
|
|
||||||
|
// `conn` is of type `net.Conn` but may be casted to `dtls.Conn`
|
||||||
|
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
|
||||||
|
// functions like `ConnectionState` etc.
|
||||||
|
|
||||||
|
// Register the connection with the chat hub
|
||||||
|
if err == nil {
|
||||||
|
hub.Register(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start chatting
|
||||||
|
hub.Chat()
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2"
|
||||||
|
"github.com/pion/dtls/v2/examples/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Prepare the IP to connect to
|
||||||
|
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}
|
||||||
|
|
||||||
|
// Create parent context to cleanup handshaking connections on exit.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//
|
||||||
|
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
|
||||||
|
//
|
||||||
|
|
||||||
|
certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem",
|
||||||
|
"examples/certificates/server.pub.pem")
|
||||||
|
util.Check(err)
|
||||||
|
|
||||||
|
rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem")
|
||||||
|
util.Check(err)
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
cert, err := x509.ParseCertificate(rootCertificate.Certificate[0])
|
||||||
|
util.Check(err)
|
||||||
|
certPool.AddCert(cert)
|
||||||
|
|
||||||
|
// Prepare the configuration of the DTLS connection
|
||||||
|
config := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{*certificate},
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
ClientAuth: dtls.RequireAndVerifyClientCert,
|
||||||
|
ClientCAs: certPool,
|
||||||
|
// Create timeout context for accepted connection.
|
||||||
|
ConnectContextMaker: func() (context.Context, func()) {
|
||||||
|
return context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a DTLS server
|
||||||
|
listener, err := dtls.Listen("udp", addr, config)
|
||||||
|
util.Check(err)
|
||||||
|
defer func() {
|
||||||
|
util.Check(listener.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Println("Listening")
|
||||||
|
|
||||||
|
// Simulate a chat session
|
||||||
|
hub := util.NewHub()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
// Wait for a connection.
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
util.Check(err)
|
||||||
|
// defer conn.Close() // TODO: graceful shutdown
|
||||||
|
|
||||||
|
// `conn` is of type `net.Conn` but may be casted to `dtls.Conn`
|
||||||
|
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
|
||||||
|
// functions like `ConnectionState` etc.
|
||||||
|
|
||||||
|
// Register the connection with the chat hub
|
||||||
|
hub.Register(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start chatting
|
||||||
|
hub.Chat()
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hub is a helper to handle one to many chat
|
||||||
|
type Hub struct {
|
||||||
|
conns map[string]net.Conn
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHub builds a new hub
|
||||||
|
func NewHub() *Hub {
|
||||||
|
return &Hub{conns: make(map[string]net.Conn)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a new conn to the Hub
|
||||||
|
func (h *Hub) Register(conn net.Conn) {
|
||||||
|
fmt.Printf("Connected to %s\n", conn.RemoteAddr())
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
h.conns[conn.RemoteAddr().String()] = conn
|
||||||
|
|
||||||
|
go h.readLoop(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) readLoop(conn net.Conn) {
|
||||||
|
b := make([]byte, bufSize)
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
h.unregister(conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Got message: %s\n", string(b[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) unregister(conn net.Conn) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
delete(h.conns, conn.RemoteAddr().String())
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to disconnect", conn.RemoteAddr(), err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("Disconnected ", conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) broadcast(msg []byte) {
|
||||||
|
h.lock.RLock()
|
||||||
|
defer h.lock.RUnlock()
|
||||||
|
for _, conn := range h.conns {
|
||||||
|
_, err := conn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to write message to %s: %v\n", conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat starts the stdin readloop to dispatch messages to the hub
|
||||||
|
func (h *Hub) Chat() {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
for {
|
||||||
|
msg, err := reader.ReadString('\n')
|
||||||
|
Check(err)
|
||||||
|
if strings.TrimSpace(msg) == "exit" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.broadcast([]byte(msg))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,154 @@
|
||||||
|
// Package util provides auxiliary utilities used in examples
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bufSize = 8192
|
||||||
|
|
||||||
|
var (
|
||||||
|
errBlockIsNotPrivateKey = errors.New("block is not a private key, unable to load key")
|
||||||
|
errUnknownKeyTime = errors.New("unknown key time in PKCS#8 wrapping, unable to load key")
|
||||||
|
errNoPrivateKeyFound = errors.New("no private key found, unable to load key")
|
||||||
|
errBlockIsNotCertificate = errors.New("block is not a certificate, unable to load certificates")
|
||||||
|
errNoCertificateFound = errors.New("no certificate found, unable to load certificates")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Chat simulates a simple text chat session over the connection
|
||||||
|
func Chat(conn io.ReadWriter) {
|
||||||
|
go func() {
|
||||||
|
b := make([]byte, bufSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(b)
|
||||||
|
Check(err)
|
||||||
|
fmt.Printf("Got message: %s\n", string(b[:n]))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
for {
|
||||||
|
text, err := reader.ReadString('\n')
|
||||||
|
Check(err)
|
||||||
|
|
||||||
|
if strings.TrimSpace(text) == "exit" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte(text))
|
||||||
|
Check(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check is a helper to throw errors in the examples
|
||||||
|
func Check(err error) {
|
||||||
|
switch e := err.(type) {
|
||||||
|
case nil:
|
||||||
|
case (net.Error):
|
||||||
|
if e.Temporary() {
|
||||||
|
fmt.Printf("Warning: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("net.Error: %v\n", err)
|
||||||
|
panic(err)
|
||||||
|
default:
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKeyAndCertificate reads certificates or key from file
|
||||||
|
func LoadKeyAndCertificate(keyPath string, certificatePath string) (*tls.Certificate, error) {
|
||||||
|
privateKey, err := LoadKey(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certificate, err := LoadCertificate(certificatePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certificate.PrivateKey = privateKey
|
||||||
|
|
||||||
|
return certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKey Load/read key from file
|
||||||
|
func LoadKey(path string) (crypto.PrivateKey, error) {
|
||||||
|
rawData, err := ioutil.ReadFile(filepath.Clean(path))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(rawData)
|
||||||
|
if block == nil || !strings.HasSuffix(block.Type, "PRIVATE KEY") {
|
||||||
|
return nil, errBlockIsNotPrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||||
|
switch key := key.(type) {
|
||||||
|
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
||||||
|
return key, nil
|
||||||
|
default:
|
||||||
|
return nil, errUnknownKeyTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errNoPrivateKeyFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCertificate Load/read certificate(s) from file
|
||||||
|
func LoadCertificate(path string) (*tls.Certificate, error) {
|
||||||
|
rawData, err := ioutil.ReadFile(filepath.Clean(path))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var certificate tls.Certificate
|
||||||
|
|
||||||
|
for {
|
||||||
|
block, rest := pem.Decode(rawData)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
return nil, errBlockIsNotCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
certificate.Certificate = append(certificate.Certificate, block.Bytes)
|
||||||
|
rawData = rest
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(certificate.Certificate) == 0 {
|
||||||
|
return nil, errNoCertificateFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &certificate, nil
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
/*
|
||||||
|
DTLS messages are grouped into a series of message flights, according
|
||||||
|
to the diagrams below. Although each flight of messages may consist
|
||||||
|
of a number of messages, they should be viewed as monolithic for the
|
||||||
|
purpose of timeout and retransmission.
|
||||||
|
https://tools.ietf.org/html/rfc4347#section-4.2.4
|
||||||
|
Client Server
|
||||||
|
------ ------
|
||||||
|
Waiting Flight 0
|
||||||
|
|
||||||
|
ClientHello --------> Flight 1
|
||||||
|
|
||||||
|
<------- HelloVerifyRequest Flight 2
|
||||||
|
|
||||||
|
ClientHello --------> Flight 3
|
||||||
|
|
||||||
|
ServerHello \
|
||||||
|
Certificate* \
|
||||||
|
ServerKeyExchange* Flight 4
|
||||||
|
CertificateRequest* /
|
||||||
|
<-------- ServerHelloDone /
|
||||||
|
|
||||||
|
Certificate* \
|
||||||
|
ClientKeyExchange \
|
||||||
|
CertificateVerify* Flight 5
|
||||||
|
[ChangeCipherSpec] /
|
||||||
|
Finished --------> /
|
||||||
|
|
||||||
|
[ChangeCipherSpec] \ Flight 6
|
||||||
|
<-------- Finished /
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
type flightVal uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
flight0 flightVal = iota + 1
|
||||||
|
flight1
|
||||||
|
flight2
|
||||||
|
flight3
|
||||||
|
flight4
|
||||||
|
flight5
|
||||||
|
flight6
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f flightVal) String() string {
|
||||||
|
switch f {
|
||||||
|
case flight0:
|
||||||
|
return "Flight 0"
|
||||||
|
case flight1:
|
||||||
|
return "Flight 1"
|
||||||
|
case flight2:
|
||||||
|
return "Flight 2"
|
||||||
|
case flight3:
|
||||||
|
return "Flight 3"
|
||||||
|
case flight4:
|
||||||
|
return "Flight 4"
|
||||||
|
case flight5:
|
||||||
|
return "Flight 5"
|
||||||
|
case flight6:
|
||||||
|
return "Flight 6"
|
||||||
|
default:
|
||||||
|
return "Invalid Flight"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f flightVal) isLastSendFlight() bool {
|
||||||
|
return f == flight6
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f flightVal) isLastRecvFlight() bool {
|
||||||
|
return f == flight5
|
||||||
|
}
|
|
@ -0,0 +1,102 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/extension"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
|
||||||
|
seq, msgs, ok := cache.fullPullMap(0,
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
|
||||||
|
var clientHello *handshake.MessageClientHello
|
||||||
|
|
||||||
|
// Validate type
|
||||||
|
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !clientHello.Version.Equal(protocol.Version1_2) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
state.remoteRandom = clientHello.Random
|
||||||
|
|
||||||
|
cipherSuites := []CipherSuite{}
|
||||||
|
for _, id := range clientHello.CipherSuiteIDs {
|
||||||
|
if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
|
||||||
|
cipherSuites = append(cipherSuites, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, val := range clientHello.Extensions {
|
||||||
|
switch e := val.(type) {
|
||||||
|
case *extension.SupportedEllipticCurves:
|
||||||
|
if len(e.EllipticCurves) == 0 {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
|
||||||
|
}
|
||||||
|
state.namedCurve = e.EllipticCurves[0]
|
||||||
|
case *extension.UseSRTP:
|
||||||
|
profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
|
||||||
|
if !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
|
||||||
|
}
|
||||||
|
state.srtpProtectionProfile = profile
|
||||||
|
case *extension.UseExtendedMasterSecret:
|
||||||
|
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
|
||||||
|
state.extendedMasterSecret = true
|
||||||
|
}
|
||||||
|
case *extension.ServerName:
|
||||||
|
state.serverName = e.ServerName // remote server name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.localKeypair == nil {
|
||||||
|
var err error
|
||||||
|
state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return flight2, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
// Initialize
|
||||||
|
state.cookie = make([]byte, cookieLength)
|
||||||
|
if _, err := rand.Read(state.cookie); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var zeroEpoch uint16
|
||||||
|
state.localEpoch.Store(zeroEpoch)
|
||||||
|
state.remoteEpoch.Store(zeroEpoch)
|
||||||
|
state.namedCurve = defaultNamedCurve
|
||||||
|
|
||||||
|
if err := state.localRandom.Populate(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,112 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/extension"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
|
||||||
|
// HelloVerifyRequest can be skipped by the server,
|
||||||
|
// so allow ServerHello during flight1 also
|
||||||
|
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := msgs[handshake.TypeServerHello]; ok {
|
||||||
|
// Flight1 and flight2 were skipped.
|
||||||
|
// Parse as flight3.
|
||||||
|
return flight3Parse(ctx, c, state, cache, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
|
||||||
|
// DTLS 1.2 clients must not assume that the server will use the protocol version
|
||||||
|
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
|
||||||
|
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
|
||||||
|
}
|
||||||
|
state.cookie = append([]byte{}, h.Cookie...)
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
return flight3, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
var zeroEpoch uint16
|
||||||
|
state.localEpoch.Store(zeroEpoch)
|
||||||
|
state.remoteEpoch.Store(zeroEpoch)
|
||||||
|
state.namedCurve = defaultNamedCurve
|
||||||
|
state.cookie = nil
|
||||||
|
|
||||||
|
if err := state.localRandom.Populate(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
extensions := []extension.Extension{
|
||||||
|
&extension.SupportedSignatureAlgorithms{
|
||||||
|
SignatureHashAlgorithms: cfg.localSignatureSchemes,
|
||||||
|
},
|
||||||
|
&extension.RenegotiationInfo{
|
||||||
|
RenegotiatedConnection: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if cfg.localPSKCallback == nil {
|
||||||
|
extensions = append(extensions, []extension.Extension{
|
||||||
|
&extension.SupportedEllipticCurves{
|
||||||
|
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
|
||||||
|
},
|
||||||
|
&extension.SupportedPointFormats{
|
||||||
|
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.localSRTPProtectionProfiles) > 0 {
|
||||||
|
extensions = append(extensions, &extension.UseSRTP{
|
||||||
|
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
|
||||||
|
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
|
||||||
|
extensions = append(extensions, &extension.UseExtendedMasterSecret{
|
||||||
|
Supported: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.serverName) > 0 {
|
||||||
|
extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*packet{
|
||||||
|
{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageClientHello{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Cookie: state.cookie,
|
||||||
|
Random: state.localRandom,
|
||||||
|
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
|
||||||
|
CompressionMethods: defaultCompressionMethods(),
|
||||||
|
Extensions: extensions,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
|
||||||
|
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// Client may retransmit the first ClientHello when HelloVerifyRequest is dropped.
|
||||||
|
// Parse as flight 0 in this case.
|
||||||
|
return flight0Parse(ctx, c, state, cache, cfg)
|
||||||
|
}
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
|
||||||
|
var clientHello *handshake.MessageClientHello
|
||||||
|
|
||||||
|
// Validate type
|
||||||
|
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !clientHello.Version.Equal(protocol.Version1_2) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clientHello.Cookie) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if !bytes.Equal(state.cookie, clientHello.Cookie) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO 添加 CiscoCompat 支持
|
||||||
|
if cfg.localCiscoCompatCallback != nil {
|
||||||
|
var err error
|
||||||
|
state.SessionID = clientHello.SessionID
|
||||||
|
if len(state.SessionID) == 0 {
|
||||||
|
err = fmt.Errorf("clientHello SessionID is nil")
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state.masterSecret, err = cfg.localCiscoCompatCallback(state.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return flight4, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
state.handshakeSendSequence = 0
|
||||||
|
return []*packet{
|
||||||
|
{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageHelloVerifyRequest{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Cookie: state.cookie,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/prf"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/extension"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
|
||||||
|
// Clients may receive multiple HelloVerifyRequest messages with different cookies.
|
||||||
|
// Clients SHOULD handle this by sending a new ClientHello with a cookie in response
|
||||||
|
// to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
|
||||||
|
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
|
||||||
|
)
|
||||||
|
if ok {
|
||||||
|
if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
|
||||||
|
// DTLS 1.2 clients must not assume that the server will use the protocol version
|
||||||
|
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
|
||||||
|
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
|
||||||
|
}
|
||||||
|
state.cookie = append([]byte{}, h.Cookie...)
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
return flight3, nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.localPSKCallback != nil {
|
||||||
|
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
// Don't have enough messages. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
|
||||||
|
if h, ok := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); ok {
|
||||||
|
if !h.Version.Equal(protocol.Version1_2) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
|
||||||
|
}
|
||||||
|
for _, v := range h.Extensions {
|
||||||
|
switch e := v.(type) {
|
||||||
|
case *extension.UseSRTP:
|
||||||
|
profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
|
||||||
|
if !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
|
||||||
|
}
|
||||||
|
state.srtpProtectionProfile = profile
|
||||||
|
case *extension.UseExtendedMasterSecret:
|
||||||
|
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
|
||||||
|
state.extendedMasterSecret = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
|
||||||
|
}
|
||||||
|
if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
|
||||||
|
if remoteCipherSuite == nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedCipherSuite, ok := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
|
||||||
|
if !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
state.cipherSuite = selectedCipherSuite
|
||||||
|
state.remoteRandom = h.Random
|
||||||
|
cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
|
||||||
|
state.PeerCertificates = h.Certificate
|
||||||
|
} else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
|
||||||
|
alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
|
||||||
|
if err != nil {
|
||||||
|
return 0, alertPtr, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
|
||||||
|
state.remoteRequestedCertificate = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return flight5, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
|
||||||
|
var err error
|
||||||
|
if cfg.localPSKCallback != nil {
|
||||||
|
var psk []byte
|
||||||
|
if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
state.IdentityHint = h.IdentityHint
|
||||||
|
state.preMasterSecret = prf.PSKPreMasterSecret(psk)
|
||||||
|
} else {
|
||||||
|
if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
extensions := []extension.Extension{
|
||||||
|
&extension.SupportedSignatureAlgorithms{
|
||||||
|
SignatureHashAlgorithms: cfg.localSignatureSchemes,
|
||||||
|
},
|
||||||
|
&extension.RenegotiationInfo{
|
||||||
|
RenegotiatedConnection: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if cfg.localPSKCallback == nil {
|
||||||
|
extensions = append(extensions, []extension.Extension{
|
||||||
|
&extension.SupportedEllipticCurves{
|
||||||
|
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
|
||||||
|
},
|
||||||
|
&extension.SupportedPointFormats{
|
||||||
|
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.localSRTPProtectionProfiles) > 0 {
|
||||||
|
extensions = append(extensions, &extension.UseSRTP{
|
||||||
|
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
|
||||||
|
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
|
||||||
|
extensions = append(extensions, &extension.UseExtendedMasterSecret{
|
||||||
|
Supported: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.serverName) > 0 {
|
||||||
|
extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*packet{
|
||||||
|
{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageClientHello{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Cookie: state.cookie,
|
||||||
|
Random: state.localRandom,
|
||||||
|
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
|
||||||
|
CompressionMethods: defaultCompressionMethods(),
|
||||||
|
Extensions: extensions,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,352 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/prf"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/extension"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
|
||||||
|
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate type
|
||||||
|
var clientKeyExchange *handshake.MessageClientKeyExchange
|
||||||
|
if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert {
|
||||||
|
state.PeerCertificates = h.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify {
|
||||||
|
if state.PeerCertificates == nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
plainText := cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify that the pair of hash algorithm and signiture is listed.
|
||||||
|
var validSignatureScheme bool
|
||||||
|
for _, ss := range cfg.localSignatureSchemes {
|
||||||
|
if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
|
||||||
|
validSignatureScheme = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validSignatureScheme {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
var chains [][]*x509.Certificate
|
||||||
|
var err error
|
||||||
|
var verified bool
|
||||||
|
if cfg.clientAuth >= VerifyClientCertIfGiven {
|
||||||
|
if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
verified = true
|
||||||
|
}
|
||||||
|
if cfg.verifyPeerCertificate != nil {
|
||||||
|
if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.peerCertificatesVerified = verified
|
||||||
|
}
|
||||||
|
|
||||||
|
if !state.cipherSuite.IsInitialized() {
|
||||||
|
serverRandom := state.localRandom.MarshalFixed()
|
||||||
|
clientRandom := state.remoteRandom.MarshalFixed()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var preMasterSecret []byte
|
||||||
|
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
|
||||||
|
var psk []byte
|
||||||
|
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
state.IdentityHint = clientKeyExchange.IdentityHint
|
||||||
|
preMasterSecret = prf.PSKPreMasterSecret(psk)
|
||||||
|
} else {
|
||||||
|
preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.extendedMasterSecret {
|
||||||
|
var sessionHash []byte
|
||||||
|
sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, encrypted packets can be handled
|
||||||
|
if err := c.handleQueuedPackets(ctx); err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
seq, msgs, ok = cache.fullPullMap(seq,
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
state.handshakeRecvSequence = seq
|
||||||
|
|
||||||
|
if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
|
||||||
|
return flight6, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cfg.clientAuth {
|
||||||
|
case RequireAnyClientCert:
|
||||||
|
if state.PeerCertificates == nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
|
||||||
|
}
|
||||||
|
case VerifyClientCertIfGiven:
|
||||||
|
if state.PeerCertificates != nil && !state.peerCertificatesVerified {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
|
||||||
|
}
|
||||||
|
case RequireAndVerifyClientCert:
|
||||||
|
if state.PeerCertificates == nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
|
||||||
|
}
|
||||||
|
if !state.peerCertificatesVerified {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
|
||||||
|
}
|
||||||
|
case NoClientCert, RequestClientCert:
|
||||||
|
return flight6, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return flight6, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
extensions := []extension.Extension{&extension.RenegotiationInfo{
|
||||||
|
RenegotiatedConnection: 0,
|
||||||
|
}}
|
||||||
|
if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
|
||||||
|
cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
|
||||||
|
extensions = append(extensions, &extension.UseExtendedMasterSecret{
|
||||||
|
Supported: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if state.srtpProtectionProfile != 0 {
|
||||||
|
extensions = append(extensions, &extension.UseSRTP{
|
||||||
|
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
|
||||||
|
extensions = append(extensions, []extension.Extension{
|
||||||
|
&extension.SupportedEllipticCurves{
|
||||||
|
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
|
||||||
|
},
|
||||||
|
&extension.SupportedPointFormats{
|
||||||
|
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkts []*packet
|
||||||
|
cipherSuiteID := uint16(state.cipherSuite.ID())
|
||||||
|
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageServerHello{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Random: state.localRandom,
|
||||||
|
SessionID: state.SessionID,
|
||||||
|
CipherSuiteID: &cipherSuiteID,
|
||||||
|
CompressionMethod: defaultCompressionMethods()[0],
|
||||||
|
Extensions: extensions,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// TODO 添加 CiscoCompat 支持
|
||||||
|
if cfg.localCiscoCompatCallback != nil {
|
||||||
|
if !state.cipherSuite.IsInitialized() {
|
||||||
|
serverRandom := state.localRandom.MarshalFixed()
|
||||||
|
clientRandom := state.remoteRandom.MarshalFixed()
|
||||||
|
|
||||||
|
if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
|
||||||
|
}
|
||||||
|
return pkts, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
|
||||||
|
certificate, err := cfg.getCertificate(cfg.serverName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageCertificate{
|
||||||
|
Certificate: certificate.Certificate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
serverRandom := state.localRandom.MarshalFixed()
|
||||||
|
clientRandom := state.remoteRandom.MarshalFixed()
|
||||||
|
|
||||||
|
// Find compatible signature scheme
|
||||||
|
signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
state.localKeySignature = signature
|
||||||
|
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageServerKeyExchange{
|
||||||
|
EllipticCurveType: elliptic.CurveTypeNamedCurve,
|
||||||
|
NamedCurve: state.namedCurve,
|
||||||
|
PublicKey: state.localKeypair.PublicKey,
|
||||||
|
HashAlgorithm: signatureHashAlgo.Hash,
|
||||||
|
SignatureAlgorithm: signatureHashAlgo.Signature,
|
||||||
|
Signature: state.localKeySignature,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if cfg.clientAuth > NoClientCert {
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageCertificateRequest{
|
||||||
|
CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign},
|
||||||
|
SignatureHashAlgorithms: cfg.localSignatureSchemes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case cfg.localPSKIdentityHint != nil:
|
||||||
|
// To help the client in selecting which identity to use, the server
|
||||||
|
// can provide a "PSK identity hint" in the ServerKeyExchange message.
|
||||||
|
// If no hint is provided, the ServerKeyExchange message is omitted.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc4279#section-2
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageServerKeyExchange{
|
||||||
|
IdentityHint: cfg.localPSKIdentityHint,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous:
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageServerKeyExchange{
|
||||||
|
EllipticCurveType: elliptic.CurveTypeNamedCurve,
|
||||||
|
NamedCurve: state.namedCurve,
|
||||||
|
PublicKey: state.localKeypair.PublicKey,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts, &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageServerHelloDone{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return pkts, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,323 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/prf"
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
|
||||||
|
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var finished *handshake.MessageFinished
|
||||||
|
if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
plainText := cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
|
||||||
|
)
|
||||||
|
|
||||||
|
expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
return flight5, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
|
||||||
|
var certBytes [][]byte
|
||||||
|
var privateKey crypto.PrivateKey
|
||||||
|
if len(cfg.localCertificates) > 0 {
|
||||||
|
certificate, err := cfg.getCertificate(cfg.serverName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
|
||||||
|
}
|
||||||
|
certBytes = certificate.Certificate
|
||||||
|
privateKey = certificate.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkts []*packet
|
||||||
|
|
||||||
|
if state.remoteRequestedCertificate {
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageCertificate{
|
||||||
|
Certificate: certBytes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
clientKeyExchange := &handshake.MessageClientKeyExchange{}
|
||||||
|
if cfg.localPSKCallback == nil {
|
||||||
|
clientKeyExchange.PublicKey = state.localKeypair.PublicKey
|
||||||
|
} else {
|
||||||
|
clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: clientKeyExchange,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
serverKeyExchangeData := cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
)
|
||||||
|
|
||||||
|
serverKeyExchange := &handshake.MessageServerKeyExchange{}
|
||||||
|
|
||||||
|
// handshakeMessageServerKeyExchange is optional for PSK
|
||||||
|
if len(serverKeyExchangeData) == 0 {
|
||||||
|
alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, alertPtr, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rawHandshake := &handshake.Handshake{}
|
||||||
|
err := rawHandshake.Unmarshal(serverKeyExchangeData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h := rawHandshake.Message.(type) {
|
||||||
|
case *handshake.MessageServerKeyExchange:
|
||||||
|
serverKeyExchange = h
|
||||||
|
default:
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append not-yet-sent packets
|
||||||
|
merged := []byte{}
|
||||||
|
seqPred := uint16(state.handshakeSendSequence)
|
||||||
|
for _, p := range pkts {
|
||||||
|
h, ok := p.record.Content.(*handshake.Handshake)
|
||||||
|
if !ok {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
|
||||||
|
}
|
||||||
|
h.Header.MessageSequence = seqPred
|
||||||
|
seqPred++
|
||||||
|
raw, err := h.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
merged = append(merged, raw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
|
||||||
|
return nil, alertPtr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client has sent a certificate with signing ability, a digitally-signed
|
||||||
|
// CertificateVerify message is sent to explicitly verify possession of the
|
||||||
|
// private key in the certificate.
|
||||||
|
if state.remoteRequestedCertificate && len(cfg.localCertificates) > 0 {
|
||||||
|
plainText := append(cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
), merged...)
|
||||||
|
|
||||||
|
// Find compatible signature scheme
|
||||||
|
signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
state.localCertificatesVerify = certVerify
|
||||||
|
|
||||||
|
p := &packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageCertificateVerify{
|
||||||
|
HashAlgorithm: signatureHashAlgo.Hash,
|
||||||
|
SignatureAlgorithm: signatureHashAlgo.Signature,
|
||||||
|
Signature: state.localCertificatesVerify,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pkts = append(pkts, p)
|
||||||
|
|
||||||
|
h, ok := p.record.Content.(*handshake.Handshake)
|
||||||
|
if !ok {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
|
||||||
|
}
|
||||||
|
h.Header.MessageSequence = seqPred
|
||||||
|
// seqPred++ // this is the last use of seqPred
|
||||||
|
raw, err := h.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
merged = append(merged, raw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &protocol.ChangeCipherSpec{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(state.localVerifyData) == 0 {
|
||||||
|
plainText := cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
|
||||||
|
)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Epoch: 1,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageFinished{
|
||||||
|
VerifyData: state.localVerifyData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldEncrypt: true,
|
||||||
|
resetLocalSequenceNumber: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
return pkts, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
|
||||||
|
if state.cipherSuite.IsInitialized() {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clientRandom := state.localRandom.MarshalFixed()
|
||||||
|
serverRandom := state.remoteRandom.MarshalFixed()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if state.extendedMasterSecret {
|
||||||
|
var sessionHash []byte
|
||||||
|
sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
|
||||||
|
if err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
|
||||||
|
// Verify that the pair of hash algorithm and signiture is listed.
|
||||||
|
var validSignatureScheme bool
|
||||||
|
for _, ss := range cfg.localSignatureSchemes {
|
||||||
|
if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
|
||||||
|
validSignatureScheme = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validSignatureScheme {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
|
||||||
|
if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
var chains [][]*x509.Certificate
|
||||||
|
if !cfg.insecureSkipVerify {
|
||||||
|
if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.verifyPeerCertificate != nil {
|
||||||
|
if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
|
||||||
|
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/crypto/prf"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func flight6Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
|
||||||
|
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1,
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
|
||||||
|
)
|
||||||
|
if !ok {
|
||||||
|
// No valid message received. Keep reading
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
|
||||||
|
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other party retransmitted the last flight.
|
||||||
|
return flight6, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flight6Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||||
|
var pkts []*packet
|
||||||
|
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
},
|
||||||
|
Content: &protocol.ChangeCipherSpec{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(state.localVerifyData) == 0 {
|
||||||
|
plainText := cache.pullAndMerge(
|
||||||
|
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
|
||||||
|
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
|
||||||
|
)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
|
||||||
|
if err != nil {
|
||||||
|
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pkts = append(pkts,
|
||||||
|
&packet{
|
||||||
|
record: &recordlayer.RecordLayer{
|
||||||
|
Header: recordlayer.Header{
|
||||||
|
Version: protocol.Version1_2,
|
||||||
|
Epoch: 1,
|
||||||
|
},
|
||||||
|
Content: &handshake.Handshake{
|
||||||
|
Message: &handshake.MessageFinished{
|
||||||
|
VerifyData: state.localVerifyData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldEncrypt: true,
|
||||||
|
resetLocalSequenceNumber: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return pkts, nil, nil
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/alert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse received handshakes and return next flightVal
|
||||||
|
type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert.Alert, error)
|
||||||
|
|
||||||
|
// Generate flights
|
||||||
|
type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error)
|
||||||
|
|
||||||
|
func (f flightVal) getFlightParser() (flightParser, error) {
|
||||||
|
switch f {
|
||||||
|
case flight0:
|
||||||
|
return flight0Parse, nil
|
||||||
|
case flight1:
|
||||||
|
return flight1Parse, nil
|
||||||
|
case flight2:
|
||||||
|
return flight2Parse, nil
|
||||||
|
case flight3:
|
||||||
|
return flight3Parse, nil
|
||||||
|
case flight4:
|
||||||
|
return flight4Parse, nil
|
||||||
|
case flight5:
|
||||||
|
return flight5Parse, nil
|
||||||
|
case flight6:
|
||||||
|
return flight6Parse, nil
|
||||||
|
default:
|
||||||
|
return nil, errInvalidFlight
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) {
|
||||||
|
switch f {
|
||||||
|
case flight0:
|
||||||
|
return flight0Generate, true, nil
|
||||||
|
case flight1:
|
||||||
|
return flight1Generate, true, nil
|
||||||
|
case flight2:
|
||||||
|
// https://tools.ietf.org/html/rfc6347#section-3.2.1
|
||||||
|
// HelloVerifyRequests must not be retransmitted.
|
||||||
|
return flight2Generate, false, nil
|
||||||
|
case flight3:
|
||||||
|
return flight3Generate, true, nil
|
||||||
|
case flight4:
|
||||||
|
return flight4Generate, true, nil
|
||||||
|
case flight5:
|
||||||
|
return flight5Generate, true, nil
|
||||||
|
case flight6:
|
||||||
|
return flight6Generate, true, nil
|
||||||
|
default:
|
||||||
|
return nil, false, errInvalidFlight
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/handshake"
|
||||||
|
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fragment struct {
|
||||||
|
recordLayerHeader recordlayer.Header
|
||||||
|
handshakeHeader handshake.Header
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type fragmentBuffer struct {
|
||||||
|
// map of MessageSequenceNumbers that hold slices of fragments
|
||||||
|
cache map[uint16][]*fragment
|
||||||
|
|
||||||
|
currentMessageSequenceNumber uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFragmentBuffer() *fragmentBuffer {
|
||||||
|
return &fragmentBuffer{cache: map[uint16][]*fragment{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempts to push a DTLS packet to the fragmentBuffer
|
||||||
|
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
|
||||||
|
// when an error returns it is fatal, and the DTLS connection should be stopped
|
||||||
|
func (f *fragmentBuffer) push(buf []byte) (bool, error) {
|
||||||
|
frag := new(fragment)
|
||||||
|
if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// fragment isn't a handshake, we don't need to handle it
|
||||||
|
if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
|
||||||
|
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
|
||||||
|
f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// end index should be the length of handshake header but if the handshake
|
||||||
|
// was fragmented, we should keep them all
|
||||||
|
end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
|
||||||
|
if size := len(buf); end > size {
|
||||||
|
end = size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discard all headers, when rebuilding the packet we will re-build
|
||||||
|
frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
|
||||||
|
f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
|
||||||
|
buf = buf[end:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
|
||||||
|
frags, ok := f.cache[f.currentMessageSequenceNumber]
|
||||||
|
if !ok {
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Go doesn't support recursive lambdas
|
||||||
|
var appendMessage func(targetOffset uint32) bool
|
||||||
|
|
||||||
|
rawMessage := []byte{}
|
||||||
|
appendMessage = func(targetOffset uint32) bool {
|
||||||
|
for _, f := range frags {
|
||||||
|
if f.handshakeHeader.FragmentOffset == targetOffset {
|
||||||
|
fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
|
||||||
|
if fragmentEnd != f.handshakeHeader.Length {
|
||||||
|
if !appendMessage(fragmentEnd) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMessage = append(f.data, rawMessage...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively collect up
|
||||||
|
if !appendMessage(0) {
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
firstHeader := frags[0].handshakeHeader
|
||||||
|
firstHeader.FragmentOffset = 0
|
||||||
|
firstHeader.FragmentLength = firstHeader.Length
|
||||||
|
|
||||||
|
rawHeader, err := firstHeader.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
messageEpoch := frags[0].recordLayerHeader.Epoch
|
||||||
|
|
||||||
|
delete(f.cache, f.currentMessageSequenceNumber)
|
||||||
|
f.currentMessageSequenceNumber++
|
||||||
|
return append(rawHeader, rawMessage...), messageEpoch
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFragmentBuffer(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
Name string
|
||||||
|
In [][]byte
|
||||||
|
Expected [][]byte
|
||||||
|
Epoch uint16
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Single Fragment",
|
||||||
|
In: [][]byte{
|
||||||
|
{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00},
|
||||||
|
},
|
||||||
|
Expected: [][]byte{
|
||||||
|
{0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00},
|
||||||
|
},
|
||||||
|
Epoch: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Single Fragment Epoch 3",
|
||||||
|
In: [][]byte{
|
||||||
|
{0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00},
|
||||||
|
},
|
||||||
|
Expected: [][]byte{
|
||||||
|
{0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00},
|
||||||
|
},
|
||||||
|
Epoch: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Multiple Fragments",
|
||||||
|
In: [][]byte{
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04},
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09},
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E},
|
||||||
|
},
|
||||||
|
Expected: [][]byte{
|
||||||
|
{0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e},
|
||||||
|
},
|
||||||
|
Epoch: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Multiple Unordered Fragments",
|
||||||
|
In: [][]byte{
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04},
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E},
|
||||||
|
{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09},
|
||||||
|
},
|
||||||
|
Expected: [][]byte{
|
||||||
|
{0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e},
|
||||||
|
},
|
||||||
|
Epoch: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Multiple Handshakes in Signle Fragment",
|
||||||
|
In: [][]byte{
|
||||||
|
{
|
||||||
|
0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, /* record header */
|
||||||
|
0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 1*/
|
||||||
|
0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 2*/
|
||||||
|
0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 3*/
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Expected: [][]byte{
|
||||||
|
{0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01},
|
||||||
|
{0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01},
|
||||||
|
{0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01},
|
||||||
|
},
|
||||||
|
Epoch: 0,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
fragmentBuffer := newFragmentBuffer()
|
||||||
|
for _, frag := range test.In {
|
||||||
|
status, err := fragmentBuffer.push(frag)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if !status {
|
||||||
|
t.Errorf("fragmentBuffer didn't accept fragments for '%s'", test.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range test.Expected {
|
||||||
|
out, epoch := fragmentBuffer.pop()
|
||||||
|
if !reflect.DeepEqual(out, expected) {
|
||||||
|
t.Errorf("fragmentBuffer '%s' push/pop: got % 02x, want % 02x", test.Name, out, expected)
|
||||||
|
}
|
||||||
|
if epoch != test.Epoch {
|
||||||
|
t.Errorf("fragmentBuffer returned wrong epoch: got %d, want %d", epoch, test.Epoch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if frag, _ := fragmentBuffer.pop(); frag != nil {
|
||||||
|
t.Errorf("fragmentBuffer popped single buffer multiple times for '%s'", test.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
// +build gofuzz
|
||||||
|
|
||||||
|
package dtls
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func partialHeaderMismatch(a, b recordlayer.Header) bool {
|
||||||
|
// Ignoring content length for now.
|
||||||
|
a.contentLen = b.contentLen
|
||||||
|
return a != b
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzRecordLayer(data []byte) int {
|
||||||
|
var r recordLayer
|
||||||
|
if err := r.Unmarshal(data); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
buf, err := r.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if len(buf) == 0 {
|
||||||
|
panic("zero buff") // nolint
|
||||||
|
}
|
||||||
|
var nr recordLayer
|
||||||
|
if err = nr.Unmarshal(data); err != nil {
|
||||||
|
panic(err) // nolint
|
||||||
|
}
|
||||||
|
if partialHeaderMismatch(nr.recordlayer.Header, r.recordlayer.Header) {
|
||||||
|
panic( // nolint
|
||||||
|
fmt.Sprintf("header mismatch: %+v != %+v",
|
||||||
|
nr.recordlayer.Header, r.recordlayer.Header,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1 @@
|
||||||
|
Ñ12‡™ŠÇ[A51
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1 @@
|
||||||
|
864797660130
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue