package mysql import ( "fmt" "sync" "time" log "github.com/golang/glog" "github.com/pingcap/tidb/util/hack" "github.com/zr-hebo/sniffer-agent/communicator" "github.com/zr-hebo/sniffer-agent/model" ) type MySQLSession struct { connectionID *string visitUser *string visitDB *string clientIP *string clientPort int serverIP *string serverPort int stmtBeginTimeNano int64 // packageOffset int64 beginSeqID int64 endSeqID int64 coverRanges *coverRanges expectReceiveSize int expectSendSize int prepareInfo *prepareInfo cachedPrepareStmt map[int][]byte cachedStmtBytes []byte computeWindowSizeCounter int tcpPacketCache []*model.TCPPacket queryPieceReceiver chan model.QueryPiece closeConn chan bool pkgCacheLock sync.Mutex ignoreAckID int64 sendSize int64 // SQL执行返回状态 1代表成功,-1代表失败, 0代表未知 responseStatus int // SQL执行返回信息,成功的时候代表影响行数,失败的时候代表错误码 responseInfo int } type prepareInfo struct { prepareStmtID int } func NewMysqlSession( sessionKey, clientIP *string, clientPort int, serverIP *string, serverPort int, receiver chan model.QueryPiece) (ms *MySQLSession) { ms = &MySQLSession{ connectionID: sessionKey, clientIP: clientIP, clientPort: clientPort, serverIP: serverIP, serverPort: serverPort, stmtBeginTimeNano: time.Now().UnixNano(), cachedPrepareStmt: make(map[int][]byte, 8), queryPieceReceiver: receiver, closeConn: make(chan bool, 1), expectReceiveSize: -1, coverRanges: NewCoverRanges(), sendSize: 0, pkgCacheLock: sync.Mutex{}, } return } func (ms *MySQLSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { if newPkt == nil { return } if newPkt.ToServer { ms.resetBeginTime() ms.readFromClient(newPkt.Seq, newPkt.Payload) } else { ms.readFromServer(newPkt.Seq, newPkt.Payload) qp := ms.GenerateQueryPiece() if qp != nil { ms.queryPieceReceiver <- qp } } } func (ms *MySQLSession) resetBeginTime() { ms.stmtBeginTimeNano = time.Now().UnixNano() } func (ms *MySQLSession) checkFinish() bool { if ms.coverRanges.head == nil || ms.coverRanges.head.next == nil { return false } checkNode := ms.coverRanges.head.next if checkNode.end-checkNode.begin == int64(len(ms.cachedStmtBytes)) { return true } return false } func (ms *MySQLSession) Close() { ms.clear() } func (ms *MySQLSession) clear() { localStmtCache.Enqueue(ms.cachedStmtBytes) ms.cachedStmtBytes = nil ms.expectReceiveSize = -1 ms.expectSendSize = -1 ms.prepareInfo = nil ms.beginSeqID = -1 ms.endSeqID = -1 ms.sendSize = 0 ms.coverRanges.clear() ms.responseStatus = 0 ms.responseInfo = 0 } func (ms *MySQLSession) readFromClient(seqID int64, bytes []byte) { contentSize := int64(len(bytes)) if ms.expectReceiveSize == -1 { // ignore invalid head package if len(bytes) <= 4 { return } ms.expectReceiveSize = parseInt3(bytes[:4]) // ignore too big mysql packet if ms.expectReceiveSize >= MaxMySQLPacketLen { log.Infof("expect receive size is bigger than max deal size: %d", MaxMySQLPacketLen) return } contents := bytes[4:] // add prepare info if contents[0] == ComStmtPrepare { ms.prepareInfo = &prepareInfo{} } contentSize = int64(len(contents)) seqID += 4 ms.beginSeqID = seqID ms.endSeqID = seqID if int64(ms.expectReceiveSize) < int64(len(contents)) { log.Warning("receive invalid mysql packet") return } newCache := localStmtCache.DequeueWithInit(ms.expectReceiveSize) copy(newCache[:len(contents)], contents) ms.cachedStmtBytes = newCache } else { // ignore too big mysql packet if ms.expectReceiveSize >= MaxMySQLPacketLen { return } if ms.beginSeqID == -1 { log.Info("cover range is empty") return } if seqID < ms.beginSeqID { // out date packet log.Infof("in session %s get outdate package with Seq:%d, beginSeq:%d", *ms.connectionID, seqID, ms.beginSeqID) return } seqOffset := seqID - ms.beginSeqID if seqOffset+contentSize > int64(len(ms.cachedStmtBytes)) { // not in a normal mysql packet log.Info("receive an unexpect packet") ms.clear() return } // add byte to stmt cache copy(ms.cachedStmtBytes[seqOffset:seqOffset+contentSize], bytes) } ms.coverRanges.addRange(coverRangePool.NewCoverage(seqID, seqID+contentSize)) // ms.expectReceiveSize = ms.expectReceiveSize - int(contentSize) } func (ms *MySQLSession) readFromServer(respSeq int64, bytes []byte) { if ms.expectSendSize < 1 && len(bytes) > 4 { fmt.Printf("%v", bytes) ms.expectSendSize = parseInt3(bytes[:3]) contents := bytes[4:] respStatus := contents[0] // the COM_STMT_PREPARE succeeded if ms.prepareInfo != nil && respStatus == 0 { ms.prepareInfo.prepareStmtID = parseInt4(contents[1:5]) } else { if respStatus == 0x00 || respStatus == 0xfe { ms.responseStatus = 1 errCode, _, _ := parseLengthEncodedInt(contents[1:]) ms.responseInfo = int(errCode) } else if respStatus == 0xff { ms.responseStatus = -1 ms.responseInfo = parseInt2(contents[1:3]) } } } if ms.coverRanges.head.next == nil || ms.coverRanges.head.next.end != respSeq { ms.clear() } } func IsAuth(val byte) bool { return val > 32 } func (ms *MySQLSession) GenerateQueryPiece() (qp model.QueryPiece) { defer ms.clear() if len(ms.cachedStmtBytes) < 1 { return } if !ms.checkFinish() { log.Warning("receive a not complete cover") return } if len(ms.cachedStmtBytes) > maxSQLLen { log.Warning("sql in cache is too long, ignore it") return } var mqp *model.PooledMysqlQueryPiece var querySQLInBytes []byte if IsAuth(ms.cachedStmtBytes[0]) { userName, dbName, err := parseAuthInfo(ms.cachedStmtBytes) if err != nil { log.Errorf("parse auth info failed <-- %s", err.Error()) return } ms.visitUser = &userName ms.visitDB = &dbName } else { switch ms.cachedStmtBytes[0] { case ComInitDB: newDBName := string(ms.cachedStmtBytes[1:]) useSQL := fmt.Sprintf("use %s", newDBName) querySQLInBytes = hack.Slice(useSQL) mqp = ms.composeQueryPiece() mqp.QuerySQL = &useSQL // update session database ms.visitDB = &newDBName case ComDropDB: dbName := string(ms.cachedStmtBytes[1:]) dropSQL := fmt.Sprintf("drop database %s", dbName) mqp = ms.composeQueryPiece() mqp.QuerySQL = &dropSQL case ComCreateDB, ComQuery: mqp = ms.composeQueryPiece() querySQLInBytes = ms.cachedStmtBytes[1:] querySQL := hack.String(querySQLInBytes) mqp.QuerySQL = &querySQL case ComStmtPrepare: mqp = ms.composeQueryPiece() querySQLInBytes = make([]byte, len(ms.cachedStmtBytes[1:])) copy(querySQLInBytes, ms.cachedStmtBytes[1:]) querySQL := hack.String(querySQLInBytes) mqp.QuerySQL = &querySQL ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = querySQLInBytes log.Infof("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID) case ComStmtExecute: prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5]) mqp = ms.composeQueryPiece() var ok bool querySQLInBytes, ok = ms.cachedPrepareStmt[prepareStmtID] if !ok { querySQLInBytes = PrepareStatement } querySQL := hack.String(querySQLInBytes) mqp.QuerySQL = &querySQL // log.Debugf("execute prepare statement:%d", prepareStmtID) case ComStmtClose: prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5]) delete(ms.cachedPrepareStmt, prepareStmtID) log.Infof("remove prepare statement:%d", prepareStmtID) default: return } } if strictMode && mqp != nil && mqp.VisitUser == nil { user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID) if err != nil { log.Errorf("query user and db from mysql failed <-- %s", err.Error()) } else { mqp.VisitUser = user mqp.VisitDB = db } } mqp = filterQueryPieceBySQL(mqp, querySQLInBytes) if mqp == nil { return nil } communicator.ReceiveExecTime(ms.stmtBeginTimeNano) return mqp } func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) *model.PooledMysqlQueryPiece { if mqp == nil || querySQL == nil { return nil } else if uselessSQLPattern.Match(querySQL) { return nil } if ddlPatern.Match(querySQL) { mqp.SetNeedSyncSend(true) } return mqp } func (ms *MySQLSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) { clientIP := ms.clientIP clientPort := ms.clientPort mqp = model.NewPooledMysqlQueryPiece( ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP, clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTimeNano) mqp.ResponseStatus = ms.responseStatus mqp.ResponseInfo = ms.responseStatus return }