diff --git a/server/base/cfg.go b/server/base/cfg.go index 1539c3e..af98a03 100644 --- a/server/base/cfg.go +++ b/server/base/cfg.go @@ -73,8 +73,10 @@ type ServerConfig struct { // AuthTimeout int `json:"auth_timeout"` // in seconds AuditInterval int `json:"audit_interval"` // in seconds - ShowSQL bool `json:"show_sql"` // bool - IptablesNat bool `json:"iptables_nat"` + ShowSQL bool `json:"show_sql"` // bool + IptablesNat bool `json:"iptables_nat"` + Compression bool `json:"compression"` // bool + NoCompressLimit int `json:"no_compress_limit"` // int } func initServerCfg() { diff --git a/server/base/config.go b/server/base/config.go index f63eac0..63976ad 100644 --- a/server/base/config.go +++ b/server/base/config.go @@ -62,6 +62,8 @@ var configs = []config{ {Typ: cfgBool, Name: "show_sql", Usage: "显示sql语句,用于调试", ValBool: false}, {Typ: cfgBool, Name: "iptables_nat", Usage: "是否自动添加NAT", ValBool: true}, + {Typ: cfgBool, Name: "compression", Usage: "启用压缩", ValBool: false}, + {Typ: cfgInt, Name: "no_compress_limit", Usage: "低于及等于多少字节不压缩", ValInt: 256}, } var envs = map[string]string{} diff --git a/server/conf/server-sample.toml b/server/conf/server-sample.toml index bc12165..a177e05 100644 --- a/server/conf/server-sample.toml +++ b/server/conf/server-sample.toml @@ -78,4 +78,7 @@ show_sql = false #是否自动添加nat iptables_nat = true - +#启用压缩 +compression = false +#低于及等于多少字节不压缩 +no_compress_limit = 256 \ No newline at end of file diff --git a/server/go.mod b/server/go.mod index a99f6a5..63051fc 100644 --- a/server/go.mod +++ b/server/go.mod @@ -14,6 +14,7 @@ require ( github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/ivpusic/grpool v1.0.0 + github.com/lanrenwo/lzsgo v0.0.2 github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.9 github.com/orcaman/concurrent-map v1.0.0 diff --git a/server/go.sum b/server/go.sum index 406727b..e396041 100644 --- a/server/go.sum +++ b/server/go.sum @@ -358,6 +358,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lanrenwo/lzsgo v0.0.2 h1:FA30LAaJFYLoaM17b+H32gA+5H+abjoomNLSA9HCbrI= +github.com/lanrenwo/lzsgo v0.0.2/go.mod h1:oxDZy2vgi6VBGIdvL80ayRMtIyXV+TbjavVuINXZY2k= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= diff --git a/server/handler/link_cstp.go b/server/handler/link_cstp.go index 0bda896..ba947e7 100644 --- a/server/handler/link_cstp.go +++ b/server/handler/link_cstp.go @@ -66,7 +66,22 @@ func LinkCstp(conn net.Conn, bufRW *bufio.ReadWriter, cSess *sessdata.ConnSessio return } case 0x04: - // log.Println("recv DPD-RESP") + // log.Println("recv DPD-RESP") + case 0x08: // decompress + if cSess.CstpPickCmp == nil { + continue + } + dst := getByteFull() + nn, err := cSess.CstpPickCmp.Uncompress(pl.Data[8:], *dst) + if err != nil { + putByte(dst) + base.Error("cstp decompress error", err, nn) + continue + } + binary.BigEndian.PutUint16(pl.Data[4:6], uint16(nn)) + pl.Data = append(pl.Data[:8], (*dst)[:nn]...) + putByte(dst) + fallthrough case 0x00: // DATA // 获取数据长度 dataLen = binary.BigEndian.Uint16(pl.Data[4:6]) // 4,5 @@ -112,16 +127,31 @@ func cstpWrite(conn net.Conn, bufRW *bufio.ReadWriter, cSess *sessdata.ConnSessi } if pl.PType == 0x00 { - // 获取数据长度 - l := len(pl.Data) - // 先扩容 +8 - pl.Data = pl.Data[:l+8] - // 数据后移 - copy(pl.Data[8:], pl.Data) - // 添加头信息 - copy(pl.Data[:8], plHeader) - // 更新头长度 - binary.BigEndian.PutUint16(pl.Data[4:6], uint16(l)) + isCompress := false + if cSess.CstpPickCmp != nil && len(pl.Data) > base.Cfg.NoCompressLimit { + dst := getByteFull() + size, err := cSess.CstpPickCmp.Compress(pl.Data, (*dst)[8:]) + if err == nil && size < len(pl.Data) { + copy((*dst)[:8], plHeader) + binary.BigEndian.PutUint16((*dst)[4:6], uint16(size)) + (*dst)[6] = 0x08 + pl.Data = append(pl.Data[:0], (*dst)[:size+8]...) + isCompress = true + } + putByte(dst) + } + if !isCompress { + // 获取数据长度 + l := len(pl.Data) + // 先扩容 +8 + pl.Data = pl.Data[:l+8] + // 数据后移 + copy(pl.Data[8:], pl.Data) + // 添加头信息 + copy(pl.Data[:8], plHeader) + // 更新头长度 + binary.BigEndian.PutUint16(pl.Data[4:6], uint16(l)) + } } else { pl.Data = append(pl.Data[:0], plHeader...) // 设置头类型 diff --git a/server/handler/link_dtls.go b/server/handler/link_dtls.go index e2b7311..34fe578 100644 --- a/server/handler/link_dtls.go +++ b/server/handler/link_dtls.go @@ -68,7 +68,22 @@ func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) { return } case 0x04: - // base.Debug("recv DPD-RESP", cSess.IpAddr) + // base.Debug("recv DPD-RESP", cSess.IpAddr) + case 0x08: // decompress + if cSess.DtlsPickCmp == nil { + continue + } + dst := getByteFull() + nn, err := cSess.DtlsPickCmp.Uncompress(pl.Data[1:], *dst) + if err != nil { + putByte(dst) + base.Error("dtls decompress error", err, n) + continue + } + pl.Data = append(pl.Data[:1], (*dst)[:nn]...) + putByte(dst) + n = nn + 1 + fallthrough case 0x00: // DATA // 去除数据头 // copy(pl.Data, pl.Data[1:n]) @@ -108,14 +123,28 @@ func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnS // header = []byte{payload.PType} if pl.PType == 0x00 { // data - // 获取数据长度 - l := len(pl.Data) - // 先扩容 +1 - pl.Data = pl.Data[:l+1] - // 数据后移 - copy(pl.Data[1:], pl.Data) - // 添加头信息 - pl.Data[0] = pl.PType + isCompress := false + if cSess.DtlsPickCmp != nil && len(pl.Data) > base.Cfg.NoCompressLimit { + dst := getByteFull() + size, err := cSess.DtlsPickCmp.Compress(pl.Data, (*dst)[1:]) + if err == nil && size < len(pl.Data) { + (*dst)[0] = 0x08 + pl.Data = append(pl.Data[:0], (*dst)[:size+1]...) + isCompress = true + } + putByte(dst) + } + // 未压缩 + if !isCompress { + // 获取数据长度 + l := len(pl.Data) + // 先扩容 +1 + pl.Data = pl.Data[:l+1] + // 数据后移 + copy(pl.Data[1:], pl.Data) + // 添加头信息 + pl.Data[0] = pl.PType + } } else { // 设置头类型 pl.Data = append(pl.Data[:0], pl.PType) diff --git a/server/handler/link_tunnel.go b/server/handler/link_tunnel.go index 7128a4c..c004081 100644 --- a/server/handler/link_tunnel.go +++ b/server/handler/link_tunnel.go @@ -89,6 +89,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) { base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile) + // 压缩 + if cmpName, ok := cSess.SetPickCmp("cstp", r.Header.Get("X-Cstp-Accept-Encoding")); ok { + HttpSetHeader(w, "X-CSTP-Content-Encoding", cmpName) + } + if cmpName, ok := cSess.SetPickCmp("dtls", r.Header.Get("X-Dtls-Accept-Encoding")); ok { + HttpSetHeader(w, "X-DTLS-Content-Encoding", cmpName) + } + // 返回客户端数据 HttpSetHeader(w, "Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER)) HttpSetHeader(w, "X-CSTP-Version", "1") diff --git a/server/sessdata/compress.go b/server/sessdata/compress.go new file mode 100644 index 0000000..7156f89 --- /dev/null +++ b/server/sessdata/compress.go @@ -0,0 +1,35 @@ +package sessdata + +import ( + "github.com/lanrenwo/lzsgo" +) + +type CmpEncoding interface { + Compress(src []byte, dst []byte) (int, error) + Uncompress(src []byte, dst []byte) (int, error) +} + +type LzsgoCmp struct { +} + +func (l LzsgoCmp) Compress(src []byte, dst []byte) (int, error) { + n, err := lzsgo.Compress(src, dst) + return n, err +} + +func (l LzsgoCmp) Uncompress(src []byte, dst []byte) (int, error) { + n, err := lzsgo.Uncompress(src, dst) + return n, err +} + +// type Lz4Cmp struct { +// c lz4.Compressor +// } + +// func (l Lz4Cmp) Compress(src []byte, dst []byte) (int, error) { +// return l.c.CompressBlock(src, dst) +// } + +// func (l Lz4Cmp) Uncompress(src []byte, dst []byte) (int, error) { +// return lz4.UncompressBlock(src, dst) +// } diff --git a/server/sessdata/compress_test.go b/server/sessdata/compress_test.go new file mode 100644 index 0000000..ce3a317 --- /dev/null +++ b/server/sessdata/compress_test.go @@ -0,0 +1,28 @@ +package sessdata + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLzsCompress(t *testing.T) { + var ( + n int + err error + ) + assert := assert.New(t) + c := LzsgoCmp{} + s := "hello anylink, you are best!" + src := []byte(strings.Repeat(s, 50)) + + comprBuf := make([]byte, 2048) + n, err = c.Compress(src, comprBuf) + assert.Nil(err) + + unprBuf := make([]byte, 2048) + n, err = c.Uncompress(comprBuf[:n], unprBuf) + assert.Nil(err) + assert.Equal(src, unprBuf[:n]) +} diff --git a/server/sessdata/session.go b/server/sessdata/session.go index b60675f..cedeca5 100644 --- a/server/sessdata/session.go +++ b/server/sessdata/session.go @@ -54,6 +54,9 @@ type ConnSession struct { PayloadOutDtls chan *Payload // Dtls的数据 // dSess *DtlsSession dSess *atomic.Value + // compress + CstpPickCmp CmpEncoding + DtlsPickCmp CmpEncoding } type DtlsSession struct { @@ -359,6 +362,30 @@ func (cs *ConnSession) RateLimit(byt int, isUp bool) error { return cs.Limit.Wait(byt) } +func (cs *ConnSession) SetPickCmp(cate, encoding string) (string, bool) { + var cmpName string + if !base.Cfg.Compression { + return cmpName, false + } + var cmp CmpEncoding + switch { + // case strings.Contains(encoding, "oc-lz4"): + // cmpName = "oc-lz4" + // cmp = Lz4Cmp{} + case strings.Contains(encoding, "lzs"): + cmpName = "lzs" + cmp = LzsgoCmp{} + default: + return cmpName, false + } + if cate == "cstp" { + cs.CstpPickCmp = cmp + } else { + cs.DtlsPickCmp = cmp + } + return cmpName, true +} + func SToken2Sess(stoken string) *Session { stoken = strings.TrimSpace(stoken) sarr := strings.Split(stoken, "@") diff --git a/server/sessdata/session_test.go b/server/sessdata/session_test.go index c9219b2..6c60ccc 100644 --- a/server/sessdata/session_test.go +++ b/server/sessdata/session_test.go @@ -1,8 +1,10 @@ package sessdata import ( + "fmt" "testing" + "github.com/bjdgyc/anylink/base" "github.com/stretchr/testify/assert" ) @@ -34,5 +36,23 @@ func TestConnSession(t *testing.T) { err = cSess.RateLimit(200, false) ast.Nil(err) ast.Equal(cSess.BandwidthDown.Load(), uint32(200)) + + var ( + cmpName string + ok bool + ) + base.Cfg.Compression = true + + cmpName, ok = cSess.SetPickCmp("cstp", "oc-lz4,lzs") + fmt.Println(cmpName, ok) + ast.True(ok) + ast.Equal(cmpName, "lzs") + cmpName, ok = cSess.SetPickCmp("dtls", "lzs") + ast.True(ok) + ast.Equal(cmpName, "lzs") + cmpName, ok = cSess.SetPickCmp("dtls", "test") + ast.False(ok) + ast.Equal(cmpName, "") + cSess.Close() }