mirror of https://github.com/bjdgyc/anylink.git
Merge pull request #124 from lanrenwo/ip_audit_map_concurrent
解决IpAuditMap在UDP下的fatal error: concurrent map read and map write
This commit is contained in:
commit
e72994c0b8
|
@ -13,6 +13,7 @@ require (
|
|||
github.com/gorilla/mux v1.8.0
|
||||
github.com/lib/pq v1.10.2
|
||||
github.com/mattn/go-sqlite3 v1.14.8
|
||||
github.com/orcaman/concurrent-map v1.0.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.0.9
|
||||
github.com/pion/logging v0.2.2
|
||||
github.com/shirou/gopsutil v3.21.7+incompatible
|
||||
|
|
|
@ -149,14 +149,14 @@ func logAudit(cSess *sessdata.ConnSession, pl *sessdata.Payload) {
|
|||
nu := utils.NowSec().Unix()
|
||||
|
||||
// 判断已经存在,并且没有过期
|
||||
v, ok := cSess.IpAuditMap[s]
|
||||
if ok && nu-v < int64(base.Cfg.AuditInterval) {
|
||||
v, ok := cSess.IpAuditMap.Get(s)
|
||||
if ok && nu-v.(int64) < int64(base.Cfg.AuditInterval) {
|
||||
// 回收byte对象
|
||||
putByte51(b)
|
||||
return
|
||||
}
|
||||
|
||||
cSess.IpAuditMap[s] = nu
|
||||
cSess.IpAuditMap.Set(s, nu)
|
||||
|
||||
audit := dbdata.AccessAudit{
|
||||
Username: cSess.Sess.Username,
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
)
|
||||
|
||||
type IMaps interface {
|
||||
Set(key string, val interface{})
|
||||
Get(key string) (interface{}, bool)
|
||||
Del(key string)
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础的Map结构
|
||||
*
|
||||
*/
|
||||
type BaseMap struct {
|
||||
m map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *BaseMap) Set(key string, value interface{}) {
|
||||
m.m[key] = value
|
||||
}
|
||||
func (m *BaseMap) Get(key string) (interface{}, bool) {
|
||||
v, ok := m.m[key]
|
||||
return v, ok
|
||||
}
|
||||
func (m *BaseMap) Del(key string) {
|
||||
delete(m.m, key)
|
||||
}
|
||||
|
||||
/**
|
||||
* CMap 并发结构
|
||||
*
|
||||
*/
|
||||
type ConcurrentMap struct {
|
||||
m cmap.ConcurrentMap
|
||||
}
|
||||
|
||||
func (m *ConcurrentMap) Set(key string, value interface{}) {
|
||||
m.m.Set(key, value)
|
||||
}
|
||||
|
||||
func (m *ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||
return m.m.Get(key)
|
||||
}
|
||||
|
||||
func (m *ConcurrentMap) Del(key string) {
|
||||
m.m.Remove(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* Map 读写结构
|
||||
*
|
||||
*/
|
||||
type RWLockMap struct {
|
||||
m map[string]interface{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *RWLockMap) Set(key string, value interface{}) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.m[key] = value
|
||||
}
|
||||
|
||||
func (m *RWLockMap) Get(key string) (interface{}, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
v, ok := m.m[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *RWLockMap) Del(key string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
delete(m.m, key)
|
||||
}
|
||||
|
||||
/**
|
||||
* sync.Map 结构
|
||||
*
|
||||
*/
|
||||
type SyncMap struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (m *SyncMap) Set(key string, val interface{}) {
|
||||
m.m.Store(key, val)
|
||||
}
|
||||
|
||||
func (m *SyncMap) Get(key string) (interface{}, bool) {
|
||||
return m.m.Load(key)
|
||||
}
|
||||
|
||||
func (m *SyncMap) Del(key string) {
|
||||
m.m.Delete(key)
|
||||
}
|
||||
|
||||
func NewMap(name string, len int) IMaps {
|
||||
switch name {
|
||||
case "cmap":
|
||||
return &ConcurrentMap{m: cmap.New()}
|
||||
case "rwmap":
|
||||
m := make(map[string]interface{}, len)
|
||||
return &RWLockMap{m: m}
|
||||
case "syncmap":
|
||||
return &SyncMap{}
|
||||
default:
|
||||
m := make(map[string]interface{}, len)
|
||||
return &BaseMap{m: m}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
NumOfReader = 200
|
||||
NumOfWriter = 100
|
||||
)
|
||||
|
||||
func TestMaps(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
var ipAuditMap IMaps
|
||||
key := "one"
|
||||
value := 100
|
||||
|
||||
testMapData := map[string]int{"basemap": 512, "cmap": 0, "rwmap": 512, "syncmap": 0}
|
||||
for name, len := range testMapData {
|
||||
ipAuditMap = NewMap(name, len)
|
||||
ipAuditMap.Set(key, value)
|
||||
v, ok := ipAuditMap.Get(key)
|
||||
assert.Equal(v.(int), value)
|
||||
assert.True(ok)
|
||||
ipAuditMap.Del(key)
|
||||
v, ok = ipAuditMap.Get(key)
|
||||
assert.Nil(v)
|
||||
assert.False(ok)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkMap(b *testing.B, hm IMaps) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < NumOfWriter; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
hm.Set(strconv.Itoa(i), i*i)
|
||||
hm.Set(strconv.Itoa(i), i*i)
|
||||
hm.Del(strconv.Itoa(i))
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < NumOfReader; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
hm.Get(strconv.Itoa(i))
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMaps(b *testing.B) {
|
||||
b.Run("RW map", func(b *testing.B) {
|
||||
myMap := NewMap("rwmap", 512)
|
||||
benchmarkMap(b, myMap)
|
||||
})
|
||||
b.Run("Concurrent map", func(b *testing.B) {
|
||||
myMap := NewMap("cmap", 0)
|
||||
benchmarkMap(b, myMap)
|
||||
})
|
||||
b.Run("Sync map", func(b *testing.B) {
|
||||
myMap := NewMap("syncmap", 0)
|
||||
benchmarkMap(b, myMap)
|
||||
})
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
"github.com/bjdgyc/anylink/dbdata"
|
||||
"github.com/bjdgyc/anylink/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -46,9 +47,9 @@ type ConnSession struct {
|
|||
closeOnce sync.Once
|
||||
CloseChan chan struct{}
|
||||
PayloadIn chan *Payload
|
||||
PayloadOutCstp chan *Payload // Cstp的数据
|
||||
PayloadOutDtls chan *Payload // Dtls的数据
|
||||
IpAuditMap map[string]int64 // 审计的ip数据
|
||||
PayloadOutCstp chan *Payload // Cstp的数据
|
||||
PayloadOutDtls chan *Payload // Dtls的数据
|
||||
IpAuditMap utils.IMaps // 审计的ip数据
|
||||
|
||||
// dSess *DtlsSession
|
||||
dSess *atomic.Value
|
||||
|
@ -191,7 +192,11 @@ func (s *Session) NewConn() *ConnSession {
|
|||
|
||||
// ip 审计
|
||||
if base.Cfg.AuditInterval >= 0 {
|
||||
cSess.IpAuditMap = make(map[string]int64, 512)
|
||||
if base.Cfg.ServerDTLS {
|
||||
cSess.IpAuditMap = utils.NewMap("cmap", 0)
|
||||
} else {
|
||||
cSess.IpAuditMap = utils.NewMap("", 512)
|
||||
}
|
||||
}
|
||||
|
||||
dSess := &DtlsSession{
|
||||
|
|
Loading…
Reference in New Issue