diff --git a/model/mysql_query_piece.go b/model/mysql_query_piece.go index fe19aaf..e0527a9 100644 --- a/model/mysql_query_piece.go +++ b/model/mysql_query_piece.go @@ -8,17 +8,21 @@ import ( type MysqlQueryPiece struct { BaseQueryPiece - SessionID *string `json:"-"` - ClientHost *string `json:"cip"` - ClientPort int `json:"cport"` + SessionID *string `json:"-"` + ClientHost *string `json:"cip"` + ClientPort int `json:"cport"` VisitUser *string `json:"user"` VisitDB *string `json:"db"` QuerySQL *string `json:"sql"` CostTimeInMS int64 `json:"cms"` + // SQL执行返回状态 1代表成功,-1代表失败, 0代表未知 + ResponseStatus int `json:"qrs"` + // SQL执行返回信息,成功的时候代表影响行数,失败的时候代表错误码 + ResponseInfo int `json:"qri"` } -func (mqp *MysqlQueryPiece) String() (*string) { +func (mqp *MysqlQueryPiece) String() *string { content := mqp.Bytes() contentStr := hack.String(content) return &contentStr diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go index 1141975..22aab2b 100644 --- a/session-dealer/mysql/session.go +++ b/session-dealer/mysql/session.go @@ -11,7 +11,7 @@ import ( "github.com/zr-hebo/sniffer-agent/model" ) -type MysqlSession struct { +type MySQLSession struct { connectionID *string visitUser *string visitDB *string @@ -39,6 +39,11 @@ type MysqlSession struct { ignoreAckID int64 sendSize int64 + + // SQL执行返回状态 1代表成功,-1代表失败, 0代表未知 + responseStatus int + // SQL执行返回信息,成功的时候代表影响行数,失败的时候代表错误码 + responseInfo int } type prepareInfo struct { @@ -47,8 +52,8 @@ type prepareInfo struct { func NewMysqlSession( sessionKey, clientIP *string, clientPort int, serverIP *string, serverPort int, - receiver chan model.QueryPiece) (ms *MysqlSession) { - ms = &MysqlSession{ + receiver chan model.QueryPiece) (ms *MySQLSession) { + ms = &MySQLSession{ connectionID: sessionKey, clientIP: clientIP, clientPort: clientPort, @@ -60,7 +65,6 @@ func NewMysqlSession( closeConn: make(chan bool, 1), expectReceiveSize: -1, coverRanges: NewCoverRanges(), - ignoreAckID: -1, sendSize: 0, pkgCacheLock: sync.Mutex{}, } @@ -68,20 +72,11 @@ func NewMysqlSession( return } -func (ms *MysqlSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { +func (ms *MySQLSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { if newPkt == nil { return } - if !newPkt.ToServer && ms.ignoreAckID == newPkt.Seq { - // ignore to response to client data - ms.ignoreAckID = ms.ignoreAckID + int64(len(newPkt.Payload)) - return - - } else if !newPkt.ToServer { - ms.ignoreAckID = newPkt.Seq + int64(len(newPkt.Payload)) - } - if newPkt.ToServer { ms.resetBeginTime() ms.readFromClient(newPkt.Seq, newPkt.Payload) @@ -95,25 +90,11 @@ func (ms *MysqlSession) ReceiveTCPPacket(newPkt *model.TCPPacket) { } } -func (ms *MysqlSession) resetBeginTime() { +func (ms *MySQLSession) resetBeginTime() { ms.stmtBeginTimeNano = time.Now().UnixNano() } -func (ms *MysqlSession) readFromServer(respSeq int64, bytes []byte) { - if ms.expectSendSize < 1 && len(bytes) > 4 { - ms.expectSendSize = extractMysqlPayloadSize(bytes[:4]) - contents := bytes[4:] - if ms.prepareInfo != nil && contents[0] == 0 { - ms.prepareInfo.prepareStmtID = bytesToInt(contents[1:5]) - } - } - - if ms.coverRanges.head.next == nil || ms.coverRanges.head.next.end != respSeq { - ms.clear() - } -} - -func (ms *MysqlSession) checkFinish() bool { +func (ms *MySQLSession) checkFinish() bool { if ms.coverRanges.head == nil || ms.coverRanges.head.next == nil { return false } @@ -126,11 +107,11 @@ func (ms *MysqlSession) checkFinish() bool { return false } -func (ms *MysqlSession) Close() { +func (ms *MySQLSession) Close() { ms.clear() } -func (ms *MysqlSession) clear() { +func (ms *MySQLSession) clear() { localStmtCache.Enqueue(ms.cachedStmtBytes) ms.cachedStmtBytes = nil ms.expectReceiveSize = -1 @@ -138,12 +119,13 @@ func (ms *MysqlSession) clear() { ms.prepareInfo = nil ms.beginSeqID = -1 ms.endSeqID = -1 - ms.ignoreAckID = -1 ms.sendSize = 0 ms.coverRanges.clear() + ms.responseStatus = 0 + ms.responseInfo = 0 } -func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { +func (ms *MySQLSession) readFromClient(seqID int64, bytes []byte) { contentSize := int64(len(bytes)) if ms.expectReceiveSize == -1 { @@ -152,7 +134,7 @@ func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { return } - ms.expectReceiveSize = extractMysqlPayloadSize(bytes[:4]) + ms.expectReceiveSize = parseInt3(bytes[:4]) // ignore too big mysql packet if ms.expectReceiveSize >= MaxMySQLPacketLen { log.Infof("expect receive size is bigger than max deal size: %d", MaxMySQLPacketLen) @@ -213,11 +195,38 @@ func (ms *MysqlSession) readFromClient(seqID int64, bytes []byte) { // ms.expectReceiveSize = ms.expectReceiveSize - int(contentSize) } +func (ms *MySQLSession) readFromServer(respSeq int64, bytes []byte) { + if ms.expectSendSize < 1 && len(bytes) > 4 { + fmt.Printf("%v", bytes) + ms.expectSendSize = parseInt3(bytes[:3]) + contents := bytes[4:] + respStatus := contents[0] + // the COM_STMT_PREPARE succeeded + if ms.prepareInfo != nil && respStatus == 0 { + ms.prepareInfo.prepareStmtID = parseInt4(contents[1:5]) + } else { + if respStatus == 0x00 || respStatus == 0xfe { + ms.responseStatus = 1 + errCode, _, _ := parseLengthEncodedInt(contents[1:]) + ms.responseInfo = int(errCode) + + } else if respStatus == 0xff { + ms.responseStatus = -1 + ms.responseInfo = parseInt2(contents[1:3]) + } + } + } + + if ms.coverRanges.head.next == nil || ms.coverRanges.head.next.end != respSeq { + ms.clear() + } +} + func IsAuth(val byte) bool { return val > 32 } -func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { +func (ms *MySQLSession) GenerateQueryPiece() (qp model.QueryPiece) { defer ms.clear() if len(ms.cachedStmtBytes) < 1 { @@ -278,7 +287,7 @@ func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { log.Infof("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID) case ComStmtExecute: - prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5]) + prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5]) mqp = ms.composeQueryPiece() var ok bool querySQLInBytes, ok = ms.cachedPrepareStmt[prepareStmtID] @@ -291,7 +300,7 @@ func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { // log.Debugf("execute prepare statement:%d", prepareStmtID) case ComStmtClose: - prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5]) + prepareStmtID := parseInt4(ms.cachedStmtBytes[1:5]) delete(ms.cachedPrepareStmt, prepareStmtID) log.Infof("remove prepare statement:%d", prepareStmtID) @@ -334,10 +343,13 @@ func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) *m return mqp } -func (ms *MysqlSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) { +func (ms *MySQLSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) { clientIP := ms.clientIP clientPort := ms.clientPort - return model.NewPooledMysqlQueryPiece( + mqp = model.NewPooledMysqlQueryPiece( ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP, clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTimeNano) + mqp.ResponseStatus = ms.responseStatus + mqp.ResponseInfo = ms.responseStatus + return } diff --git a/session-dealer/mysql/util.go b/session-dealer/mysql/util.go index 53c1752..70e7f98 100644 --- a/session-dealer/mysql/util.go +++ b/session-dealer/mysql/util.go @@ -121,10 +121,14 @@ func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { return } -func extractMysqlPayloadSize(header []byte) int { +func parseInt2(header []byte) int { + return int(uint32(header[0]) | uint32(header[1])<<8) +} + +func parseInt3(header []byte) int { return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) } -func bytesToInt(contents []byte) int { +func parseInt4(contents []byte) int { return int(uint32(contents[0]) | uint32(contents[1])<<8 | uint32(contents[2])<<16 | uint32(contents[3])<<24) -} \ No newline at end of file +}