From f342b123725472cda0769b993461eacba4a5799d Mon Sep 17 00:00:00 2001 From: bjdgyc Date: Wed, 26 May 2021 19:13:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9dtlssession=E7=9A=84=E5=AD=98?= =?UTF-8?q?=E5=82=A8=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/sessdata/session.go | 52 +++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/server/sessdata/session.go b/server/sessdata/session.go index 260fea3..1988a8b 100644 --- a/server/sessdata/session.go +++ b/server/sessdata/session.go @@ -50,12 +50,12 @@ type ConnSession struct { PayloadOutCstp chan *Payload // Cstp的数据 PayloadOutDtls chan *Payload // Dtls的数据 - mux sync.RWMutex - dSess *DtlsSession // Dtls Session - // DSess *atomic.Value + // dSess *DtlsSession + dSess *atomic.Value } type DtlsSession struct { + isActive int32 CSess *ConnSession CloseChan chan struct{} closeOnce sync.Once @@ -169,6 +169,14 @@ func (s *Session) NewConn() *ConnSession { return nil } + // 查询group信息 + group := &dbdata.Group{} + err = dbdata.One("Name", s.Group, group) + if err != nil { + base.Error(err) + return nil + } + cSess := &ConnSession{ Sess: s, MacHw: macHw, @@ -178,16 +186,17 @@ func (s *Session) NewConn() *ConnSession { PayloadIn: make(chan *Payload), PayloadOutCstp: make(chan *Payload), PayloadOutDtls: make(chan *Payload), + dSess: &atomic.Value{}, } - // 查询group信息 - group := &dbdata.Group{} - err = dbdata.One("Name", s.Group, group) - if err != nil { - base.Error(err) - cSess.Close() - return nil + dSess := &DtlsSession{ + isActive: -1, + CSess: cSess, + CloseChan: make(chan struct{}), + closeOnce: sync.Once{}, } + cSess.dSess.Store(dSess) + cSess.Group = group if group.Bandwidth > 0 { // 限流设置 @@ -222,20 +231,20 @@ func (cs *ConnSession) Close() { // 创建dtls链接 func (cs *ConnSession) NewDtlsConn() *DtlsSession { - cs.mux.Lock() - defer cs.mux.Unlock() - - if cs.dSess != nil { + ds := cs.dSess.Load().(*DtlsSession) + isActive := atomic.LoadInt32(&ds.isActive) + if isActive > 0 { // 判断原有连接存在,不进行创建 return nil } dSess := &DtlsSession{ + isActive: 1, CSess: cs, CloseChan: make(chan struct{}), closeOnce: sync.Once{}, } - cs.dSess = dSess + cs.dSess.Store(dSess) return dSess } @@ -243,18 +252,19 @@ func (cs *ConnSession) NewDtlsConn() *DtlsSession { func (ds *DtlsSession) Close() { ds.closeOnce.Do(func() { base.Info("closeOnce dtls:", ds.CSess.IpAddr) - ds.CSess.mux.Lock() - defer ds.CSess.mux.Unlock() + atomic.StoreInt32(&ds.isActive, -1) close(ds.CloseChan) - ds.CSess.dSess = nil }) } func (cs *ConnSession) GetDtlsSession() *DtlsSession { - cs.mux.RLock() - defer cs.mux.RUnlock() - return cs.dSess + ds := cs.dSess.Load().(*DtlsSession) + isActive := atomic.LoadInt32(&ds.isActive) + if isActive > 0 { + return ds + } + return nil } const BandwidthPeriodSec = 2 // 流量速率统计周期(秒)