sniffer-agent/session-dealer/mysql/session.go

356 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"fmt"
"sync"
"time"
log "github.com/golang/glog"
"github.com/pingcap/tidb/util/hack"
"github.com/zr-hebo/sniffer-agent/communicator"
"github.com/zr-hebo/sniffer-agent/model"
)
type MySQLSession struct {
connectionID *string
visitUser *string
visitDB *string
clientIP *string
clientPort int
serverIP *string
serverPort int
stmtBeginTimeNano int64
// packageOffset int64
beginSeqID int64
endSeqID int64
coverRanges *coverRanges
expectReceiveSize int
expectSendSize int
prepareInfo *prepareInfo
cachedPrepareStmt map[int][]byte
cachedStmtBytes []byte
computeWindowSizeCounter int
tcpPacketCache []*model.TCPPacket
queryPieceReceiver chan model.QueryPiece
closeConn chan bool
pkgCacheLock sync.Mutex
ignoreAckID int64
sendSize int64
// SQL执行返回状态 1代表成功-1代表失败, 0代表未知
responseStatus int
// SQL执行返回信息成功的时候代表影响行数失败的时候代表错误码
responseInfo int
}
type prepareInfo struct {
prepareStmtID int
}
func NewMysqlSession(
sessionKey, clientIP *string, clientPort int, serverIP *string, serverPort int,
receiver chan model.QueryPiece) (ms *MySQLSession) {
ms = &MySQLSession{
connectionID: sessionKey,
clientIP: clientIP,
clientPort: clientPort,
serverIP: serverIP,
serverPort: serverPort,
stmtBeginTimeNano: time.Now().UnixNano(),
cachedPrepareStmt: make(map[int][]byte, 8),
queryPieceReceiver: receiver,
closeConn: make(chan bool, 1),
expectReceiveSize: -1,
coverRanges: NewCoverRanges(),
sendSize: 0,
pkgCacheLock: sync.Mutex{},
}
return
}
func (ms *MySQLSession) ReceiveTCPPacket(newPkt *model.TCPPacket) {
if newPkt == nil {
return
}
if newPkt.ToServer {
ms.resetBeginTime()
ms.readFromClient(newPkt.Seq, newPkt.Payload)
} else {
ms.readFromServer(newPkt.Seq, newPkt.Payload)
qp := ms.GenerateQueryPiece()
if qp != nil {
ms.queryPieceReceiver <- qp
}
}
}
func (ms *MySQLSession) resetBeginTime() {
ms.stmtBeginTimeNano = time.Now().UnixNano()
}
func (ms *MySQLSession) checkFinish() bool {
if ms.coverRanges.head == nil || ms.coverRanges.head.next == nil {
return false
}
checkNode := ms.coverRanges.head.next
if checkNode.end-checkNode.begin == int64(len(ms.cachedStmtBytes)) {
return true
}
return false
}
func (ms *MySQLSession) Close() {
ms.clear()
}
func (ms *MySQLSession) clear() {
localStmtCache.Enqueue(ms.cachedStmtBytes)
ms.cachedStmtBytes = nil
ms.expectReceiveSize = -1
ms.expectSendSize = -1
ms.prepareInfo = nil
ms.beginSeqID = -1
ms.endSeqID = -1
ms.sendSize = 0
ms.coverRanges.clear()
ms.responseStatus = 0
ms.responseInfo = 0
}
func (ms *MySQLSession) readFromClient(seqID int64, bytes []byte) {
contentSize := int64(len(bytes))
if ms.expectReceiveSize == -1 {
// ignore invalid head package
if len(bytes) <= 4 {
return
}
ms.expectReceiveSize = parseInt3(bytes[:4])
// ignore too big mysql packet
if ms.expectReceiveSize >= MaxMySQLPacketLen {
log.Infof("expect receive size is bigger than max deal size: %d", MaxMySQLPacketLen)
return
}
contents := bytes[4:]
// add prepare info
if contents[0] == ComStmtPrepare {
ms.prepareInfo = &prepareInfo{}
}
contentSize = int64(len(contents))
seqID += 4
ms.beginSeqID = seqID
ms.endSeqID = seqID
if int64(ms.expectReceiveSize) < int64(len(contents)) {
log.Warning("receive invalid mysql packet")
return
}
newCache := localStmtCache.DequeueWithInit(ms.expectReceiveSize)
copy(newCache[:len(contents)], contents)
ms.cachedStmtBytes = newCache
} else {
// ignore too big mysql packet
if ms.expectReceiveSize >= MaxMySQLPacketLen {
return
}
if ms.beginSeqID == -1 {
log.Info("cover range is empty")
return
}
if seqID < ms.beginSeqID {
// out date packet
log.Infof("in session %s get outdate package with Seq:%d, beginSeq:%d",
*ms.connectionID, seqID, ms.beginSeqID)
return
}
seqOffset := seqID - ms.beginSeqID
if seqOffset+contentSize > int64(len(ms.cachedStmtBytes)) {
// not in a normal mysql packet
log.Info("receive an unexpect packet")
ms.clear()
return
}
// add byte to stmt cache
copy(ms.cachedStmtBytes[seqOffset:seqOffset+contentSize], bytes)
}
ms.coverRanges.addRange(coverRangePool.NewCoverage(seqID, seqID+contentSize))
// ms.expectReceiveSize = ms.expectReceiveSize - int(contentSize)
}
func (ms *MySQLSession) readFromServer(respSeq int64, bytes []byte) {
if ms.expectSendSize < 1 && len(bytes) > 4 {
fmt.Printf("%v", bytes)
ms.expectSendSize = parseInt3(bytes[:3])
contents := bytes[4:]
respStatus := contents[0]
// the COM_STMT_PREPARE succeeded
if ms.prepareInfo != nil && respStatus == 0 {
ms.prepareInfo.prepareStmtID = parseInt4(contents[1:5])
} else {
if respStatus == 0x00 || respStatus == 0xfe {
ms.responseStatus = 1
errCode, _, _ := parseLengthEncodedInt(contents[1:])
ms.responseInfo = int(errCode)
} else if respStatus == 0xff {
ms.responseStatus = -1
ms.responseInfo = parseInt2(contents[1:3])
}
}
}
if ms.coverRanges.head.next == nil || ms.coverRanges.head.next.end != respSeq {
ms.clear()
}
}
func IsAuth(val byte) bool {
return val > 32
}
func (ms *MySQLSession) GenerateQueryPiece() (qp model.QueryPiece) {
defer ms.clear()
if len(ms.cachedStmtBytes) < 1 {
return
}
if !ms.checkFinish() {
log.Warning("receive a not complete cover")
return
}
if len(ms.cachedStmtBytes) > maxSQLLen {
log.Warning("sql in cache is too long, ignore it")
return
}
var mqp *model.PooledMysqlQueryPiece
var querySQLInBytes []byte
if IsAuth(ms.cachedStmtBytes[0]) {
userName, dbName, err := parseAuthInfo(ms.cachedStmtBytes)
if err != nil {
log.Errorf("parse auth info failed <-- %s", err.Error())
return
}
ms.visitUser = &userName
ms.visitDB = &dbName
} else {
switch ms.cachedStmtBytes[0] {
case ComInitDB:
newDBName := string(ms.cachedStmtBytes[1:])
useSQL := fmt.Sprintf("use %s", newDBName)
querySQLInBytes = hack.Slice(useSQL)
mqp = ms.composeQueryPiece()
mqp.QuerySQL = &useSQL
// update session database
ms.visitDB = &newDBName
case ComDropDB:
dbName := string(ms.cachedStmtBytes[1:])
dropSQL := fmt.Sprintf("drop database %s", dbName)
mqp = ms.composeQueryPiece()
mqp.QuerySQL = &dropSQL
case ComCreateDB, ComQuery:
mqp = ms.composeQueryPiece()
querySQLInBytes = ms.cachedStmtBytes[1:]
querySQL := hack.String(querySQLInBytes)
mqp.QuerySQL = &querySQL
case ComStmtPrepare:
mqp = ms.composeQueryPiece()
querySQLInBytes = make([]byte, len(ms.cachedStmtBytes[1:]))
copy(querySQLInBytes, ms.cachedStmtBytes[1:])
querySQL := hack.String(querySQLInBytes)
mqp.QuerySQL = &querySQL
ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = querySQLInBytes
log.Infof("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID)
case ComStmtExecute:
prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5])
mqp = ms.composeQueryPiece()
var ok bool
querySQLInBytes, ok = ms.cachedPrepareStmt[prepareStmtID]
if !ok {
querySQLInBytes = PrepareStatement
}
querySQL := hack.String(querySQLInBytes)
mqp.QuerySQL = &querySQL
// log.Debugf("execute prepare statement:%d", prepareStmtID)
case ComStmtClose:
prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5])
delete(ms.cachedPrepareStmt, prepareStmtID)
log.Infof("remove prepare statement:%d", prepareStmtID)
default:
return
}
}
if strictMode && mqp != nil && mqp.VisitUser == nil {
user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID)
if err != nil {
log.Errorf("query user and db from mysql failed <-- %s", err.Error())
} else {
mqp.VisitUser = user
mqp.VisitDB = db
}
}
mqp = filterQueryPieceBySQL(mqp, querySQLInBytes)
if mqp == nil {
return nil
}
communicator.ReceiveExecTime(ms.stmtBeginTimeNano)
return mqp
}
func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) *model.PooledMysqlQueryPiece {
if mqp == nil || querySQL == nil {
return nil
} else if uselessSQLPattern.Match(querySQL) {
return nil
}
if ddlPatern.Match(querySQL) {
mqp.SetNeedSyncSend(true)
}
return mqp
}
func (ms *MySQLSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) {
clientIP := ms.clientIP
clientPort := ms.clientPort
mqp = model.NewPooledMysqlQueryPiece(
ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP,
clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTimeNano)
mqp.ResponseStatus = ms.responseStatus
mqp.ResponseInfo = ms.responseStatus
return
}