mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 11:10:14 +08:00
添加dtls支持
This commit is contained in:
@@ -1,59 +1,78 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
"os"
|
||||
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/sessdata"
|
||||
"github.com/pion/dtls/v2"
|
||||
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
|
||||
"github.com/pion/logging"
|
||||
)
|
||||
|
||||
// 因本项目对 github.com/pion/dtls 的代码,进行了大量的修改
|
||||
// 且短时间内无法合并到上游项目
|
||||
// 所以本项目暂时copy了一份代码
|
||||
// 最后,感谢 github.com/pion/dtls 对golang生态做出的贡献
|
||||
|
||||
func startDtls() {
|
||||
certificate, err := selfsign.GenerateSelfSigned()
|
||||
|
||||
logf := logging.NewDefaultLoggerFactory()
|
||||
logf.DefaultLogLevel = logging.LogLevelTrace
|
||||
f, err := os.OpenFile("/tmp/key.log", os.O_TRUNC|os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
logf := logging.NewDefaultLoggerFactory()
|
||||
logf.Writer = base.GetBaseLw()
|
||||
// logf.DefaultLogLevel = logging.LogLevelTrace
|
||||
logf.DefaultLogLevel = logging.LogLevelInfo
|
||||
|
||||
config := &dtls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
InsecureSkipVerify: true,
|
||||
ExtendedMasterSecret: dtls.DisableExtendedMasterSecret,
|
||||
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
|
||||
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
LoggerFactory: logf,
|
||||
KeyLogWriter: f,
|
||||
MTU: BufferSize,
|
||||
CiscoCompat: func(sessid []byte) ([]byte, error) {
|
||||
masterSecret := sessdata.Dtls2MasterSecret(hex.EncodeToString(sessid))
|
||||
if masterSecret == "" {
|
||||
return nil, fmt.Errorf("masterSecret is err")
|
||||
}
|
||||
return hex.DecodeString(masterSecret)
|
||||
},
|
||||
ConnectContextMaker: func() (context.Context, func()) {
|
||||
return context.WithTimeout(context.Background(), 5*time.Second)
|
||||
},
|
||||
}
|
||||
|
||||
addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 4433}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", base.Cfg.ServerDTLSAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ln, err := dtls.Listen("udp", addr, config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
base.Info("listen DTLS server", addr)
|
||||
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
log.Println("Accept error", err)
|
||||
base.Error("DTLS Accept error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
cc := c.(*dtls.Conn)
|
||||
id := hex.EncodeToString(cc.ConnectionState().SessionID)
|
||||
s, ok := ss.Load(id)
|
||||
log.Println("get link", id, ok)
|
||||
cs := s.(*sessdata.ConnSession)
|
||||
LinkDtls(c, cs)
|
||||
// time.Sleep(1 * time.Second)
|
||||
cc := conn.(*dtls.Conn)
|
||||
sessid := hex.EncodeToString(cc.ConnectionState().SessionID)
|
||||
sess := sessdata.Dtls2Sess(sessid)
|
||||
LinkDtls(conn, sess.CSess)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
@@ -55,7 +55,7 @@ func LinkCstp(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
return
|
||||
case 0x03: // DPD-REQ
|
||||
// base.Debug("recv DPD-REQ", cSess.IpAddr)
|
||||
if payloadOut(cSess, sessdata.LTypeIPData, 0x04, nil) {
|
||||
if payloadOutCstp(cSess, sessdata.LTypeIPData, 0x04, nil) {
|
||||
return
|
||||
}
|
||||
case 0x04:
|
||||
@@ -86,7 +86,7 @@ func cstpWrite(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
|
||||
for {
|
||||
select {
|
||||
case payload = <-cSess.PayloadOut:
|
||||
case payload = <-cSess.PayloadOutCstp:
|
||||
case <-cSess.CloseChan:
|
||||
return
|
||||
}
|
||||
|
@@ -9,19 +9,33 @@ import (
|
||||
)
|
||||
|
||||
func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
dSess := cSess.NewDtlsConn()
|
||||
if dSess == nil {
|
||||
// 创建失败,直接关闭链接
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
base.Debug("LinkDtls return", cSess.IpAddr)
|
||||
_ = conn.Close()
|
||||
cSess.Close()
|
||||
dSess.Close()
|
||||
}()
|
||||
|
||||
var (
|
||||
dead = time.Duration(cSess.CstpDpd+5) * time.Second
|
||||
)
|
||||
|
||||
go dtlsWrite(conn, cSess)
|
||||
go dtlsWrite(conn, dSess, cSess)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for {
|
||||
|
||||
if time.Now().Sub(now) > time.Second*30 {
|
||||
// return
|
||||
}
|
||||
|
||||
err := conn.SetReadDeadline(time.Now().Add(dead))
|
||||
if err != nil {
|
||||
base.Error("SetDeadline: ", err)
|
||||
@@ -48,26 +62,33 @@ func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
base.Debug("DISCONNECT", cSess.IpAddr)
|
||||
return
|
||||
case 0x03: // DPD-REQ
|
||||
base.Debug("recv DPD-REQ", cSess.IpAddr)
|
||||
if payloadOut(cSess, sessdata.LTypeIPData, 0x04, nil) {
|
||||
// base.Debug("recv DPD-REQ", cSess.IpAddr)
|
||||
payload := &sessdata.Payload{
|
||||
LType: sessdata.LTypeIPData,
|
||||
PType: 0x04,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
select {
|
||||
case cSess.PayloadOutDtls <- payload:
|
||||
case <-dSess.CloseChan:
|
||||
return
|
||||
}
|
||||
case 0x04:
|
||||
base.Debug("recv DPD-RESP", cSess.IpAddr)
|
||||
// base.Debug("recv DPD-RESP", cSess.IpAddr)
|
||||
case 0x00: // DATA
|
||||
if payloadIn(cSess, sessdata.LTypeIPData, 0x00, hdata[1:]) {
|
||||
if payloadIn(cSess, sessdata.LTypeIPData, 0x00, hdata[1:n]) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dtlsWrite(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnSession) {
|
||||
defer func() {
|
||||
base.Debug("dtlsWrite return", cSess.IpAddr)
|
||||
_ = conn.Close()
|
||||
cSess.Close()
|
||||
dSess.Close()
|
||||
}()
|
||||
|
||||
var (
|
||||
@@ -76,9 +97,10 @@ func dtlsWrite(conn net.Conn, cSess *sessdata.ConnSession) {
|
||||
)
|
||||
|
||||
for {
|
||||
// dtls优先推送数据
|
||||
select {
|
||||
case payload = <-cSess.PayloadOut:
|
||||
case <-cSess.CloseChan:
|
||||
case payload = <-cSess.PayloadOutDtls:
|
||||
case <-dSess.CloseChan:
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -46,7 +46,7 @@ func LinkTun(cSess *sessdata.ConnSession) error {
|
||||
cmdstr1 := fmt.Sprintf("ip link set dev %s up mtu %d multicast off", ifce.Name(), cSess.Mtu)
|
||||
cmdstr2 := fmt.Sprintf("ip addr add dev %s local %s peer %s/32",
|
||||
ifce.Name(), base.Cfg.Ipv4Gateway, cSess.IpAddr)
|
||||
cmdstr3 := "true"
|
||||
cmdstr3 := fmt.Sprintf("sysctl -w net.ipv6.conf.%s.disable_ipv6=1", ifce.Name())
|
||||
cmdStrs := []string{cmdstr1, cmdstr2, cmdstr3}
|
||||
err = execCmd(cmdStrs)
|
||||
if err != nil {
|
||||
|
@@ -2,23 +2,20 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"sync"
|
||||
"strings"
|
||||
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/sessdata"
|
||||
"github.com/pion/dtls/v2"
|
||||
)
|
||||
|
||||
var hn string
|
||||
|
||||
var ss sync.Map
|
||||
var (
|
||||
hn string
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 获取主机名称
|
||||
@@ -27,9 +24,9 @@ func init() {
|
||||
|
||||
func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO 调试信息输出
|
||||
hd, _ := httputil.DumpRequest(r, true)
|
||||
fmt.Println("DumpRequest: ", string(hd))
|
||||
fmt.Println("LinkTunnel", r.RemoteAddr)
|
||||
// hd, _ := httputil.DumpRequest(r, true)
|
||||
// fmt.Println("DumpRequest: ", string(hd))
|
||||
// fmt.Println("LinkTunnel", r.RemoteAddr)
|
||||
|
||||
// 判断session-token的值
|
||||
cookie, err := r.Cookie("webvpn")
|
||||
@@ -58,14 +55,6 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
localIp := r.Header.Get("X-Cstp-Local-Address-Ip4")
|
||||
mobile := r.Header.Get("X-Cstp-License")
|
||||
|
||||
preMasterSecret, err := hex.DecodeString(masterSecret)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
dtls.Sessions.Store(sess.DtlsSid, preMasterSecret)
|
||||
|
||||
cSess.SetMtu(cstpMtu)
|
||||
cSess.MasterSecret = masterSecret
|
||||
cSess.RemoteAddr = r.RemoteAddr
|
||||
@@ -81,6 +70,12 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
cSess.CstpDpd = cstpDpd
|
||||
|
||||
dtlsPort := ""
|
||||
if strings.Contains(base.Cfg.ServerDTLSAddr, ":") {
|
||||
ss := strings.Split(base.Cfg.ServerDTLSAddr, ":")
|
||||
dtlsPort = ss[1]
|
||||
}
|
||||
|
||||
base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile)
|
||||
|
||||
// 返回客户端数据
|
||||
@@ -126,17 +121,15 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-CSTP-MSIE-Proxy-Lockdown", "true")
|
||||
w.Header().Set("X-CSTP-Smartcard-Removal-Disconnect", "true")
|
||||
|
||||
w.Header().Set("X-MTU", fmt.Sprintf("%d", cSess.Mtu)) // 1399
|
||||
w.Header().Set("X-CSTP-MTU", fmt.Sprintf("%d", cSess.Mtu)) // 1399
|
||||
w.Header().Set("X-DTLS-MTU", fmt.Sprintf("%d", cSess.Mtu))
|
||||
|
||||
w.Header().Set("X-DTLS-Session-ID", sess.DtlsSid)
|
||||
w.Header().Set("X-DTLS-Port", "4433")
|
||||
w.Header().Set("X-DTLS-Port", dtlsPort)
|
||||
w.Header().Set("X-DTLS-DPD", fmt.Sprintf("%d", cstpDpd))
|
||||
w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", base.Cfg.CstpKeepalive))
|
||||
w.Header().Set("X-DTLS-Keepalive", fmt.Sprintf("%d", cstpKeepalive))
|
||||
w.Header().Set("X-DTLS-Rekey-Time", "5400")
|
||||
w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-ECDSA-AES128-GCM-SHA256")
|
||||
// w.Header().Set("X-DTLS12-CipherSuite", "ECDHE-RSA-AES128-GCM-SHA256")
|
||||
|
||||
w.Header().Set("X-CSTP-License", "accept")
|
||||
w.Header().Set("X-CSTP-Routing-Filtering-Ignore", "false")
|
||||
@@ -173,7 +166,5 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ss.Store(cSess.Sess.DtlsSid, cSess)
|
||||
|
||||
go LinkCstp(conn, cSess)
|
||||
}
|
||||
|
@@ -35,20 +35,25 @@ func payloadInData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool
|
||||
}
|
||||
|
||||
func payloadOut(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool {
|
||||
dSess := cSess.GetDtlsSession()
|
||||
if dSess == nil {
|
||||
return payloadOutCstp(cSess, lType, pType, data)
|
||||
} else {
|
||||
return payloadOutDtls(dSess, lType, pType, data)
|
||||
}
|
||||
}
|
||||
|
||||
func payloadOutCstp(cSess *sessdata.ConnSession, lType sessdata.LType, pType byte, data []byte) bool {
|
||||
payload := &sessdata.Payload{
|
||||
LType: lType,
|
||||
PType: pType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return payloadOutData(cSess, payload)
|
||||
}
|
||||
|
||||
func payloadOutData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool {
|
||||
closed := false
|
||||
|
||||
select {
|
||||
case cSess.PayloadOut <- payload:
|
||||
case cSess.PayloadOutCstp <- payload:
|
||||
case <-cSess.CloseChan:
|
||||
closed = true
|
||||
}
|
||||
@@ -56,6 +61,21 @@ func payloadOutData(cSess *sessdata.ConnSession, payload *sessdata.Payload) bool
|
||||
return closed
|
||||
}
|
||||
|
||||
func payloadOutDtls(dSess *sessdata.DtlsSession, lType sessdata.LType, pType byte, data []byte) bool {
|
||||
payload := &sessdata.Payload{
|
||||
LType: lType,
|
||||
PType: pType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
select {
|
||||
case dSess.CSess.PayloadOutDtls <- payload:
|
||||
case <-dSess.CloseChan:
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Acl规则校验
|
||||
func checkLinkAcl(group *dbdata.Group, payload *sessdata.Payload) bool {
|
||||
if payload.LType == sessdata.LTypeIPData && payload.PType == 0x00 && len(group.LinkAcl) > 0 {
|
||||
|
Reference in New Issue
Block a user