Merge pull request #124 from lanrenwo/ip_audit_map_concurrent

解决IpAuditMap在UDP下的fatal error: concurrent map read and map write
This commit is contained in:
bjdgyc 2022-08-01 19:00:06 +08:00 committed by GitHub
commit e72994c0b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 204 additions and 7 deletions

View File

@ -13,6 +13,7 @@ require (
github.com/gorilla/mux v1.8.0
github.com/lib/pq v1.10.2
github.com/mattn/go-sqlite3 v1.14.8
github.com/orcaman/concurrent-map v1.0.0 // indirect
github.com/pion/dtls/v2 v2.0.9
github.com/pion/logging v0.2.2
github.com/shirou/gopsutil v3.21.7+incompatible

View File

@ -149,14 +149,14 @@ func logAudit(cSess *sessdata.ConnSession, pl *sessdata.Payload) {
nu := utils.NowSec().Unix()
// 判断已经存在,并且没有过期
v, ok := cSess.IpAuditMap[s]
if ok && nu-v < int64(base.Cfg.AuditInterval) {
v, ok := cSess.IpAuditMap.Get(s)
if ok && nu-v.(int64) < int64(base.Cfg.AuditInterval) {
// 回收byte对象
putByte51(b)
return
}
cSess.IpAuditMap[s] = nu
cSess.IpAuditMap.Set(s, nu)
audit := dbdata.AccessAudit{
Username: cSess.Sess.Username,

115
server/pkg/utils/maps.go Normal file
View File

@ -0,0 +1,115 @@
package utils
import (
"sync"
cmap "github.com/orcaman/concurrent-map"
)
type IMaps interface {
Set(key string, val interface{})
Get(key string) (interface{}, bool)
Del(key string)
}
/**
* 基础的Map结构
*
*/
type BaseMap struct {
m map[string]interface{}
}
func (m *BaseMap) Set(key string, value interface{}) {
m.m[key] = value
}
func (m *BaseMap) Get(key string) (interface{}, bool) {
v, ok := m.m[key]
return v, ok
}
func (m *BaseMap) Del(key string) {
delete(m.m, key)
}
/**
* CMap 并发结构
*
*/
type ConcurrentMap struct {
m cmap.ConcurrentMap
}
func (m *ConcurrentMap) Set(key string, value interface{}) {
m.m.Set(key, value)
}
func (m *ConcurrentMap) Get(key string) (interface{}, bool) {
return m.m.Get(key)
}
func (m *ConcurrentMap) Del(key string) {
m.m.Remove(key)
}
/**
* Map 读写结构
*
*/
type RWLockMap struct {
m map[string]interface{}
lock sync.RWMutex
}
func (m *RWLockMap) Set(key string, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.m[key] = value
}
func (m *RWLockMap) Get(key string) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.m[key]
return v, ok
}
func (m *RWLockMap) Del(key string) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.m, key)
}
/**
* sync.Map 结构
*
*/
type SyncMap struct {
m sync.Map
}
func (m *SyncMap) Set(key string, val interface{}) {
m.m.Store(key, val)
}
func (m *SyncMap) Get(key string) (interface{}, bool) {
return m.m.Load(key)
}
func (m *SyncMap) Del(key string) {
m.m.Delete(key)
}
func NewMap(name string, len int) IMaps {
switch name {
case "cmap":
return &ConcurrentMap{m: cmap.New()}
case "rwmap":
m := make(map[string]interface{}, len)
return &RWLockMap{m: m}
case "syncmap":
return &SyncMap{}
default:
m := make(map[string]interface{}, len)
return &BaseMap{m: m}
}
}

View File

@ -0,0 +1,76 @@
package utils
import (
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
const (
NumOfReader = 200
NumOfWriter = 100
)
func TestMaps(t *testing.T) {
assert := assert.New(t)
var ipAuditMap IMaps
key := "one"
value := 100
testMapData := map[string]int{"basemap": 512, "cmap": 0, "rwmap": 512, "syncmap": 0}
for name, len := range testMapData {
ipAuditMap = NewMap(name, len)
ipAuditMap.Set(key, value)
v, ok := ipAuditMap.Get(key)
assert.Equal(v.(int), value)
assert.True(ok)
ipAuditMap.Del(key)
v, ok = ipAuditMap.Get(key)
assert.Nil(v)
assert.False(ok)
}
}
func benchmarkMap(b *testing.B, hm IMaps) {
for i := 0; i < b.N; i++ {
var wg sync.WaitGroup
for i := 0; i < NumOfWriter; i++ {
wg.Add(1)
go func() {
for i := 0; i < 100; i++ {
hm.Set(strconv.Itoa(i), i*i)
hm.Set(strconv.Itoa(i), i*i)
hm.Del(strconv.Itoa(i))
}
wg.Done()
}()
}
for i := 0; i < NumOfReader; i++ {
wg.Add(1)
go func() {
for i := 0; i < 100; i++ {
hm.Get(strconv.Itoa(i))
}
wg.Done()
}()
}
wg.Wait()
}
}
func BenchmarkMaps(b *testing.B) {
b.Run("RW map", func(b *testing.B) {
myMap := NewMap("rwmap", 512)
benchmarkMap(b, myMap)
})
b.Run("Concurrent map", func(b *testing.B) {
myMap := NewMap("cmap", 0)
benchmarkMap(b, myMap)
})
b.Run("Sync map", func(b *testing.B) {
myMap := NewMap("syncmap", 0)
benchmarkMap(b, myMap)
})
}

View File

@ -13,6 +13,7 @@ import (
"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
)
var (
@ -46,9 +47,9 @@ type ConnSession struct {
closeOnce sync.Once
CloseChan chan struct{}
PayloadIn chan *Payload
PayloadOutCstp chan *Payload // Cstp的数据
PayloadOutDtls chan *Payload // Dtls的数据
IpAuditMap map[string]int64 // 审计的ip数据
PayloadOutCstp chan *Payload // Cstp的数据
PayloadOutDtls chan *Payload // Dtls的数据
IpAuditMap utils.IMaps // 审计的ip数据
// dSess *DtlsSession
dSess *atomic.Value
@ -191,7 +192,11 @@ func (s *Session) NewConn() *ConnSession {
// ip 审计
if base.Cfg.AuditInterval >= 0 {
cSess.IpAuditMap = make(map[string]int64, 512)
if base.Cfg.ServerDTLS {
cSess.IpAuditMap = utils.NewMap("cmap", 0)
} else {
cSess.IpAuditMap = utils.NewMap("", 512)
}
}
dSess := &DtlsSession{