mirror of
https://github.com/zr-hebo/sniffer-agent.git
synced 2025-08-10 22:28:36 +08:00
Initial commit
This commit is contained in:
18
session-dealer/config.go
Normal file
18
session-dealer/config.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package session_dealer
|
||||
|
||||
import (
|
||||
"flag"
|
||||
)
|
||||
|
||||
const (
|
||||
ServiceTypeMysql = "mysql"
|
||||
)
|
||||
|
||||
var (
|
||||
serviceType string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&serviceType, "service_type", "mysql", "service type. Default is mysql")
|
||||
}
|
||||
|
24
session-dealer/controller.go
Normal file
24
session-dealer/controller.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package session_dealer
|
||||
|
||||
import (
|
||||
"github.com/zr-hebo/sniffer-agent/session-dealer/mysql"
|
||||
)
|
||||
|
||||
func NewSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (session ConnSession) {
|
||||
switch serviceType {
|
||||
case ServiceTypeMysql:
|
||||
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
|
||||
default:
|
||||
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CheckParams() {
|
||||
switch serviceType {
|
||||
case ServiceTypeMysql:
|
||||
mysql.CheckParams()
|
||||
default:
|
||||
mysql.CheckParams()
|
||||
}
|
||||
}
|
9
session-dealer/model.go
Normal file
9
session-dealer/model.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package session_dealer
|
||||
|
||||
import "github.com/zr-hebo/sniffer-agent/model"
|
||||
|
||||
type ConnSession interface {
|
||||
ReadFromClient(bytes []byte)
|
||||
ReadFromServer(bytes []byte)
|
||||
GenerateQueryPiece() (qp model.QueryPiece)
|
||||
}
|
19
session-dealer/mysql/auth_info.go
Normal file
19
session-dealer/mysql/auth_info.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package mysql
|
||||
|
||||
// parseAuthInfo parse username, dbname from mysql client auth info
|
||||
func parseAuthInfo(data []byte) (userName, dbName string, err error) {
|
||||
var resp handshakeResponse41
|
||||
pos, err := parseHandshakeResponseHeader(&resp, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read the remaining part of the packet.
|
||||
if err = parseHandshakeResponseBody(&resp, data, pos); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
userName = resp.User
|
||||
dbName = resp.DBName
|
||||
return
|
||||
}
|
38
session-dealer/mysql/config.go
Normal file
38
session-dealer/mysql/config.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
uselessSQLPattern = regexp.MustCompile(`(?i)^\s*(select 1|select @@version_comment limit 1|`+
|
||||
`SELECT user, db FROM information_schema.processlist WHERE host=|commit|begin)`)
|
||||
)
|
||||
|
||||
var (
|
||||
strictMode bool
|
||||
adminUser string
|
||||
adminPasswd string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&strictMode,"strict_mode", false, "strict mode. Default is false")
|
||||
flag.StringVar(&adminUser,"admin_user", "", "admin user name. When set strict mode, must set admin user to query session info")
|
||||
flag.StringVar(&adminPasswd,"admin_passwd", "", "admin user passwd. When use strict mode, must set admin user to query session info")
|
||||
}
|
||||
|
||||
func CheckParams() {
|
||||
if !strictMode {
|
||||
return
|
||||
}
|
||||
|
||||
if len(adminUser) < 1 {
|
||||
panic(fmt.Sprintf("In strict mode, admin user name cannot be empty"))
|
||||
}
|
||||
|
||||
if len(adminPasswd) < 1 {
|
||||
panic(fmt.Sprintf("In strict mode, admin passwd cannot be empty"))
|
||||
}
|
||||
}
|
50
session-dealer/mysql/connections.go
Normal file
50
session-dealer/mysql/connections.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
du "github.com/zr-hebo/util-db"
|
||||
// log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func expandLocalMysql(port int) (mysqlHost *du.MysqlDB) {
|
||||
mysqlHost = new(du.MysqlDB)
|
||||
mysqlHost.IP = "localhost"
|
||||
mysqlHost.Port = port
|
||||
mysqlHost.UserName = adminUser
|
||||
mysqlHost.Passwd = adminPasswd
|
||||
mysqlHost.DatabaseType = "mysql"
|
||||
mysqlHost.ConnectTimeout = 1
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func querySessionInfo(snifferPort int, clientHost string) (user, db *string, err error) {
|
||||
mysqlServer := expandLocalMysql(snifferPort)
|
||||
querySQL := fmt.Sprintf(
|
||||
"SELECT user, db FROM information_schema.processlist WHERE host='%s'", clientHost)
|
||||
// log.Debug(querySQL)
|
||||
queryRow, err := mysqlServer.QueryRow(querySQL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if queryRow == nil {
|
||||
return
|
||||
}
|
||||
|
||||
userVal := queryRow.Record["user"]
|
||||
if userVal != nil {
|
||||
usrStr := userVal.(string)
|
||||
user = &usrStr
|
||||
}
|
||||
|
||||
dbVal := queryRow.Record["db"]
|
||||
if dbVal != nil {
|
||||
dbStr := dbVal.(string)
|
||||
db = &dbStr
|
||||
}
|
||||
|
||||
return
|
||||
}
|
105
session-dealer/mysql/const.go
Normal file
105
session-dealer/mysql/const.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package mysql
|
||||
|
||||
import "errors"
|
||||
|
||||
// Command information.
|
||||
const (
|
||||
ComSleep byte = iota
|
||||
ComQuit
|
||||
ComInitDB
|
||||
ComQuery
|
||||
ComFieldList
|
||||
ComCreateDB
|
||||
ComDropDB
|
||||
ComRefresh
|
||||
ComShutdown
|
||||
ComStatistics
|
||||
ComProcessInfo
|
||||
ComConnect
|
||||
ComProcessKill
|
||||
ComDebug
|
||||
ComPing
|
||||
ComTime
|
||||
ComDelayedInsert
|
||||
ComChangeUser
|
||||
ComBinlogDump
|
||||
ComTableDump
|
||||
ComConnectOut
|
||||
ComRegisterSlave
|
||||
ComStmtPrepare
|
||||
ComStmtExecute
|
||||
ComStmtSendLongData
|
||||
ComStmtClose
|
||||
ComStmtReset
|
||||
ComSetOption
|
||||
ComStmtFetch
|
||||
ComBinlogDumpGtid
|
||||
ComResetConnection
|
||||
)
|
||||
|
||||
const (
|
||||
ComAuth = 141
|
||||
)
|
||||
|
||||
// Client information.
|
||||
const (
|
||||
ClientLongPassword uint32 = 1 << iota
|
||||
ClientFoundRows
|
||||
ClientLongFlag
|
||||
ClientConnectWithDB
|
||||
ClientNoSchema
|
||||
ClientCompress
|
||||
ClientODBC
|
||||
ClientLocalFiles
|
||||
ClientIgnoreSpace
|
||||
ClientProtocol41
|
||||
ClientInteractive
|
||||
ClientSSL
|
||||
ClientIgnoreSigpipe
|
||||
ClientTransactions
|
||||
ClientReserved
|
||||
ClientSecureConnection
|
||||
ClientMultiStatements
|
||||
ClientMultiResults
|
||||
ClientPSMultiResults
|
||||
ClientPluginAuth
|
||||
ClientConnectAtts
|
||||
ClientPluginAuthLenencClientData
|
||||
)
|
||||
|
||||
|
||||
// Auth name information.
|
||||
const (
|
||||
AuthName = "mysql_native_password"
|
||||
)
|
||||
|
||||
|
||||
// MySQL database and tables.
|
||||
const (
|
||||
// SystemDB is the name of system database.
|
||||
SystemDB = "mysql"
|
||||
// UserTable is the table in system db contains user info.
|
||||
UserTable = "User"
|
||||
// DBTable is the table in system db contains db scope privilege info.
|
||||
DBTable = "DB"
|
||||
// GlobalVariablesTable is the table contains global system variables.
|
||||
GlobalVariablesTable = "GLOBAL_VARIABLES"
|
||||
// GlobalStatusTable is the table contains global status variables.
|
||||
GlobalStatusTable = "GLOBAL_STATUS"
|
||||
)
|
||||
|
||||
|
||||
// Identifier length limitations.
|
||||
// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
||||
const (
|
||||
// MaxPayloadLen is the max packet payload length.
|
||||
MaxPayloadLen = 1<<24 - 1
|
||||
)
|
||||
|
||||
const (
|
||||
datetimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMalformPacket = errors.New("malform packet error")
|
||||
)
|
9
session-dealer/mysql/model.go
Normal file
9
session-dealer/mysql/model.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package mysql
|
||||
|
||||
type handshakeResponse41 struct {
|
||||
Capability uint32
|
||||
Collation uint8
|
||||
User string
|
||||
DBName string
|
||||
Auth []byte
|
||||
}
|
175
session-dealer/mysql/session.go
Normal file
175
session-dealer/mysql/session.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zr-hebo/sniffer-agent/model"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type MysqlSession struct {
|
||||
connectionID string
|
||||
visitUser *string
|
||||
visitDB *string
|
||||
clientHost string
|
||||
clientPort int
|
||||
serverIP string
|
||||
serverPort int
|
||||
beginTime int64
|
||||
expectSize int
|
||||
prepareInfo *prepareInfo
|
||||
cachedPrepareStmt map[int]*string
|
||||
tcpCache []byte
|
||||
cachedStmtBytes []byte
|
||||
}
|
||||
|
||||
type prepareInfo struct {
|
||||
prepareStmtID int
|
||||
}
|
||||
|
||||
func NewMysqlSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (ms *MysqlSession) {
|
||||
ms = &MysqlSession{
|
||||
connectionID: sessionKey,
|
||||
clientHost: clientIP,
|
||||
clientPort: clientPort,
|
||||
serverIP: serverIP,
|
||||
serverPort: serverPort,
|
||||
beginTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
cachedPrepareStmt: make(map[int]*string),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ms *MysqlSession) ReadFromServer(bytes []byte) {
|
||||
if ms.expectSize < 1 {
|
||||
ms.expectSize = extractMysqlPayloadSize(bytes)
|
||||
contents := bytes[4:]
|
||||
if ms.prepareInfo != nil && contents[0] == 0 {
|
||||
ms.prepareInfo.prepareStmtID = bytesToInt(contents[1:5])
|
||||
}
|
||||
ms.expectSize = ms.expectSize - len(contents)
|
||||
|
||||
} else {
|
||||
ms.expectSize = ms.expectSize - len(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MysqlSession) ReadFromClient(bytes []byte) {
|
||||
if ms.expectSize < 1 {
|
||||
ms.expectSize = extractMysqlPayloadSize(bytes)
|
||||
contents := bytes[4:]
|
||||
if contents[0] == ComStmtPrepare {
|
||||
ms.prepareInfo = &prepareInfo{}
|
||||
}
|
||||
|
||||
ms.expectSize = ms.expectSize - len(contents)
|
||||
ms.tcpCache = append(ms.tcpCache, contents...)
|
||||
|
||||
} else {
|
||||
ms.expectSize = ms.expectSize - len(bytes)
|
||||
ms.tcpCache = append(ms.tcpCache, bytes...)
|
||||
if len(ms.tcpCache) == MaxPayloadLen {
|
||||
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
|
||||
ms.tcpCache = ms.tcpCache[:0]
|
||||
ms.expectSize = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) {
|
||||
if len(ms.cachedStmtBytes) < 1 && len(ms.tcpCache) < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var mqp *model.MysqlQueryPiece = nil
|
||||
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
|
||||
switch ms.cachedStmtBytes[0] {
|
||||
case ComAuth:
|
||||
var userName, dbName string
|
||||
var err error
|
||||
userName, dbName, err = parseAuthInfo(ms.cachedStmtBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ms.visitUser = &userName
|
||||
ms.visitDB = &dbName
|
||||
|
||||
case ComInitDB:
|
||||
newDBName := string(ms.cachedStmtBytes[1:])
|
||||
useSQL := fmt.Sprintf("use %s", newDBName)
|
||||
mqp = ms.composeQueryPiece()
|
||||
mqp.QuerySQL = &useSQL
|
||||
// update session database
|
||||
ms.visitDB = &newDBName
|
||||
|
||||
case ComCreateDB:
|
||||
case ComDropDB:
|
||||
case ComQuery:
|
||||
mqp = ms.composeQueryPiece()
|
||||
querySQL := string(ms.cachedStmtBytes[1:])
|
||||
mqp.QuerySQL = &querySQL
|
||||
|
||||
case ComStmtPrepare:
|
||||
mqp = ms.composeQueryPiece()
|
||||
querySQL := string(ms.cachedStmtBytes[1:])
|
||||
mqp.QuerySQL = &querySQL
|
||||
ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = &querySQL
|
||||
log.Debugf("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID)
|
||||
|
||||
case ComStmtExecute:
|
||||
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
|
||||
mqp = ms.composeQueryPiece()
|
||||
mqp.QuerySQL = ms.cachedPrepareStmt[prepareStmtID]
|
||||
log.Debugf("execute prepare statement:%d", prepareStmtID)
|
||||
|
||||
case ComStmtClose:
|
||||
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
|
||||
delete(ms.cachedPrepareStmt, prepareStmtID)
|
||||
log.Debugf("remove prepare statement:%d", prepareStmtID)
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
if strictMode && mqp != nil && mqp.VisitUser == nil {
|
||||
user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID)
|
||||
if err != nil {
|
||||
log.Errorf("query user and db from mysql failed <-- %s", err.Error())
|
||||
} else {
|
||||
mqp.VisitUser = user
|
||||
mqp.VisitDB = db
|
||||
}
|
||||
}
|
||||
|
||||
ms.tcpCache = ms.tcpCache[:0]
|
||||
ms.cachedStmtBytes = ms.cachedStmtBytes[:0]
|
||||
ms.expectSize = 0
|
||||
ms.prepareInfo = nil
|
||||
return filterQueryPieceBySQL(mqp)
|
||||
}
|
||||
|
||||
func filterQueryPieceBySQL(mqp *model.MysqlQueryPiece) (model.QueryPiece) {
|
||||
if mqp == nil || mqp.QuerySQL == nil {
|
||||
return nil
|
||||
|
||||
} else if (uselessSQLPattern.MatchString(*mqp.QuerySQL)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return mqp
|
||||
}
|
||||
|
||||
func (ms *MysqlSession) composeQueryPiece() (mqp *model.MysqlQueryPiece) {
|
||||
nowInMS := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
mqp = &model.MysqlQueryPiece{
|
||||
SessionID: ms.connectionID,
|
||||
ClientHost: ms.clientHost,
|
||||
ServerIP: ms.serverIP,
|
||||
ServerPort: ms.serverPort,
|
||||
VisitUser: ms.visitUser,
|
||||
VisitDB: ms.visitDB,
|
||||
BeginTime: time.Unix(ms.beginTime/1000, 0).Format(datetimeFormat),
|
||||
CostTimeInMS: nowInMS - ms.beginTime,
|
||||
}
|
||||
return mqp
|
||||
}
|
131
session-dealer/mysql/util.go
Normal file
131
session-dealer/mysql/util.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
|
||||
func parseHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
|
||||
// Ensure there are enough data to read:
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
if len(data) < 4+4+1+23 {
|
||||
return 0, ErrMalformPacket
|
||||
}
|
||||
|
||||
offset := 0
|
||||
// capability
|
||||
capability := binary.LittleEndian.Uint32(data[:4])
|
||||
packet.Capability = capability
|
||||
offset += 4
|
||||
// skip max packet size
|
||||
offset += 4
|
||||
// charset, skip, if you want to use another charset, use set names
|
||||
packet.Collation = data[offset]
|
||||
offset++
|
||||
// skip reserved 23[00]
|
||||
offset += 23
|
||||
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
|
||||
func parseHandshakeResponseBody(packet *handshakeResponse41, data []byte, offset int) (err error) {
|
||||
defer func() {
|
||||
// Check malformat packet cause out of range is disgusting, but don't panic!
|
||||
if r := recover(); r != nil {
|
||||
err = ErrMalformPacket
|
||||
}
|
||||
}()
|
||||
// user name
|
||||
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
|
||||
offset += len(packet.User) + 1
|
||||
|
||||
if packet.Capability&ClientPluginAuthLenencClientData > 0 {
|
||||
// MySQL client sets the wrong capability, it will set this bit even server doesn't
|
||||
// support ClientPluginAuthLenencClientData.
|
||||
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
|
||||
num, null, off := parseLengthEncodedInt(data[offset:])
|
||||
offset += off
|
||||
if !null {
|
||||
packet.Auth = data[offset : offset+int(num)]
|
||||
offset += int(num)
|
||||
}
|
||||
} else if packet.Capability&ClientSecureConnection > 0 {
|
||||
// auth length and auth
|
||||
authLen := int(data[offset])
|
||||
offset++
|
||||
packet.Auth = data[offset : offset+authLen]
|
||||
offset += authLen
|
||||
} else {
|
||||
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
|
||||
offset += len(packet.Auth) + 1
|
||||
}
|
||||
|
||||
if packet.Capability&ClientConnectWithDB > 0 {
|
||||
if len(data[offset:]) > 0 {
|
||||
idx := bytes.IndexByte(data[offset:], 0)
|
||||
packet.DBName = string(data[offset : offset+idx])
|
||||
offset = offset + idx + 1
|
||||
}
|
||||
}
|
||||
|
||||
if packet.Capability&ClientPluginAuth > 0 {
|
||||
// TODO: Support mysql.ClientPluginAuth, skip it now
|
||||
idx := bytes.IndexByte(data[offset:], 0)
|
||||
offset = offset + idx + 1
|
||||
}
|
||||
|
||||
if packet.Capability&ClientConnectAtts > 0 {
|
||||
if len(data[offset:]) == 0 {
|
||||
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
|
||||
switch b[0] {
|
||||
// 251: NULL
|
||||
case 0xfb:
|
||||
n = 1
|
||||
isNull = true
|
||||
return
|
||||
|
||||
// 252: value of following 2
|
||||
case 0xfc:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8
|
||||
n = 3
|
||||
return
|
||||
|
||||
// 253: value of following 3
|
||||
case 0xfd:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
|
||||
n = 4
|
||||
return
|
||||
|
||||
// 254: value of following 8
|
||||
case 0xfe:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
||||
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
|
||||
uint64(b[7])<<48 | uint64(b[8])<<56
|
||||
n = 9
|
||||
return
|
||||
}
|
||||
|
||||
// 0-250: value of first byte
|
||||
num = uint64(b[0])
|
||||
n = 1
|
||||
return
|
||||
}
|
||||
|
||||
func extractMysqlPayloadSize(payload []byte) int {
|
||||
header := payload[:4]
|
||||
return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
||||
}
|
||||
|
||||
func bytesToInt(contents []byte) int {
|
||||
return int(uint32(contents[0]) | uint32(contents[1])<<8 | uint32(contents[2])<<16 | uint32(contents[3])<<24)
|
||||
}
|
Reference in New Issue
Block a user