Merge pull request #219 from lanrenwo/add_lzs_compress

新增压缩功能-LZS算法
This commit is contained in:
bjdgyc 2023-02-16 16:01:19 +08:00 committed by GitHub
commit df52087473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 210 additions and 23 deletions

View File

@ -73,8 +73,10 @@ type ServerConfig struct {
// AuthTimeout int `json:"auth_timeout"` // in seconds // AuthTimeout int `json:"auth_timeout"` // in seconds
AuditInterval int `json:"audit_interval"` // in seconds AuditInterval int `json:"audit_interval"` // in seconds
ShowSQL bool `json:"show_sql"` // bool ShowSQL bool `json:"show_sql"` // bool
IptablesNat bool `json:"iptables_nat"` IptablesNat bool `json:"iptables_nat"`
Compression bool `json:"compression"` // bool
NoCompressLimit int `json:"no_compress_limit"` // int
} }
func initServerCfg() { func initServerCfg() {

View File

@ -62,6 +62,8 @@ var configs = []config{
{Typ: cfgBool, Name: "show_sql", Usage: "显示sql语句用于调试", ValBool: false}, {Typ: cfgBool, Name: "show_sql", Usage: "显示sql语句用于调试", ValBool: false},
{Typ: cfgBool, Name: "iptables_nat", Usage: "是否自动添加NAT", ValBool: true}, {Typ: cfgBool, Name: "iptables_nat", Usage: "是否自动添加NAT", ValBool: true},
{Typ: cfgBool, Name: "compression", Usage: "启用压缩", ValBool: false},
{Typ: cfgInt, Name: "no_compress_limit", Usage: "低于及等于多少字节不压缩", ValInt: 256},
} }
var envs = map[string]string{} var envs = map[string]string{}

View File

@ -78,4 +78,7 @@ show_sql = false
#是否自动添加nat #是否自动添加nat
iptables_nat = true iptables_nat = true
#启用压缩
compression = false
#低于及等于多少字节不压缩
no_compress_limit = 256

View File

@ -14,6 +14,7 @@ require (
github.com/gorilla/handlers v1.5.1 github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/ivpusic/grpool v1.0.0 github.com/ivpusic/grpool v1.0.0
github.com/lanrenwo/lzsgo v0.0.2
github.com/lib/pq v1.10.2 github.com/lib/pq v1.10.2
github.com/mattn/go-sqlite3 v1.14.9 github.com/mattn/go-sqlite3 v1.14.9
github.com/orcaman/concurrent-map v1.0.0 github.com/orcaman/concurrent-map v1.0.0

View File

@ -358,6 +358,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lanrenwo/lzsgo v0.0.2 h1:FA30LAaJFYLoaM17b+H32gA+5H+abjoomNLSA9HCbrI=
github.com/lanrenwo/lzsgo v0.0.2/go.mod h1:oxDZy2vgi6VBGIdvL80ayRMtIyXV+TbjavVuINXZY2k=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=

View File

@ -66,7 +66,22 @@ func LinkCstp(conn net.Conn, bufRW *bufio.ReadWriter, cSess *sessdata.ConnSessio
return return
} }
case 0x04: case 0x04:
// log.Println("recv DPD-RESP") // log.Println("recv DPD-RESP")
case 0x08: // decompress
if cSess.CstpPickCmp == nil {
continue
}
dst := getByteFull()
nn, err := cSess.CstpPickCmp.Uncompress(pl.Data[8:], *dst)
if err != nil {
putByte(dst)
base.Error("cstp decompress error", err, nn)
continue
}
binary.BigEndian.PutUint16(pl.Data[4:6], uint16(nn))
pl.Data = append(pl.Data[:8], (*dst)[:nn]...)
putByte(dst)
fallthrough
case 0x00: // DATA case 0x00: // DATA
// 获取数据长度 // 获取数据长度
dataLen = binary.BigEndian.Uint16(pl.Data[4:6]) // 4,5 dataLen = binary.BigEndian.Uint16(pl.Data[4:6]) // 4,5
@ -112,16 +127,31 @@ func cstpWrite(conn net.Conn, bufRW *bufio.ReadWriter, cSess *sessdata.ConnSessi
} }
if pl.PType == 0x00 { if pl.PType == 0x00 {
// 获取数据长度 isCompress := false
l := len(pl.Data) if cSess.CstpPickCmp != nil && len(pl.Data) > base.Cfg.NoCompressLimit {
// 先扩容 +8 dst := getByteFull()
pl.Data = pl.Data[:l+8] size, err := cSess.CstpPickCmp.Compress(pl.Data, (*dst)[8:])
// 数据后移 if err == nil && size < len(pl.Data) {
copy(pl.Data[8:], pl.Data) copy((*dst)[:8], plHeader)
// 添加头信息 binary.BigEndian.PutUint16((*dst)[4:6], uint16(size))
copy(pl.Data[:8], plHeader) (*dst)[6] = 0x08
// 更新头长度 pl.Data = append(pl.Data[:0], (*dst)[:size+8]...)
binary.BigEndian.PutUint16(pl.Data[4:6], uint16(l)) isCompress = true
}
putByte(dst)
}
if !isCompress {
// 获取数据长度
l := len(pl.Data)
// 先扩容 +8
pl.Data = pl.Data[:l+8]
// 数据后移
copy(pl.Data[8:], pl.Data)
// 添加头信息
copy(pl.Data[:8], plHeader)
// 更新头长度
binary.BigEndian.PutUint16(pl.Data[4:6], uint16(l))
}
} else { } else {
pl.Data = append(pl.Data[:0], plHeader...) pl.Data = append(pl.Data[:0], plHeader...)
// 设置头类型 // 设置头类型

View File

@ -68,7 +68,22 @@ func LinkDtls(conn net.Conn, cSess *sessdata.ConnSession) {
return return
} }
case 0x04: case 0x04:
// base.Debug("recv DPD-RESP", cSess.IpAddr) // base.Debug("recv DPD-RESP", cSess.IpAddr)
case 0x08: // decompress
if cSess.DtlsPickCmp == nil {
continue
}
dst := getByteFull()
nn, err := cSess.DtlsPickCmp.Uncompress(pl.Data[1:], *dst)
if err != nil {
putByte(dst)
base.Error("dtls decompress error", err, n)
continue
}
pl.Data = append(pl.Data[:1], (*dst)[:nn]...)
putByte(dst)
n = nn + 1
fallthrough
case 0x00: // DATA case 0x00: // DATA
// 去除数据头 // 去除数据头
// copy(pl.Data, pl.Data[1:n]) // copy(pl.Data, pl.Data[1:n])
@ -108,14 +123,28 @@ func dtlsWrite(conn net.Conn, dSess *sessdata.DtlsSession, cSess *sessdata.ConnS
// header = []byte{payload.PType} // header = []byte{payload.PType}
if pl.PType == 0x00 { // data if pl.PType == 0x00 { // data
// 获取数据长度 isCompress := false
l := len(pl.Data) if cSess.DtlsPickCmp != nil && len(pl.Data) > base.Cfg.NoCompressLimit {
// 先扩容 +1 dst := getByteFull()
pl.Data = pl.Data[:l+1] size, err := cSess.DtlsPickCmp.Compress(pl.Data, (*dst)[1:])
// 数据后移 if err == nil && size < len(pl.Data) {
copy(pl.Data[1:], pl.Data) (*dst)[0] = 0x08
// 添加头信息 pl.Data = append(pl.Data[:0], (*dst)[:size+1]...)
pl.Data[0] = pl.PType isCompress = true
}
putByte(dst)
}
// 未压缩
if !isCompress {
// 获取数据长度
l := len(pl.Data)
// 先扩容 +1
pl.Data = pl.Data[:l+1]
// 数据后移
copy(pl.Data[1:], pl.Data)
// 添加头信息
pl.Data[0] = pl.PType
}
} else { } else {
// 设置头类型 // 设置头类型
pl.Data = append(pl.Data[:0], pl.PType) pl.Data = append(pl.Data[:0], pl.PType)

View File

@ -89,6 +89,14 @@ func LinkTunnel(w http.ResponseWriter, r *http.Request) {
base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile) base.Debug(cSess.IpAddr, cSess.MacHw, sess.Username, mobile)
// 压缩
if cmpName, ok := cSess.SetPickCmp("cstp", r.Header.Get("X-Cstp-Accept-Encoding")); ok {
HttpSetHeader(w, "X-CSTP-Content-Encoding", cmpName)
}
if cmpName, ok := cSess.SetPickCmp("dtls", r.Header.Get("X-Dtls-Accept-Encoding")); ok {
HttpSetHeader(w, "X-DTLS-Content-Encoding", cmpName)
}
// 返回客户端数据 // 返回客户端数据
HttpSetHeader(w, "Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER)) HttpSetHeader(w, "Server", fmt.Sprintf("%s %s", base.APP_NAME, base.APP_VER))
HttpSetHeader(w, "X-CSTP-Version", "1") HttpSetHeader(w, "X-CSTP-Version", "1")

View File

@ -0,0 +1,35 @@
package sessdata
import (
"github.com/lanrenwo/lzsgo"
)
type CmpEncoding interface {
Compress(src []byte, dst []byte) (int, error)
Uncompress(src []byte, dst []byte) (int, error)
}
type LzsgoCmp struct {
}
func (l LzsgoCmp) Compress(src []byte, dst []byte) (int, error) {
n, err := lzsgo.Compress(src, dst)
return n, err
}
func (l LzsgoCmp) Uncompress(src []byte, dst []byte) (int, error) {
n, err := lzsgo.Uncompress(src, dst)
return n, err
}
// type Lz4Cmp struct {
// c lz4.Compressor
// }
// func (l Lz4Cmp) Compress(src []byte, dst []byte) (int, error) {
// return l.c.CompressBlock(src, dst)
// }
// func (l Lz4Cmp) Uncompress(src []byte, dst []byte) (int, error) {
// return lz4.UncompressBlock(src, dst)
// }

View File

@ -0,0 +1,28 @@
package sessdata
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLzsCompress(t *testing.T) {
var (
n int
err error
)
assert := assert.New(t)
c := LzsgoCmp{}
s := "hello anylink, you are best!"
src := []byte(strings.Repeat(s, 50))
comprBuf := make([]byte, 2048)
n, err = c.Compress(src, comprBuf)
assert.Nil(err)
unprBuf := make([]byte, 2048)
n, err = c.Uncompress(comprBuf[:n], unprBuf)
assert.Nil(err)
assert.Equal(src, unprBuf[:n])
}

View File

@ -54,6 +54,9 @@ type ConnSession struct {
PayloadOutDtls chan *Payload // Dtls的数据 PayloadOutDtls chan *Payload // Dtls的数据
// dSess *DtlsSession // dSess *DtlsSession
dSess *atomic.Value dSess *atomic.Value
// compress
CstpPickCmp CmpEncoding
DtlsPickCmp CmpEncoding
} }
type DtlsSession struct { type DtlsSession struct {
@ -359,6 +362,30 @@ func (cs *ConnSession) RateLimit(byt int, isUp bool) error {
return cs.Limit.Wait(byt) return cs.Limit.Wait(byt)
} }
func (cs *ConnSession) SetPickCmp(cate, encoding string) (string, bool) {
var cmpName string
if !base.Cfg.Compression {
return cmpName, false
}
var cmp CmpEncoding
switch {
// case strings.Contains(encoding, "oc-lz4"):
// cmpName = "oc-lz4"
// cmp = Lz4Cmp{}
case strings.Contains(encoding, "lzs"):
cmpName = "lzs"
cmp = LzsgoCmp{}
default:
return cmpName, false
}
if cate == "cstp" {
cs.CstpPickCmp = cmp
} else {
cs.DtlsPickCmp = cmp
}
return cmpName, true
}
func SToken2Sess(stoken string) *Session { func SToken2Sess(stoken string) *Session {
stoken = strings.TrimSpace(stoken) stoken = strings.TrimSpace(stoken)
sarr := strings.Split(stoken, "@") sarr := strings.Split(stoken, "@")

View File

@ -1,8 +1,10 @@
package sessdata package sessdata
import ( import (
"fmt"
"testing" "testing"
"github.com/bjdgyc/anylink/base"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -34,5 +36,23 @@ func TestConnSession(t *testing.T) {
err = cSess.RateLimit(200, false) err = cSess.RateLimit(200, false)
ast.Nil(err) ast.Nil(err)
ast.Equal(cSess.BandwidthDown.Load(), uint32(200)) ast.Equal(cSess.BandwidthDown.Load(), uint32(200))
var (
cmpName string
ok bool
)
base.Cfg.Compression = true
cmpName, ok = cSess.SetPickCmp("cstp", "oc-lz4,lzs")
fmt.Println(cmpName, ok)
ast.True(ok)
ast.Equal(cmpName, "lzs")
cmpName, ok = cSess.SetPickCmp("dtls", "lzs")
ast.True(ok)
ast.Equal(cmpName, "lzs")
cmpName, ok = cSess.SetPickCmp("dtls", "test")
ast.False(ok)
ast.Equal(cmpName, "")
cSess.Close() cSess.Close()
} }