diff --git a/capture/network.go b/capture/network.go index 62e2d5e..c8c9797 100644 --- a/capture/network.go +++ b/capture/network.go @@ -267,8 +267,6 @@ func readFromServerPackage( sessionKey := spliceSessionKey(srcIP, srcPort) session := sessionPool[*sessionKey] if session != nil { - // session.readFromServer(tcpPayload) - // qp = session.GenerateQueryPiece() pkt := model.NewTCPPacket(tcpPayload, int64(tcpPkt.Ack), false) session.ReceiveTCPPacket(pkt) } diff --git a/session-dealer/mysql/const.go b/session-dealer/mysql/const.go index a6ac377..98845f0 100644 --- a/session-dealer/mysql/const.go +++ b/session-dealer/mysql/const.go @@ -96,7 +96,7 @@ const ( // See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html const ( // MaxMysqlPacketLen is the max packet payload length. - MaxMysqlPacketLen = 1<<24 - 1 + MaxMysqlPacketLen = 4 * 1024 * 1024 ) const ( diff --git a/session-dealer/mysql/model.go b/session-dealer/mysql/model.go index a02bd6b..53d7ef5 100644 --- a/session-dealer/mysql/model.go +++ b/session-dealer/mysql/model.go @@ -8,8 +8,15 @@ type handshakeResponse41 struct { Auth []byte } -// receiveRange record mysql package begin and end seq id -type receiveRange struct { - beginSeqID int64 - endSeqID int64 +// jigsaw record tcp package begin and end seq id +type jigsaw struct { + begin int64 + end int64 } + +func newJigsaw(begin, end int64) (js *jigsaw) { + return &jigsaw{ + begin: begin, + end: end, + } +} \ No newline at end of file diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go index 3447a34..e350f2d 100644 --- a/session-dealer/mysql/session.go +++ b/session-dealer/mysql/session.go @@ -1,15 +1,15 @@ package mysql import ( + "container/list" "fmt" - "github.com/zr-hebo/sniffer-agent/communicator" - // "strings" "sync" "time" - log "github.com/sirupsen/logrus" "github.com/pingcap/tidb/util/hack" + log "github.com/sirupsen/logrus" + "github.com/zr-hebo/sniffer-agent/communicator" "github.com/zr-hebo/sniffer-agent/model" ) @@ -22,13 +22,11 @@ type MysqlSession struct { serverIP *string serverPort int stmtBeginTime int64 - packageOffset int64 + // packageOffset int64 + beginSeqID int64 + endSeqID int64 + coverRanges *list.List expectReceiveSize int - // coverRanges []*receiveRange - // coverRange *receiveRange - beginSeqID int64 - expectSeqID int64 - expectSendSize int prepareInfo *prepareInfo cachedPrepareStmt map[int][]byte @@ -63,7 +61,7 @@ func NewMysqlSession( queryPieceReceiver: receiver, closeConn: make(chan bool, 1), expectReceiveSize: -1, - expectSeqID: -1, + coverRanges: list.New(), ignoreAckID: -1, sendSize: 0, pkgCacheLock: sync.Mutex{}, @@ -97,6 +95,7 @@ func (ms *MysqlSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { ms.queryPieceReceiver <- qp } } + } func (ms *MysqlSession) resetBeginTime() { @@ -113,16 +112,39 @@ func (ms *MysqlSession) readFromServer(bytes []byte) { } } -func (ms *MysqlSession) oneMysqlPackageFinish() bool { - if int64(len(ms.cachedStmtBytes))%MaxMysqlPacketLen == 0 { - return true +func mergeRanges(currRange *jigsaw, pkgRanges []*jigsaw) (mergedRange *jigsaw, newPkgRanges []*jigsaw) { + var nextRange *jigsaw + newPkgRanges = make([]*jigsaw, 0, 4) + + if len(pkgRanges) < 1 { + return currRange, newPkgRanges + + } else if len(pkgRanges) == 1 { + nextRange = pkgRanges[0] + + } else { + nextRange, newPkgRanges = mergeRanges(pkgRanges[0], pkgRanges[1:]) } - return false + if currRange.end >= nextRange.begin { + mergedRange = &jigsaw{begin: currRange.begin, end: nextRange.end} + + } else { + tmpRanges := make([]*jigsaw, len(newPkgRanges)+1) + tmpRanges[0] = nextRange + if len(newPkgRanges) > 0 { + copy(tmpRanges[1:], newPkgRanges) + } + newPkgRanges = tmpRanges + mergedRange = currRange + } + + return } func (ms *MysqlSession) checkFinish() bool { - if ms.beginSeqID != -1 && ms.expectReceiveSize == 0 { + ms.mergeRange() + if ms.endSeqID - ms.beginSeqID == int64(len(ms.cachedStmtBytes)) && ms.expectReceiveSize == 0 { return true } @@ -135,19 +157,24 @@ func (ms *MysqlSession) clear() { ms.expectSendSize = -1 ms.prepareInfo = nil ms.beginSeqID = -1 - ms.expectSeqID = -1 + ms.endSeqID = -1 ms.ignoreAckID = -1 ms.sendSize = 0 + ms.coverRanges = list.New() } func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { contentSize := int64(len(bytes)) - if ms.expectReceiveSize == -1 || ms.oneMysqlPackageFinish() { + if ms.expectReceiveSize == -1 { ms.expectReceiveSize = extractMysqlPayloadSize(bytes[:4]) - // ms.packageOffset = int64(len(ms.cachedStmtBytes)) + // ignore too big mysql packet + if ms.expectReceiveSize >= MaxMysqlPacketLen { + return + } contents := bytes[4:] + // add prepare info if contents[0] == ComStmtPrepare { ms.prepareInfo = &prepareInfo{} } @@ -155,18 +182,26 @@ func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { contentSize = int64(len(contents)) seqID += 4 ms.beginSeqID = seqID + ms.endSeqID = seqID - newCache := make([]byte, ms.expectReceiveSize+len(ms.cachedStmtBytes)) - if len(ms.cachedStmtBytes) > 0 { - copy(newCache[:len(ms.cachedStmtBytes)], ms.cachedStmtBytes) + // if len(ms.cachedStmtBytes) > 0 { + // copy(newCache[:len(ms.cachedStmtBytes)], ms.cachedStmtBytes) + // } + if int64(ms.expectReceiveSize) < int64(len(contents)) { + log.Warnf("receive invalid mysql packet") + return } - if int64(ms.expectReceiveSize+len(ms.cachedStmtBytes)) >= ms.packageOffset+int64(len(contents)) { - copy(newCache[ms.packageOffset:ms.packageOffset+int64(len(contents))], contents) - ms.cachedStmtBytes = newCache - } + newCache := make([]byte, ms.expectReceiveSize) + copy(newCache[:len(contents)], contents) + ms.cachedStmtBytes = newCache } else { + // ignore too big mysql packet + if ms.expectReceiveSize >= MaxMysqlPacketLen { + return + } + if ms.beginSeqID == -1 { log.Warnf("cover range is empty") return @@ -181,29 +216,65 @@ func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { } else if seqID + int64(len(bytes)) <= ms.beginSeqID { // repeat packet log.Debugf("receive repeat packet") - return - - } else if seqID > ms.expectSeqID { - // discontinuous packet - log.Debugf("receive discontinuous packet") - ms.clear() - return } - seqOffset := seqID - ms.beginSeqID + ms.packageOffset - if seqOffset+int64(len(bytes)) > int64(len(ms.cachedStmtBytes)) { - // is not a normal mysql packet + seqOffset := seqID - ms.beginSeqID + if seqOffset+contentSize > int64(len(ms.cachedStmtBytes)) { + // not in a normal mysql packet log.Debugf("receive an unexpect packet") - ms.clear() + // ms.clear() return } // add byte to stmt cache - copy(ms.cachedStmtBytes[ms.packageOffset+seqOffset:ms.packageOffset+seqOffset+contentSize], bytes) + copy(ms.cachedStmtBytes[seqOffset:seqOffset+contentSize], bytes) } + ms.addRange(newJigsaw(seqID, seqID+contentSize)) ms.expectReceiveSize = ms.expectReceiveSize - int(contentSize) - ms.expectSeqID = seqID + contentSize +} + +func (ms *MysqlSession) addRange(js *jigsaw) { + head := ms.coverRanges.Front() + // empty list + if head == nil { + ms.coverRanges.PushBack(js) + return + } + + // find insert position + var node = head + for ; node != nil; node = node.Next() { + nodeVal := node.Value.(*jigsaw) + if nodeVal.begin > js.begin { + break + } + } + + // insert element + if node != nil { + ms.coverRanges.InsertBefore(js, node) + } else { + ms.coverRanges.PushBack(js) + } +} + +func (ms *MysqlSession) mergeRange() { + head := ms.coverRanges.Front() + // empty list + if head == nil { + return + } + + // find insert position + var node = head + + for ; node != nil; node = node.Next() { + nodeVal := node.Value.(*jigsaw) + if nodeVal.begin <= ms.endSeqID && nodeVal.end > ms.endSeqID { + ms.endSeqID = nodeVal.end + } + } } func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) {