mirror of https://github.com/bjdgyc/anylink.git
Merge pull request #124 from lanrenwo/ip_audit_map_concurrent
解决IpAuditMap在UDP下的fatal error: concurrent map read and map write
This commit is contained in:
commit
e72994c0b8
|
@ -13,6 +13,7 @@ require (
|
||||||
github.com/gorilla/mux v1.8.0
|
github.com/gorilla/mux v1.8.0
|
||||||
github.com/lib/pq v1.10.2
|
github.com/lib/pq v1.10.2
|
||||||
github.com/mattn/go-sqlite3 v1.14.8
|
github.com/mattn/go-sqlite3 v1.14.8
|
||||||
|
github.com/orcaman/concurrent-map v1.0.0 // indirect
|
||||||
github.com/pion/dtls/v2 v2.0.9
|
github.com/pion/dtls/v2 v2.0.9
|
||||||
github.com/pion/logging v0.2.2
|
github.com/pion/logging v0.2.2
|
||||||
github.com/shirou/gopsutil v3.21.7+incompatible
|
github.com/shirou/gopsutil v3.21.7+incompatible
|
||||||
|
|
|
@ -149,14 +149,14 @@ func logAudit(cSess *sessdata.ConnSession, pl *sessdata.Payload) {
|
||||||
nu := utils.NowSec().Unix()
|
nu := utils.NowSec().Unix()
|
||||||
|
|
||||||
// 判断已经存在,并且没有过期
|
// 判断已经存在,并且没有过期
|
||||||
v, ok := cSess.IpAuditMap[s]
|
v, ok := cSess.IpAuditMap.Get(s)
|
||||||
if ok && nu-v < int64(base.Cfg.AuditInterval) {
|
if ok && nu-v.(int64) < int64(base.Cfg.AuditInterval) {
|
||||||
// 回收byte对象
|
// 回收byte对象
|
||||||
putByte51(b)
|
putByte51(b)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cSess.IpAuditMap[s] = nu
|
cSess.IpAuditMap.Set(s, nu)
|
||||||
|
|
||||||
audit := dbdata.AccessAudit{
|
audit := dbdata.AccessAudit{
|
||||||
Username: cSess.Sess.Username,
|
Username: cSess.Sess.Username,
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
cmap "github.com/orcaman/concurrent-map"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IMaps interface {
|
||||||
|
Set(key string, val interface{})
|
||||||
|
Get(key string) (interface{}, bool)
|
||||||
|
Del(key string)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基础的Map结构
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
type BaseMap struct {
|
||||||
|
m map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *BaseMap) Set(key string, value interface{}) {
|
||||||
|
m.m[key] = value
|
||||||
|
}
|
||||||
|
func (m *BaseMap) Get(key string) (interface{}, bool) {
|
||||||
|
v, ok := m.m[key]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
func (m *BaseMap) Del(key string) {
|
||||||
|
delete(m.m, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CMap 并发结构
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
type ConcurrentMap struct {
|
||||||
|
m cmap.ConcurrentMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConcurrentMap) Set(key string, value interface{}) {
|
||||||
|
m.m.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||||
|
return m.m.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConcurrentMap) Del(key string) {
|
||||||
|
m.m.Remove(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map 读写结构
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
type RWLockMap struct {
|
||||||
|
m map[string]interface{}
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RWLockMap) Set(key string, value interface{}) {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
m.m[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RWLockMap) Get(key string) (interface{}, bool) {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
v, ok := m.m[key]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RWLockMap) Del(key string) {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
delete(m.m, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* sync.Map 结构
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
type SyncMap struct {
|
||||||
|
m sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SyncMap) Set(key string, val interface{}) {
|
||||||
|
m.m.Store(key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SyncMap) Get(key string) (interface{}, bool) {
|
||||||
|
return m.m.Load(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SyncMap) Del(key string) {
|
||||||
|
m.m.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMap(name string, len int) IMaps {
|
||||||
|
switch name {
|
||||||
|
case "cmap":
|
||||||
|
return &ConcurrentMap{m: cmap.New()}
|
||||||
|
case "rwmap":
|
||||||
|
m := make(map[string]interface{}, len)
|
||||||
|
return &RWLockMap{m: m}
|
||||||
|
case "syncmap":
|
||||||
|
return &SyncMap{}
|
||||||
|
default:
|
||||||
|
m := make(map[string]interface{}, len)
|
||||||
|
return &BaseMap{m: m}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NumOfReader = 200
|
||||||
|
NumOfWriter = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaps(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
var ipAuditMap IMaps
|
||||||
|
key := "one"
|
||||||
|
value := 100
|
||||||
|
|
||||||
|
testMapData := map[string]int{"basemap": 512, "cmap": 0, "rwmap": 512, "syncmap": 0}
|
||||||
|
for name, len := range testMapData {
|
||||||
|
ipAuditMap = NewMap(name, len)
|
||||||
|
ipAuditMap.Set(key, value)
|
||||||
|
v, ok := ipAuditMap.Get(key)
|
||||||
|
assert.Equal(v.(int), value)
|
||||||
|
assert.True(ok)
|
||||||
|
ipAuditMap.Del(key)
|
||||||
|
v, ok = ipAuditMap.Get(key)
|
||||||
|
assert.Nil(v)
|
||||||
|
assert.False(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkMap(b *testing.B, hm IMaps) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < NumOfWriter; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
hm.Set(strconv.Itoa(i), i*i)
|
||||||
|
hm.Set(strconv.Itoa(i), i*i)
|
||||||
|
hm.Del(strconv.Itoa(i))
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for i := 0; i < NumOfReader; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
hm.Get(strconv.Itoa(i))
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMaps(b *testing.B) {
|
||||||
|
b.Run("RW map", func(b *testing.B) {
|
||||||
|
myMap := NewMap("rwmap", 512)
|
||||||
|
benchmarkMap(b, myMap)
|
||||||
|
})
|
||||||
|
b.Run("Concurrent map", func(b *testing.B) {
|
||||||
|
myMap := NewMap("cmap", 0)
|
||||||
|
benchmarkMap(b, myMap)
|
||||||
|
})
|
||||||
|
b.Run("Sync map", func(b *testing.B) {
|
||||||
|
myMap := NewMap("syncmap", 0)
|
||||||
|
benchmarkMap(b, myMap)
|
||||||
|
})
|
||||||
|
}
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/bjdgyc/anylink/base"
|
"github.com/bjdgyc/anylink/base"
|
||||||
"github.com/bjdgyc/anylink/dbdata"
|
"github.com/bjdgyc/anylink/dbdata"
|
||||||
|
"github.com/bjdgyc/anylink/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -48,7 +49,7 @@ type ConnSession struct {
|
||||||
PayloadIn chan *Payload
|
PayloadIn chan *Payload
|
||||||
PayloadOutCstp chan *Payload // Cstp的数据
|
PayloadOutCstp chan *Payload // Cstp的数据
|
||||||
PayloadOutDtls chan *Payload // Dtls的数据
|
PayloadOutDtls chan *Payload // Dtls的数据
|
||||||
IpAuditMap map[string]int64 // 审计的ip数据
|
IpAuditMap utils.IMaps // 审计的ip数据
|
||||||
|
|
||||||
// dSess *DtlsSession
|
// dSess *DtlsSession
|
||||||
dSess *atomic.Value
|
dSess *atomic.Value
|
||||||
|
@ -191,7 +192,11 @@ func (s *Session) NewConn() *ConnSession {
|
||||||
|
|
||||||
// ip 审计
|
// ip 审计
|
||||||
if base.Cfg.AuditInterval >= 0 {
|
if base.Cfg.AuditInterval >= 0 {
|
||||||
cSess.IpAuditMap = make(map[string]int64, 512)
|
if base.Cfg.ServerDTLS {
|
||||||
|
cSess.IpAuditMap = utils.NewMap("cmap", 0)
|
||||||
|
} else {
|
||||||
|
cSess.IpAuditMap = utils.NewMap("", 512)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dSess := &DtlsSession{
|
dSess := &DtlsSession{
|
||||||
|
|
Loading…
Reference in New Issue