diff --git a/server/go.mod b/server/go.mod index 7065a25..1ef67ff 100644 --- a/server/go.mod +++ b/server/go.mod @@ -13,6 +13,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.8 + github.com/orcaman/concurrent-map v1.0.0 // indirect github.com/pion/dtls/v2 v2.0.9 github.com/pion/logging v0.2.2 github.com/shirou/gopsutil v3.21.7+incompatible diff --git a/server/handler/payload.go b/server/handler/payload.go index 8062af8..a66eabf 100644 --- a/server/handler/payload.go +++ b/server/handler/payload.go @@ -149,14 +149,14 @@ func logAudit(cSess *sessdata.ConnSession, pl *sessdata.Payload) { nu := utils.NowSec().Unix() // 判断已经存在,并且没有过期 - v, ok := cSess.IpAuditMap[s] - if ok && nu-v < int64(base.Cfg.AuditInterval) { + v, ok := cSess.IpAuditMap.Get(s) + if ok && nu-v.(int64) < int64(base.Cfg.AuditInterval) { // 回收byte对象 putByte51(b) return } - cSess.IpAuditMap[s] = nu + cSess.IpAuditMap.Set(s, nu) audit := dbdata.AccessAudit{ Username: cSess.Sess.Username, diff --git a/server/pkg/utils/maps.go b/server/pkg/utils/maps.go new file mode 100644 index 0000000..df2a3c0 --- /dev/null +++ b/server/pkg/utils/maps.go @@ -0,0 +1,115 @@ +package utils + +import ( + "sync" + + cmap "github.com/orcaman/concurrent-map" +) + +type IMaps interface { + Set(key string, val interface{}) + Get(key string) (interface{}, bool) + Del(key string) +} + +/** + * 基础的Map结构 + * + */ +type BaseMap struct { + m map[string]interface{} +} + +func (m *BaseMap) Set(key string, value interface{}) { + m.m[key] = value +} +func (m *BaseMap) Get(key string) (interface{}, bool) { + v, ok := m.m[key] + return v, ok +} +func (m *BaseMap) Del(key string) { + delete(m.m, key) +} + +/** + * CMap 并发结构 + * + */ +type ConcurrentMap struct { + m cmap.ConcurrentMap +} + +func (m *ConcurrentMap) Set(key string, value interface{}) { + m.m.Set(key, value) +} + +func (m *ConcurrentMap) Get(key string) (interface{}, bool) { + return m.m.Get(key) +} + +func (m *ConcurrentMap) Del(key string) { + m.m.Remove(key) +} + +/** + * Map 读写结构 + * + */ +type RWLockMap struct { + m map[string]interface{} + lock sync.RWMutex +} + +func (m *RWLockMap) Set(key string, value interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + m.m[key] = value +} + +func (m *RWLockMap) Get(key string) (interface{}, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + v, ok := m.m[key] + return v, ok +} + +func (m *RWLockMap) Del(key string) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.m, key) +} + +/** + * sync.Map 结构 + * + */ +type SyncMap struct { + m sync.Map +} + +func (m *SyncMap) Set(key string, val interface{}) { + m.m.Store(key, val) +} + +func (m *SyncMap) Get(key string) (interface{}, bool) { + return m.m.Load(key) +} + +func (m *SyncMap) Del(key string) { + m.m.Delete(key) +} + +func NewMap(name string, len int) IMaps { + switch name { + case "cmap": + return &ConcurrentMap{m: cmap.New()} + case "rwmap": + m := make(map[string]interface{}, len) + return &RWLockMap{m: m} + case "syncmap": + return &SyncMap{} + default: + m := make(map[string]interface{}, len) + return &BaseMap{m: m} + } +} diff --git a/server/pkg/utils/maps_test.go b/server/pkg/utils/maps_test.go new file mode 100644 index 0000000..0646c5d --- /dev/null +++ b/server/pkg/utils/maps_test.go @@ -0,0 +1,76 @@ +package utils + +import ( + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + NumOfReader = 200 + NumOfWriter = 100 +) + +func TestMaps(t *testing.T) { + assert := assert.New(t) + var ipAuditMap IMaps + key := "one" + value := 100 + + testMapData := map[string]int{"basemap": 512, "cmap": 0, "rwmap": 512, "syncmap": 0} + for name, len := range testMapData { + ipAuditMap = NewMap(name, len) + ipAuditMap.Set(key, value) + v, ok := ipAuditMap.Get(key) + assert.Equal(v.(int), value) + assert.True(ok) + ipAuditMap.Del(key) + v, ok = ipAuditMap.Get(key) + assert.Nil(v) + assert.False(ok) + } +} + +func benchmarkMap(b *testing.B, hm IMaps) { + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + for i := 0; i < NumOfWriter; i++ { + wg.Add(1) + go func() { + for i := 0; i < 100; i++ { + hm.Set(strconv.Itoa(i), i*i) + hm.Set(strconv.Itoa(i), i*i) + hm.Del(strconv.Itoa(i)) + } + wg.Done() + }() + } + for i := 0; i < NumOfReader; i++ { + wg.Add(1) + go func() { + for i := 0; i < 100; i++ { + hm.Get(strconv.Itoa(i)) + } + wg.Done() + }() + } + wg.Wait() + } +} + +func BenchmarkMaps(b *testing.B) { + b.Run("RW map", func(b *testing.B) { + myMap := NewMap("rwmap", 512) + benchmarkMap(b, myMap) + }) + b.Run("Concurrent map", func(b *testing.B) { + myMap := NewMap("cmap", 0) + benchmarkMap(b, myMap) + }) + b.Run("Sync map", func(b *testing.B) { + myMap := NewMap("syncmap", 0) + benchmarkMap(b, myMap) + }) +} diff --git a/server/sessdata/session.go b/server/sessdata/session.go index f165247..87277d6 100644 --- a/server/sessdata/session.go +++ b/server/sessdata/session.go @@ -13,6 +13,7 @@ import ( "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/pkg/utils" ) var ( @@ -46,9 +47,9 @@ type ConnSession struct { closeOnce sync.Once CloseChan chan struct{} PayloadIn chan *Payload - PayloadOutCstp chan *Payload // Cstp的数据 - PayloadOutDtls chan *Payload // Dtls的数据 - IpAuditMap map[string]int64 // 审计的ip数据 + PayloadOutCstp chan *Payload // Cstp的数据 + PayloadOutDtls chan *Payload // Dtls的数据 + IpAuditMap utils.IMaps // 审计的ip数据 // dSess *DtlsSession dSess *atomic.Value @@ -191,7 +192,11 @@ func (s *Session) NewConn() *ConnSession { // ip 审计 if base.Cfg.AuditInterval >= 0 { - cSess.IpAuditMap = make(map[string]int64, 512) + if base.Cfg.ServerDTLS { + cSess.IpAuditMap = utils.NewMap("cmap", 0) + } else { + cSess.IpAuditMap = utils.NewMap("", 512) + } } dSess := &DtlsSession{