添加dtls支持

This commit is contained in:
bjdgyc
2021-05-21 19:00:23 +08:00
parent 07163aa33b
commit c8f090c9e3
17 changed files with 279 additions and 183 deletions

View File

@@ -1,59 +1,78 @@
package handler
import (
"context"
"crypto/tls"
"encoding/hex"
"log"
"fmt"
"net"
"time"
"os"
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/sessdata"
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/logging"
)
// 因本项目对 github.com/pion/dtls 的代码,进行了大量的修改
// 且短时间内无法合并到上游项目
// 所以本项目暂时copy了一份代码
// 最后,感谢 github.com/pion/dtls 对golang生态做出的贡献
func startDtls() {
certificate, err := selfsign.GenerateSelfSigned()
logf := logging.NewDefaultLoggerFactory()
logf.DefaultLogLevel = logging.LogLevelTrace
f, err := os.OpenFile("/tmp/key.log", os.O_TRUNC|os.O_RDWR, 0600)
if err != nil {
panic(err)
}
logf := logging.NewDefaultLoggerFactory()
logf.Writer = base.GetBaseLw()
// logf.DefaultLogLevel = logging.LogLevelTrace
logf.DefaultLogLevel = logging.LogLevelInfo
config := &dtls.Config{
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: true,
ExtendedMasterSecret: dtls.DisableExtendedMasterSecret,
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
LoggerFactory: logf,
KeyLogWriter: f,
MTU: BufferSize,
CiscoCompat: func(sessid []byte) ([]byte, error) {
masterSecret := sessdata.Dtls2MasterSecret(hex.EncodeToString(sessid))
if masterSecret == "" {
return nil, fmt.Errorf("masterSecret is err")
}
return hex.DecodeString(masterSecret)
},
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 5*time.Second)
},
}
addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 4433}
addr, err := net.ResolveUDPAddr("udp", base.Cfg.ServerDTLSAddr)
if err != nil {
panic(err)
}
ln, err := dtls.Listen("udp", addr, config)
if err != nil {
panic(err)
}
base.Info("listen DTLS server", addr)
for {
c, err := ln.Accept()
conn, err := ln.Accept()
if err != nil {
log.Println("Accept error", err)
base.Error("DTLS Accept error", err)
continue
}
go func() {
time.Sleep(1 * time.Second)
cc := c.(*dtls.Conn)
id := hex.EncodeToString(cc.ConnectionState().SessionID)
s, ok := ss.Load(id)
log.Println("get link", id, ok)
cs := s.(*sessdata.ConnSession)
LinkDtls(c, cs)
// time.Sleep(1 * time.Second)
cc := conn.(*dtls.Conn)
sessid := hex.EncodeToString(cc.ConnectionState().SessionID)
sess := sessdata.Dtls2Sess(sessid)
LinkDtls(conn, sess.CSess)
}()
}
}

View File

@@ -55,7 +55,7 @@ func LinkCstp(conn net.Conn, cSess *sessdata.ConnSession) {
return
case 0x03: // DPD-REQ
// base.Debug("recv DPD-REQ", cSess.IpAddr)
if payloadOut(cSess, sessdata.LTypeIPData, 0x04, nil) {
if payloadOutCstp(cSess, sessdata.LTypeIPData, 0x04, nil) {
return
}
case 0x04:
@@ -86,7 +86,7 @@ func cstpWrite(conn net.Conn, cSess *sessdata.ConnSession) {
for {
select {
case payload = <-cSess.PayloadOut:
case payload = <-cSess.PayloadOutCstp:
case <-cSess.CloseChan:
return
}

View File

@@ -9,19 +9,33 @@ import (
)
func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
dSess := cSess.NewDtlsConn()
if dSess == nil {
// 创建失败,直接关闭链接
_ = conn.Close()
return
}
defer func() {
base.Debug("LinkDtls return", cSess.IpAddr)
_ = conn.Close()
cSess.Close()
dSess.Close()
}()
var (
dead = time.Duration(cSess.CstpDpd+5) * time.Second
)
go dtlsWrite(conn, cSess)
go dtlsWrite(conn, dSess, cSess)
now := time.Now()
for {
if time.Now().Sub(now) > time.Second*30 {
// return
}
err := conn.SetReadDeadline(time.Now().Add(dead))
if err != nil {
base.Error("SetDeadline: ", err)
@@ -48,26 +62,33 @@ func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
base.Debug("DISCONNECT", cSess.IpAddr)
return
case 0x03: // DPD-REQ
base.Debug("recv DPD-REQ", cSess.IpAddr)
if payloadOut(cSess, sessdata.LTypeIPData, 0x04, nil) {
// base.Debug("recv DPD-REQ", cSess.IpAddr)
payload := &sessdata.Payload{
LType: sessdata.LTypeIPData,
PType: 0x04,
Data: nil,
}
select {
case cSess.PayloadOutDtls <- payload:
case <-dSess.CloseChan:
return
}
case 0x04:
base.Debug("recv DPD-RESP", cSess.IpAddr)
// base.Debug("recv DPD-RESP", cSess.IpAddr)
case 0x00: // DATA
if payloadIn(cSess, sessdata.LTypeIPData, 0x00, hdata[1:]) {
if payloadIn(cSess, sessdata.LTypeIPData, 0x00, hdata[1:n]) {
return
}
}
}
}
func dtlsWrite(conn net.Conn, cSess *sessdata.ConnSession) {
func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnSession) {
defer func() {
base.Debug("dtlsWrite return", cSess.IpAddr)
_ = conn.Close()
cSess.Close()
dSess.Close()
}()
var (
@@ -76,9 +97,10 @@ func dtlsWrite(conn net.Conn, cSess *sessdata.ConnSession) {
)
for {
// dtls优先推送数据
select {
case payload = <-cSess.PayloadOut:
case <-cSess.CloseChan:
case payload = <-cSess.PayloadOutDtls:
case <-dSess.CloseChan:
return
}

View File

@@ -46,7 +46,7 @@ func LinkTun(cSess *sessdata.ConnSession) error {
cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast off", ifce.Name(), cSess.Mtu)
cmdstr2 := fmt.Sprintf("ip addr add dev %s local %s peer %s/32",
ifce.Name(), base.Cfg.Ipv4Gateway, cSess.IpAddr)
cmdstr3 := "true"
cmdstr3 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name())
cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3}
err = execCmd(cmdStrs)
if err != nil {

View File

@@ -2,23 +2,20 @@ package handler
import (
"bytes"
"encoding/hex"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
"os"
"sync"
"strings"
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/sessdata"
"github.com/pion/dtls/v2"
)
var hn string
var ss sync.Map
var (
hn string
)
func init() {
// 获取主机名称
@@ -27,9 +24,9 @@ func init() {
func LinkTunnel(w http.ResponseWriter, r *http.Request) {
// TODO 调试信息输出
hd, _ := httputil.DumpRequest(r, true)
fmt.Println("DumpRequest: ", string(hd))
fmt.Println("LinkTunnel", r.RemoteAddr)
// hd, _ := httputil.DumpRequest(r, true)
// fmt.Println("DumpRequest: ", string(hd))
// fmt.Println("LinkTunnel", r.RemoteAddr)
// 判断session-token的值
cookie, err := r.Cookie("webvpn")
@@ -58,14 +55,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
localIp := r.Header.Get("X-Cstp-Local-Address-Ip4")
mobile := r.Header.Get("X-Cstp-License")
preMasterSecret, err := hex.DecodeString(masterSecret)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
return
}
dtls.Sessions.Store(sess.DtlsSid, preMasterSecret)
cSess.SetMtu(cstpMtu)
cSess.MasterSecret = masterSecret
cSess.RemoteAddr = r.RemoteAddr
@@ -81,6 +70,12 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
}
cSess.CstpDpd = cstpDpd
dtlsPort := ""
if strings.Contains(base.Cfg.ServerDTLSAddr, ":") {
ss := strings.Split(base.Cfg.ServerDTLSAddr, ":")
dtlsPort = ss[1]
}
base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile)
// 返回客户端数据
@@ -126,17 +121,15 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-CSTP-MSIE-Proxy-Lockdown", "true")
w.Header().Set("X-CSTP-Smartcard-Removal-Disconnect", "true")
w.Header().Set("X-MTU", fmt.Sprintf("%d", cSess.Mtu)) // 1399
w.Header().Set("X-CSTP-MTU", fmt.Sprintf("%d", cSess.Mtu)) // 1399
w.Header().Set("X-DTLS-MTU", fmt.Sprintf("%d", cSess.Mtu))
w.Header().Set("X-DTLS-Session-ID", sess.DtlsSid)
w.Header().Set("X-DTLS-Port", "4433")
w.Header().Set("X-DTLS-Port", dtlsPort)
w.Header().Set("X-DTLS-DPD", fmt.Sprintf("%d", cstpDpd))
w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", base.Cfg.CstpKeepalive))
w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", cstpKeepalive))
w.Header().Set("X-DTLS-Rekey-Time", "5400")
w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-ECDSA-AES128-GCM-SHA256")
// w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-RSA-AES128-GCM-SHA256")
w.Header().Set("X-CSTP-License", "accept")
w.Header().Set("X-CSTP-Routing-Filtering-Ignore", "false")
@@ -173,7 +166,5 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
return
}
ss.Store(cSess.Sess.DtlsSid, cSess)
go LinkCstp(conn, cSess)
}

View File

@@ -35,20 +35,25 @@ func payloadInData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool
}
func payloadOut(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool {
dSess := cSess.GetDtlsSession()
if dSess == nil {
return payloadOutCstp(cSess, lType, pType, data)
} else {
return payloadOutDtls(dSess, lType, pType, data)
}
}
func payloadOutCstp(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool {
payload := &sessdata.Payload{
LType: lType,
PType: pType,
Data: data,
}
return payloadOutData(cSess, payload)
}
func payloadOutData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool {
closed := false
select {
case cSess.PayloadOut <- payload:
case cSess.PayloadOutCstp <- payload:
case <-cSess.CloseChan:
closed = true
}
@@ -56,6 +61,21 @@ func payloadOutData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool
return closed
}
func payloadOutDtls(dSess *sessdata.DtlsSession, lType sessdata.LType, pType byte, data []byte) bool {
payload := &sessdata.Payload{
LType: lType,
PType: pType,
Data: data,
}
select {
case dSess.CSess.PayloadOutDtls <- payload:
case <-dSess.CloseChan:
}
return false
}
// Acl规则校验
func checkLinkAcl(group *dbdata.Group, payload *sessdata.Payload) bool {
if payload.LType == sessdata.LTypeIPData && payload.PType == 0x00 && len(group.LinkAcl) > 0 {