Initial commit

This commit is contained in:
hebo
2019-08-08 15:20:56 +08:00
parent 16b067ec89
commit 0f4a202c60
29 changed files with 1482 additions and 2 deletions

18
session-dealer/config.go Normal file
View File

@@ -0,0 +1,18 @@
package session_dealer
import (
"flag"
)
const (
ServiceTypeMysql = "mysql"
)
var (
serviceType string
)
func init() {
flag.StringVar(&serviceType, "service_type", "mysql", "service type. Default is mysql")
}

View File

@@ -0,0 +1,24 @@
package session_dealer
import (
"github.com/zr-hebo/sniffer-agent/session-dealer/mysql"
)
func NewSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (session ConnSession) {
switch serviceType {
case ServiceTypeMysql:
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
default:
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
}
return
}
func CheckParams() {
switch serviceType {
case ServiceTypeMysql:
mysql.CheckParams()
default:
mysql.CheckParams()
}
}

9
session-dealer/model.go Normal file
View File

@@ -0,0 +1,9 @@
package session_dealer
import "github.com/zr-hebo/sniffer-agent/model"
type ConnSession interface {
ReadFromClient(bytes []byte)
ReadFromServer(bytes []byte)
GenerateQueryPiece() (qp model.QueryPiece)
}

View File

@@ -0,0 +1,19 @@
package mysql
// parseAuthInfo parse username, dbname from mysql client auth info
func parseAuthInfo(data []byte) (userName, dbName string, err error) {
var resp handshakeResponse41
pos, err := parseHandshakeResponseHeader(&resp, data)
if err != nil {
return
}
// Read the remaining part of the packet.
if err = parseHandshakeResponseBody(&resp, data, pos); err != nil {
return
}
userName = resp.User
dbName = resp.DBName
return
}

View File

@@ -0,0 +1,38 @@
package mysql
import (
"flag"
"fmt"
"regexp"
)
var (
uselessSQLPattern = regexp.MustCompile(`(?i)^\s*(select 1|select @@version_comment limit 1|`+
`SELECT user, db FROM information_schema.processlist WHERE host=|commit|begin)`)
)
var (
strictMode bool
adminUser string
adminPasswd string
)
func init() {
flag.BoolVar(&strictMode,"strict_mode", false, "strict mode. Default is false")
flag.StringVar(&adminUser,"admin_user", "", "admin user name. When set strict mode, must set admin user to query session info")
flag.StringVar(&adminPasswd,"admin_passwd", "", "admin user passwd. When use strict mode, must set admin user to query session info")
}
func CheckParams() {
if !strictMode {
return
}
if len(adminUser) < 1 {
panic(fmt.Sprintf("In strict mode, admin user name cannot be empty"))
}
if len(adminPasswd) < 1 {
panic(fmt.Sprintf("In strict mode, admin passwd cannot be empty"))
}
}

View File

@@ -0,0 +1,50 @@
package mysql
import (
"fmt"
_ "github.com/go-sql-driver/mysql"
du "github.com/zr-hebo/util-db"
// log "github.com/sirupsen/logrus"
)
func expandLocalMysql(port int) (mysqlHost *du.MysqlDB) {
mysqlHost = new(du.MysqlDB)
mysqlHost.IP = "localhost"
mysqlHost.Port = port
mysqlHost.UserName = adminUser
mysqlHost.Passwd = adminPasswd
mysqlHost.DatabaseType = "mysql"
mysqlHost.ConnectTimeout = 1
return
}
func querySessionInfo(snifferPort int, clientHost string) (user, db *string, err error) {
mysqlServer := expandLocalMysql(snifferPort)
querySQL := fmt.Sprintf(
"SELECT user, db FROM information_schema.processlist WHERE host='%s'", clientHost)
// log.Debug(querySQL)
queryRow, err := mysqlServer.QueryRow(querySQL)
if err != nil {
return
}
if queryRow == nil {
return
}
userVal := queryRow.Record["user"]
if userVal != nil {
usrStr := userVal.(string)
user = &usrStr
}
dbVal := queryRow.Record["db"]
if dbVal != nil {
dbStr := dbVal.(string)
db = &dbStr
}
return
}

View File

@@ -0,0 +1,105 @@
package mysql
import "errors"
// Command information.
const (
ComSleep byte = iota
ComQuit
ComInitDB
ComQuery
ComFieldList
ComCreateDB
ComDropDB
ComRefresh
ComShutdown
ComStatistics
ComProcessInfo
ComConnect
ComProcessKill
ComDebug
ComPing
ComTime
ComDelayedInsert
ComChangeUser
ComBinlogDump
ComTableDump
ComConnectOut
ComRegisterSlave
ComStmtPrepare
ComStmtExecute
ComStmtSendLongData
ComStmtClose
ComStmtReset
ComSetOption
ComStmtFetch
ComBinlogDumpGtid
ComResetConnection
)
const (
ComAuth = 141
)
// Client information.
const (
ClientLongPassword uint32 = 1 << iota
ClientFoundRows
ClientLongFlag
ClientConnectWithDB
ClientNoSchema
ClientCompress
ClientODBC
ClientLocalFiles
ClientIgnoreSpace
ClientProtocol41
ClientInteractive
ClientSSL
ClientIgnoreSigpipe
ClientTransactions
ClientReserved
ClientSecureConnection
ClientMultiStatements
ClientMultiResults
ClientPSMultiResults
ClientPluginAuth
ClientConnectAtts
ClientPluginAuthLenencClientData
)
// Auth name information.
const (
AuthName = "mysql_native_password"
)
// MySQL database and tables.
const (
// SystemDB is the name of system database.
SystemDB = "mysql"
// UserTable is the table in system db contains user info.
UserTable = "User"
// DBTable is the table in system db contains db scope privilege info.
DBTable = "DB"
// GlobalVariablesTable is the table contains global system variables.
GlobalVariablesTable = "GLOBAL_VARIABLES"
// GlobalStatusTable is the table contains global status variables.
GlobalStatusTable = "GLOBAL_STATUS"
)
// Identifier length limitations.
// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
const (
// MaxPayloadLen is the max packet payload length.
MaxPayloadLen = 1<<24 - 1
)
const (
datetimeFormat = "2006-01-02 15:04:05"
)
var (
ErrMalformPacket = errors.New("malform packet error")
)

View File

@@ -0,0 +1,9 @@
package mysql
type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
}

View File

@@ -0,0 +1,175 @@
package mysql
import (
"fmt"
"time"
"github.com/zr-hebo/sniffer-agent/model"
log "github.com/sirupsen/logrus"
)
type MysqlSession struct {
connectionID string
visitUser *string
visitDB *string
clientHost string
clientPort int
serverIP string
serverPort int
beginTime int64
expectSize int
prepareInfo *prepareInfo
cachedPrepareStmt map[int]*string
tcpCache []byte
cachedStmtBytes []byte
}
type prepareInfo struct {
prepareStmtID int
}
func NewMysqlSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (ms *MysqlSession) {
ms = &MysqlSession{
connectionID: sessionKey,
clientHost: clientIP,
clientPort: clientPort,
serverIP: serverIP,
serverPort: serverPort,
beginTime: time.Now().UnixNano() / int64(time.Millisecond),
cachedPrepareStmt: make(map[int]*string),
}
return
}
func (ms *MysqlSession) ReadFromServer(bytes []byte) {
if ms.expectSize < 1 {
ms.expectSize = extractMysqlPayloadSize(bytes)
contents := bytes[4:]
if ms.prepareInfo != nil && contents[0] == 0 {
ms.prepareInfo.prepareStmtID = bytesToInt(contents[1:5])
}
ms.expectSize = ms.expectSize - len(contents)
} else {
ms.expectSize = ms.expectSize - len(bytes)
}
}
func (ms *MysqlSession) ReadFromClient(bytes []byte) {
if ms.expectSize < 1 {
ms.expectSize = extractMysqlPayloadSize(bytes)
contents := bytes[4:]
if contents[0] == ComStmtPrepare {
ms.prepareInfo = &prepareInfo{}
}
ms.expectSize = ms.expectSize - len(contents)
ms.tcpCache = append(ms.tcpCache, contents...)
} else {
ms.expectSize = ms.expectSize - len(bytes)
ms.tcpCache = append(ms.tcpCache, bytes...)
if len(ms.tcpCache) == MaxPayloadLen {
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
ms.tcpCache = ms.tcpCache[:0]
ms.expectSize = 0
}
}
}
func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) {
if len(ms.cachedStmtBytes) < 1 && len(ms.tcpCache) < 1 {
return
}
var mqp *model.MysqlQueryPiece = nil
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
switch ms.cachedStmtBytes[0] {
case ComAuth:
var userName, dbName string
var err error
userName, dbName, err = parseAuthInfo(ms.cachedStmtBytes)
if err != nil {
return
}
ms.visitUser = &userName
ms.visitDB = &dbName
case ComInitDB:
newDBName := string(ms.cachedStmtBytes[1:])
useSQL := fmt.Sprintf("use %s", newDBName)
mqp = ms.composeQueryPiece()
mqp.QuerySQL = &useSQL
// update session database
ms.visitDB = &newDBName
case ComCreateDB:
case ComDropDB:
case ComQuery:
mqp = ms.composeQueryPiece()
querySQL := string(ms.cachedStmtBytes[1:])
mqp.QuerySQL = &querySQL
case ComStmtPrepare:
mqp = ms.composeQueryPiece()
querySQL := string(ms.cachedStmtBytes[1:])
mqp.QuerySQL = &querySQL
ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = &querySQL
log.Debugf("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID)
case ComStmtExecute:
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
mqp = ms.composeQueryPiece()
mqp.QuerySQL = ms.cachedPrepareStmt[prepareStmtID]
log.Debugf("execute prepare statement:%d", prepareStmtID)
case ComStmtClose:
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
delete(ms.cachedPrepareStmt, prepareStmtID)
log.Debugf("remove prepare statement:%d", prepareStmtID)
default:
}
if strictMode && mqp != nil && mqp.VisitUser == nil {
user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID)
if err != nil {
log.Errorf("query user and db from mysql failed <-- %s", err.Error())
} else {
mqp.VisitUser = user
mqp.VisitDB = db
}
}
ms.tcpCache = ms.tcpCache[:0]
ms.cachedStmtBytes = ms.cachedStmtBytes[:0]
ms.expectSize = 0
ms.prepareInfo = nil
return filterQueryPieceBySQL(mqp)
}
func filterQueryPieceBySQL(mqp *model.MysqlQueryPiece) (model.QueryPiece) {
if mqp == nil || mqp.QuerySQL == nil {
return nil
} else if (uselessSQLPattern.MatchString(*mqp.QuerySQL)) {
return nil
}
return mqp
}
func (ms *MysqlSession) composeQueryPiece() (mqp *model.MysqlQueryPiece) {
nowInMS := time.Now().UnixNano() / int64(time.Millisecond)
mqp = &model.MysqlQueryPiece{
SessionID: ms.connectionID,
ClientHost: ms.clientHost,
ServerIP: ms.serverIP,
ServerPort: ms.serverPort,
VisitUser: ms.visitUser,
VisitDB: ms.visitDB,
BeginTime: time.Unix(ms.beginTime/1000, 0).Format(datetimeFormat),
CostTimeInMS: nowInMS - ms.beginTime,
}
return mqp
}

View File

@@ -0,0 +1,131 @@
package mysql
import (
"bytes"
"encoding/binary"
)
// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
func parseHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
return 0, ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
return offset, nil
}
// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
func parseHandshakeResponseBody(packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
err = ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
num, null, off := parseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
} else if packet.Capability&ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}
if packet.Capability&ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset = offset + idx + 1
}
}
if packet.Capability&ClientPluginAuth > 0 {
// TODO: Support mysql.ClientPluginAuth, skip it now
idx := bytes.IndexByte(data[offset:], 0)
offset = offset + idx + 1
}
if packet.Capability&ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
}
return nil
}
func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
switch b[0] {
// 251: NULL
case 0xfb:
n = 1
isNull = true
return
// 252: value of following 2
case 0xfc:
num = uint64(b[1]) | uint64(b[2])<<8
n = 3
return
// 253: value of following 3
case 0xfd:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
n = 4
return
// 254: value of following 8
case 0xfe:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56
n = 9
return
}
// 0-250: value of first byte
num = uint64(b[0])
n = 1
return
}
func extractMysqlPayloadSize(payload []byte) int {
header := payload[:4]
return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
}
func bytesToInt(contents []byte) int {
return int(uint32(contents[0]) | uint32(contents[1])<<8 | uint32(contents[2])<<16 | uint32(contents[3])<<24)
}