From 8f8c8442a7ea7c2a2bebcb6c90a8a60e25af9227 Mon Sep 17 00:00:00 2001 From: hebo Date: Wed, 21 Aug 2019 22:52:43 +0800 Subject: [PATCH] add more info --- capture/config.go | 2 + capture/network.go | 96 +++++++------- session-dealer/controller.go | 8 +- session-dealer/model.go | 3 + session-dealer/mysql/model.go | 48 +++++++ session-dealer/mysql/session.go | 219 +++++++++++++++++++++++--------- 6 files changed, 270 insertions(+), 106 deletions(-) diff --git a/capture/config.go b/capture/config.go index c466107..5fac129 100644 --- a/capture/config.go +++ b/capture/config.go @@ -3,12 +3,14 @@ package capture import ( sd "github.com/zr-hebo/sniffer-agent/session-dealer" log "github.com/sirupsen/logrus" + "sync" ) var ( localIPAddr *string sessionPool = make(map[string]sd.ConnSession) + sessionPoolLock sync.Mutex ) func init() { diff --git a/capture/network.go b/capture/network.go index 5972470..e66e026 100644 --- a/capture/network.go +++ b/capture/network.go @@ -31,11 +31,16 @@ func init() { type networkCard struct { name string listenPort int + receiver chan model.QueryPiece } func NewNetworkCard() (nc *networkCard) { // init device - return &networkCard{name: DeviceName, listenPort: snifferPort} + return &networkCard{ + name: DeviceName, + listenPort: snifferPort, + receiver: make(chan model.QueryPiece, 100), + } } func initEthernetHandlerFromPacpgo() (handler *pcapgo.EthernetHandle) { @@ -81,7 +86,6 @@ func initEthernetHandlerFromPacp() (handler *pcap.Handle) { if err != nil { panic(err.Error()) } - handler.SnapLen() return } @@ -96,10 +100,8 @@ func (nc *networkCard) Listen() (receiver chan model.QueryPiece) { // Listen get a connection. func (nc *networkCard) listenNormal() (receiver chan model.QueryPiece) { - receiver = make(chan model.QueryPiece, 100) - go func() { - handler := initEthernetHandlerFromPacpgo() + handler := initEthernetHandlerFromPacp() for { var data []byte data, ci, err := handler.ZeroCopyReadPacketData() @@ -113,11 +115,7 @@ func (nc *networkCard) listenNormal() (receiver chan model.QueryPiece) { m := packet.Metadata() m.CaptureInfo = ci m.Truncated = m.Truncated || ci.CaptureLength < ci.Length - - qp := nc.parseTCPPackage(packet) - if qp != nil { - receiver <- qp - } + nc.parseTCPPackage(packet) } }() @@ -126,8 +124,13 @@ func (nc *networkCard) listenNormal() (receiver chan model.QueryPiece) { // Listen get a connection. func (nc *networkCard) listenInParallel() (receiver chan model.QueryPiece) { - receiver = make(chan model.QueryPiece, 100) - packageChan := make(chan gopacket.Packet, 10) + type captureInfo struct { + bytes []byte + captureInfo gopacket.CaptureInfo + } + + rawDataChan := make(chan *captureInfo, 20) + packageChan := make(chan gopacket.Packet, 20) // read packet go func() { @@ -146,33 +149,33 @@ func (nc *networkCard) listenInParallel() (receiver chan model.QueryPiece) { continue } - packet := gopacket.NewPacket(data, layers.LayerTypeEthernet, gopacket.NoCopy) - m := packet.Metadata() - m.CaptureInfo = ci - m.Truncated = m.Truncated || ci.CaptureLength < ci.Length - - packageChan <- packet + rawDataChan <- &captureInfo{ + bytes: data, + captureInfo: ci, + } } }() - // deal packet + // parse package go func() { defer func() { close(receiver) }() - for packet := range packageChan { - qp := nc.parseTCPPackage(packet) - if qp != nil { - receiver <- qp - } + for captureInfo := range rawDataChan { + packet := gopacket.NewPacket(captureInfo.bytes, layers.LayerTypeEthernet, gopacket.NoCopy) + m := packet.Metadata() + m.CaptureInfo = captureInfo.captureInfo + m.Truncated = m.Truncated || captureInfo.captureInfo.CaptureLength < captureInfo.captureInfo.Length + + nc.parseTCPPackage(packet) } }() return } -func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPiece) { +func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) { var err error defer func() { if err != nil { @@ -180,8 +183,8 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPi } }() - tcpConn := packet.TransportLayer().(*layers.TCP) - if tcpConn.SYN || tcpConn.RST { + tcpPkt := packet.TransportLayer().(*layers.TCP) + if tcpPkt.SYN || tcpPkt.RST { return } @@ -200,18 +203,18 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPi // get IP from ip layer srcIP := ipInfo.SrcIP.String() dstIP := ipInfo.DstIP.String() - srcPort := int(tcpConn.SrcPort) - dstPort := int(tcpConn.DstPort) + srcPort := int(tcpPkt.SrcPort) + dstPort := int(tcpPkt.DstPort) if dstPort == nc.listenPort { // deal mysql server response - err = readToServerPackage(&srcIP, srcPort, tcpConn) + err = readToServerPackage(&srcIP, srcPort, tcpPkt, nc.receiver) if err != nil { return } } else if srcPort == nc.listenPort { // deal mysql client request - qp, err = readFromServerPackage(&dstIP, dstPort, tcpConn) + err = readFromServerPackage(&dstIP, dstPort, tcpPkt) if err != nil { return } @@ -220,38 +223,40 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPi return } -func readFromServerPackage(srcIP *string, srcPort int, tcpConn *layers.TCP) (qp model.QueryPiece, err error) { +func readFromServerPackage( + srcIP *string, srcPort int, tcpPkt *layers.TCP) (err error) { defer func() { if err != nil { log.Error("read Mysql package send from mysql server to client failed <-- %s", err.Error()) } }() - if tcpConn.FIN { + if tcpPkt.FIN { sessionKey := spliceSessionKey(srcIP, srcPort) delete(sessionPool, *sessionKey) log.Debugf("close connection from %s", *sessionKey) return } - tcpPayload := tcpConn.Payload + tcpPayload := tcpPkt.Payload if (len(tcpPayload) < 1) { return } - _ = tcpConn.Seq - sessionKey := spliceSessionKey(srcIP, srcPort) session := sessionPool[*sessionKey] if session != nil { - session.ReadFromServer(tcpPayload) - qp = session.GenerateQueryPiece() + // session.ReadFromServer(tcpPayload) + // qp = session.GenerateQueryPiece() + pkt := model.NewTCPPacket(tcpPayload, int64(tcpPkt.Ack), false) + session.ReceiveTCPPacket(pkt) } return } -func readToServerPackage(srcIP *string, srcPort int, tcpConn *layers.TCP) (err error) { +func readToServerPackage( + srcIP *string, srcPort int, tcpPkt *layers.TCP, receiver chan model.QueryPiece) (err error) { defer func() { if err != nil { log.Error("read package send from client to mysql server failed <-- %s", err.Error()) @@ -259,14 +264,14 @@ func readToServerPackage(srcIP *string, srcPort int, tcpConn *layers.TCP) (err e }() // when client try close connection remove session from session pool - if tcpConn.FIN { + if tcpPkt.FIN { sessionKey := spliceSessionKey(srcIP, srcPort) delete(sessionPool, *sessionKey) log.Debugf("close connection from %s", *sessionKey) return } - tcpPayload := tcpConn.Payload + tcpPayload := tcpPkt.Payload if (len(tcpPayload) < 1) { return } @@ -274,12 +279,15 @@ func readToServerPackage(srcIP *string, srcPort int, tcpConn *layers.TCP) (err e sessionKey := spliceSessionKey(srcIP, srcPort) session := sessionPool[*sessionKey] if session == nil { - session = sd.NewSession(sessionKey, srcIP, srcPort, localIPAddr, snifferPort) + session = sd.NewSession(sessionKey, srcIP, srcPort, localIPAddr, snifferPort, receiver) sessionPool[*sessionKey] = session } - session.ResetBeginTime() - session.ReadFromClient(int64(tcpConn.Seq), tcpPayload) + pkt := model.NewTCPPacket(tcpPayload, int64(tcpPkt.Seq), true) + session.ReceiveTCPPacket(pkt) + + // session.ResetBeginTime() + // session.ReadFromClient(int64(tcpPkt.Seq), tcpPayload) return } diff --git a/session-dealer/controller.go b/session-dealer/controller.go index 963233c..72b7572 100644 --- a/session-dealer/controller.go +++ b/session-dealer/controller.go @@ -1,15 +1,17 @@ package session_dealer import ( + "github.com/zr-hebo/sniffer-agent/model" "github.com/zr-hebo/sniffer-agent/session-dealer/mysql" ) -func NewSession(sessionKey *string, clientIP *string, clientPort int, serverIP *string, serverPort int) (session ConnSession) { +func NewSession(sessionKey *string, clientIP *string, clientPort int, serverIP *string, serverPort int, + receiver chan model.QueryPiece) (session ConnSession) { switch serviceType { case ServiceTypeMysql: - session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort) + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort, receiver) default: - session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort) + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort, receiver) } return } diff --git a/session-dealer/model.go b/session-dealer/model.go index bdbb9ac..c023565 100644 --- a/session-dealer/model.go +++ b/session-dealer/model.go @@ -7,4 +7,7 @@ type ConnSession interface { ReadFromServer(bytes []byte) ResetBeginTime() GenerateQueryPiece() (qp model.QueryPiece) + + ReceiveTCPPacket(*model.TCPPacket) + Stop() } diff --git a/session-dealer/mysql/model.go b/session-dealer/mysql/model.go index d47e30a..b8eb78b 100644 --- a/session-dealer/mysql/model.go +++ b/session-dealer/mysql/model.go @@ -1,5 +1,9 @@ package mysql +import ( + log "github.com/sirupsen/logrus" +) + type handshakeResponse41 struct { Capability uint32 Collation uint8 @@ -13,3 +17,47 @@ type jigsaw struct { b int64 e int64 } + +type packageWindowCounter struct { + sizeCount map[int]int64 + visitCount int64 + suggestSize int +} + +func newPackageWindowCounter() *packageWindowCounter { + return &packageWindowCounter{ + sizeCount: make(map[int]int64, 4), + suggestSize: 512, + } +} + +func (pwc *packageWindowCounter) refresh (readSize int, isLastPackage bool) { + if pwc.visitCount > 10000 { + return + } + + log.Debugf("WindowCounter: %#v", pwc.sizeCount) + pwc.visitCount += 1 + miniMatchSize := maxIPPackageSize + for size := range pwc.sizeCount { + if readSize % size == 0 && miniMatchSize > size { + miniMatchSize = size + } + } + if miniMatchSize < maxIPPackageSize { + pwc.sizeCount[miniMatchSize] = pwc.sizeCount[miniMatchSize] + 1 + } else if !isLastPackage { + pwc.sizeCount[readSize] = 1 + } + + mostFrequentSize := pwc.suggestSize + mostFrequentCount := int64(0) + for size, count := range pwc.sizeCount { + if count > mostFrequentCount { + mostFrequentSize = size + mostFrequentCount = count + } + } + + pwc.suggestSize = mostFrequentSize +} \ No newline at end of file diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go index 2c43ab2..a9b5291 100644 --- a/session-dealer/mysql/session.go +++ b/session-dealer/mysql/session.go @@ -2,6 +2,7 @@ package mysql import ( "fmt" + "strings" "time" "github.com/siddontang/go/hack" @@ -22,25 +23,37 @@ type MysqlSession struct { packageOffset int64 expectReceiveSize int coverRanges []*jigsaw - tcpWindowSize int expectSendSize int prepareInfo *prepareInfo - sizeCount map[int]int64 cachedPrepareStmt map[int]*string cachedStmtBytes []byte computeWindowSizeCounter int + + tcpPacketCache []*model.TCPPacket + + queryPieceReceiver chan model.QueryPiece + lastSeq int64 + keepAlive chan bool + + ackID int64 + sendSize int64 } type prepareInfo struct { prepareStmtID int } +var ( + windowSizeCache = make(map[string]*packageWindowCounter, 16) +) + const ( - defaultCacheSize = 1 << 16 maxIPPackageSize = 1 << 16 ) -func NewMysqlSession(sessionKey *string, clientIP *string, clientPort int, serverIP *string, serverPort int) (ms *MysqlSession) { +func NewMysqlSession( + sessionKey *string, clientIP *string, clientPort int, serverIP *string, serverPort int, + receiver chan model.QueryPiece) (ms *MysqlSession) { ms = &MysqlSession{ connectionID: sessionKey, clientHost: clientIP, @@ -49,14 +62,106 @@ func NewMysqlSession(sessionKey *string, clientIP *string, clientPort int, serve serverPort: serverPort, stmtBeginTime: time.Now().UnixNano() / millSecondUnit, cachedPrepareStmt: make(map[int]*string, 8), + coverRanges: make([]*jigsaw, 0, 4), + queryPieceReceiver: receiver, + keepAlive: make(chan bool, 1), + lastSeq: -1, + ackID: -1, + sendSize: 0, } - ms.tcpWindowSize = 512 - ms.coverRanges = make([]*jigsaw, 0, 4) - ms.sizeCount = make(map[int]int64) + + go ms.haha() return } +func (ms *MysqlSession) Stop() { + ms.keepAlive <- false +} + +func (ms *MysqlSession) haha() { + + for true { + // select { + // case <- ms.keepAlive: + // return + // default: + // } + + if len(ms.tcpPacketCache) < 1 { + // log.Debugf("there are %d packages in tcp packet cache", ) + time.Sleep(1) + continue + } + + beginIdx := -1 + if ms.lastSeq < 0 { + ms.lastSeq = ms.tcpPacketCache[0].Seq + } + for idx := 0; idx < len(ms.tcpPacketCache); idx++ { + pkt := ms.tcpPacketCache[idx] + if ms.lastSeq == pkt.Seq { + beginIdx = idx + ms.lastSeq = pkt.Seq + int64(len(pkt.Payload)) + + } else { + break + } + } + + + if beginIdx < 0 { + return + } + + inOrderPkgs := ms.tcpPacketCache[:beginIdx+1] + if beginIdx == len(ms.tcpPacketCache) - 1 { + ms.tcpPacketCache = make([]*model.TCPPacket, 0, 4) + } else { + ms.tcpPacketCache = ms.tcpPacketCache[beginIdx+1:] + } + + for _, pkg := range inOrderPkgs { + if pkg.ToServer { + ms.ReadFromClient(pkg.Seq, pkg.Payload) + + } else { + ms.ReadFromServer(pkg.Payload) + ms.queryPieceReceiver <- ms.GenerateQueryPiece() + } + } + } +} + +func (ms *MysqlSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { + if !newPkt.ToServer && ms.ackID + ms.sendSize == newPkt.Seq { + // ignore to response to client data + ms.ackID = ms.ackID + newPkt.Seq + ms.sendSize = ms.sendSize + int64(len(newPkt.Payload)) + return + + } else if !newPkt.ToServer { + ms.ackID = newPkt.Seq + ms.sendSize = int64(len(newPkt.Payload)) + } + + insertIdx := len(ms.tcpPacketCache) + for idx, pkt := range ms.tcpPacketCache { + if pkt.Seq > newPkt.Seq { + insertIdx = idx + } + } + + if insertIdx == len(ms.tcpPacketCache) { + ms.tcpPacketCache = append(ms.tcpPacketCache, newPkt) + } else { + newCache := make([]*model.TCPPacket, len(ms.tcpPacketCache)+1) + copy(newCache[:insertIdx], ms.tcpPacketCache[:insertIdx]) + newCache[insertIdx] = newPkt + copy(newCache[insertIdx+1:], ms.tcpPacketCache[insertIdx:]) + } +} + func (ms *MysqlSession) ResetBeginTime() { ms.stmtBeginTime = time.Now().UnixNano() / millSecondUnit } @@ -74,19 +179,24 @@ func (ms *MysqlSession) ReadFromServer(bytes []byte) { func (ms *MysqlSession) mergeRanges() { if len(ms.coverRanges) > 1 { newRange, newPkgRanges := mergeRanges(ms.coverRanges[0], ms.coverRanges[1:]) - newPkgRanges = append(newPkgRanges, newRange) - ms.coverRanges = newPkgRanges + tmpRanges := make([]*jigsaw, len(newPkgRanges)+1) + tmpRanges[0] = newRange + if len(newPkgRanges) > 0 { + copy(tmpRanges[1:], newPkgRanges) + } + ms.coverRanges = tmpRanges } } func mergeRanges(currRange *jigsaw, pkgRanges []*jigsaw) (mergedRange *jigsaw, newPkgRanges []*jigsaw) { var nextRange *jigsaw + newPkgRanges = make([]*jigsaw, 0, 4) + if len(pkgRanges) < 1 { - return currRange, make([]*jigsaw, 0) + return currRange, newPkgRanges } else if len(pkgRanges) == 1 { nextRange = pkgRanges[0] - newPkgRanges = make([]*jigsaw, 0, 4) } else { nextRange, newPkgRanges = mergeRanges(pkgRanges[0], pkgRanges[1:]) @@ -96,9 +206,15 @@ func mergeRanges(currRange *jigsaw, pkgRanges []*jigsaw) (mergedRange *jigsaw, n mergedRange = &jigsaw{b: currRange.b, e: nextRange.e} } else { - newPkgRanges = append(newPkgRanges, nextRange) + tmpRanges := make([]*jigsaw, len(newPkgRanges)+1) + tmpRanges[0] = nextRange + if len(newPkgRanges) > 0 { + copy(tmpRanges[1:], newPkgRanges) + } + newPkgRanges = tmpRanges mergedRange = currRange } + return } @@ -112,7 +228,16 @@ func (ms *MysqlSession) oneMysqlPackageFinish() bool { func (ms *MysqlSession) checkFinish() bool { if len(ms.coverRanges) != 1 { - return true + ranges := make([]string, 0, len(ms.coverRanges)) + for _, cr := range ms.coverRanges { + log.Errorf("miss values: %s", string(ms.cachedStmtBytes[cr.b-ms.beginSeqID: cr.e-ms.beginSeqID])) + + ranges = append(ranges, fmt.Sprintf("[%d -- %d]", cr.b, cr.e)) + } + + + log.Errorf("in session %s get invalid range: %s", *ms.connectionID, strings.Join(ranges, ", ")) + return false } firstRange := ms.coverRanges[0] @@ -124,7 +249,6 @@ func (ms *MysqlSession) checkFinish() bool { } func (ms *MysqlSession) ReadFromClient(seqID int64, bytes []byte) { - readSize := len(bytes) contentSize := int64(len(bytes)) if ms.expectReceiveSize == 0 || ms.oneMysqlPackageFinish() { @@ -143,12 +267,18 @@ func (ms *MysqlSession) ReadFromClient(seqID int64, bytes []byte) { if len(ms.cachedStmtBytes) > 0 { copy(newCache[:len(ms.cachedStmtBytes)], ms.cachedStmtBytes) } - copy(newCache[ms.packageOffset:ms.packageOffset+int64(len(contents))], contents) - ms.cachedStmtBytes = newCache + + if int64(ms.expectReceiveSize+len(ms.cachedStmtBytes)) > ms.packageOffset+int64(len(contents)) { + copy(newCache[ms.packageOffset:ms.packageOffset+int64(len(contents))], contents) + ms.cachedStmtBytes = newCache + } else { + log.Debugf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXxxx") + return + } } else { if seqID < ms.beginSeqID { - log.Debugf("outdate package with Seq:%d", seqID) + log.Debugf("in session %s get outdate package with Seq:%d", *ms.connectionID, seqID) return } @@ -158,8 +288,6 @@ func (ms *MysqlSession) ReadFromClient(seqID int64, bytes []byte) { } } - ms.refreshWindowSize(readSize) - insertIdx := len(ms.coverRanges) for idx, cr := range ms.coverRanges { if seqID < cr.b { @@ -168,8 +296,8 @@ func (ms *MysqlSession) ReadFromClient(seqID int64, bytes []byte) { } } - cr := &jigsaw{b: seqID, e: seqID+int64(contentSize)} - if insertIdx == len(ms.coverRanges) - 1 { + cr := &jigsaw{b: seqID, e: seqID+contentSize} + if len(ms.coverRanges) < 1 || insertIdx == len(ms.coverRanges) { ms.coverRanges = append(ms.coverRanges, cr) } else { @@ -179,57 +307,31 @@ func (ms *MysqlSession) ReadFromClient(seqID int64, bytes []byte) { copy(newCoverRanges[insertIdx+1:], ms.coverRanges[insertIdx:]) ms.coverRanges = newCoverRanges } - ms.mergeRanges() + ms.mergeRanges() } func (ms *MysqlSession) refreshWindowSize(readSize int) { - if ms.computeWindowSizeCounter > 5000 { - return + windowCounter := windowSizeCache[*ms.clientHost] + if windowCounter == nil { + windowCounter = newPackageWindowCounter() + windowSizeCache[*ms.clientHost] = windowCounter } - log.Debugf("sizeCount: %#v", ms.sizeCount) - - ms.computeWindowSizeCounter += 1 - miniMatchSize := maxIPPackageSize - for size := range ms.sizeCount { - if readSize % size == 0 && miniMatchSize > size { - miniMatchSize = size - } - } - if miniMatchSize < maxIPPackageSize { - ms.sizeCount[miniMatchSize] = ms.sizeCount[miniMatchSize] + 1 - } else if ms.checkFinish() { - ms.sizeCount[readSize] = 1 - } - - mostFrequentSize := ms.tcpWindowSize - miniSize := ms.tcpWindowSize - mostFrequentCount := int64(0) - for size, count := range ms.sizeCount { - if count > mostFrequentCount { - mostFrequentSize = size - mostFrequentCount = count - } - - if miniSize > size { - miniSize = size - } - } - - ms.tcpWindowSize = mostFrequentSize + // windowCounter.refresh(readSize, ms.checkFinish()) + // ms.tcpWindowSize = windowCounter.suggestSize } - func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { defer func() { - // ms.tcpCache = ms.tcpCache[0:0] ms.cachedStmtBytes = nil ms.expectReceiveSize = 0 ms.expectSendSize = 0 ms.prepareInfo = nil ms.coverRanges = make([]*jigsaw, 0, 4) - // ms.packageComplete = false + ms.lastSeq = -1 + ms.ackID = -1 + ms.sendSize = 0 }() if len(ms.cachedStmtBytes) < 1 { @@ -238,7 +340,7 @@ func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { // fmt.Printf("packageComplete in generate: %v\n", ms.packageComplete) if !ms.checkFinish() { - log.Errorf("is not a complete cover") + log.Errorf("receive a not complete cover") return } @@ -323,7 +425,6 @@ func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) (m mqp.SetNeedSyncSend(true) } - // log.Debug(mqp.String()) return mqp }