Initial commit

This commit is contained in:
hebo 2019-08-08 15:20:56 +08:00
parent 16b067ec89
commit 0f4a202c60
29 changed files with 1482 additions and 2 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
.idea/
*.log
*.swp
one_key.sh
sniffer-agent
vendor/

93
Godeps/Godeps.json generated Normal file
View File

@ -0,0 +1,93 @@
{
"ImportPath": "github.com/zr-hebo/sniffer-agent",
"GoVersion": "go1.12",
"GodepVersion": "v79",
"Packages": [
"./..."
],
"Deps": [
{
"ImportPath": "github.com/Shopify/sarama",
"Comment": "v1.14.0-20-gcd645bf",
"Rev": "cd645bfba7622e7de8971388543ff97ee026aad4"
},
{
"ImportPath": "github.com/davecgh/go-spew/spew",
"Comment": "v1.1.0-9-gecdeabc",
"Rev": "ecdeabc65495df2dec95d7c4a4c3e021903035e5"
},
{
"ImportPath": "github.com/eapache/go-resiliency/breaker",
"Comment": "v1.0.0-6-gb1fe83b",
"Rev": "b1fe83b5b03f624450823b751b662259ffc6af70"
},
{
"ImportPath": "github.com/eapache/go-xerial-snappy",
"Rev": "bb955e01b9346ac19dc29eb16586c90ded99a98c"
},
{
"ImportPath": "github.com/eapache/queue",
"Comment": "v1.1.0",
"Rev": "44cc805cf13205b55f69e14bcb69867d1ae92f98"
},
{
"ImportPath": "github.com/go-sql-driver/mysql",
"Comment": "v1.3-42-gfade210",
"Rev": "fade21009797158e7b79e04c340118a9220c6f9e"
},
{
"ImportPath": "github.com/golang/snappy",
"Rev": "553a641470496b2327abcac10b36396bd98e45c9"
},
{
"ImportPath": "github.com/google/gopacket",
"Comment": "v1.1.17-26-gce2e696",
"Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949"
},
{
"ImportPath": "github.com/google/gopacket/layers",
"Comment": "v1.1.17-26-gce2e696",
"Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949"
},
{
"ImportPath": "github.com/google/gopacket/pcap",
"Comment": "v1.1.17-26-gce2e696",
"Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949"
},
{
"ImportPath": "github.com/pierrec/lz4",
"Comment": "v1.0.1",
"Rev": "08c27939df1bd95e881e2c2367a749964ad1fceb"
},
{
"ImportPath": "github.com/pierrec/xxHash/xxHash32",
"Comment": "v0.1-11-ga0006b1",
"Rev": "a0006b13c722f7f12368c00a3d3c2ae8a999a0c6"
},
{
"ImportPath": "github.com/rcrowley/go-metrics",
"Rev": "e181e095bae94582363434144c61a9653aff6e50"
},
{
"ImportPath": "github.com/sirupsen/logrus",
"Comment": "v1.0.4",
"Rev": "d682213848ed68c0a260ca37d6dd5ace8423f5ba"
},
{
"ImportPath": "github.com/zr-hebo/util-db",
"Rev": "3ff29f916f7b712b3adc53c4b9b19b13b8bbed87"
},
{
"ImportPath": "golang.org/x/crypto/ssh/terminal",
"Rev": "eb71ad9bd329b5ac0fd0148dd99bd62e8be8e035"
},
{
"ImportPath": "golang.org/x/sys/unix",
"Rev": "ac767d655b305d4e9612f5f6e33120b9176c4ad4"
},
{
"ImportPath": "golang.org/x/sys/windows",
"Rev": "ac767d655b305d4e9612f5f6e33120b9176c4ad4"
}
]
}

5
Godeps/Readme generated Normal file
View File

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

1
Makefile Normal file
View File

@ -0,0 +1 @@
go build

View File

@ -1,2 +1,90 @@
# sniffer-agent ### sniffer-agent
抓取tcp包解析出mysql语句
> Sniffer TCP package, parsed with mysql protocol, optional you can just print on screen or send query info to Kafka.
> 抓取tcp包解析出mysql语句将查询信息打印在屏幕上或者发送到Kafka
## Architecture
架构设计:
本项目采用模块化设计,主要分为四大模块:抓包模块,协议解析模块,输出模块,心跳模块
![架构设计图](https://github.com/zr-hebo/sniffer-agent/blob/master/images/arch.png)
## Parse Protocol
sniffer-agent采用模块化结构支持用户添加自己的解析模块只要实现了统一的接口即可
- [x] MySQL
- [ ] PostgreSQL
- [ ] Redis
- [ ] Mongodb
目前输出的内容都是解析结果组成的json。
MySQL协议的解析结果示例如下
```
{"sid":"10.XX.XX.XX:54656","sip":"192.168.XX.XX","sport":3306,"user":"root","db":"unibase","sql":"show tables","bt":"2019-08-05 18:23:09","cms":15}
```
其中sid代表客户端ipport组成的session标识sip代表server ipsport代表server portuser代表查询用户db代表当前连接的库名sql代表查询语句bt代表查询开始时间cms代表查询消耗的时间单位是毫秒
## Exporter
输出模块主要负责将解析的结果对外输出。默认情况下输出到命令行可以通过指定export_type参数选择kafka这时候会直接将解析结果发送到kafka。
同样只要实现了export接口用户可以自定义自己的输出方式。
## Install:
环境:
golang1.12
libpcap包
测试脚本运行在python3环境下
1.安装依赖目前自测支持Linux系列操作系统其他版本的系统有待验证
CentOS:
```
yum install libpcap-devel
```
Ubuntu:
```
apt-get install libpcap-dev
```
2.执行编译命令 go build
## Demo
目前只支持MySQL协议的抓取需要将编译后的二进制文件上传到MySQL服务器上
1.最简单的使用
`./sniffer-agent`
2.指定log级别可以指定的值为debug、info、warn、error默认是info
`./sniffer-agent --log_level=debug`
默认会监听 网卡eth0端口3306
3.指定网卡和监听端口
`./sniffer-agent --interface=eth0 --port=3358`
4.指定输出到kafka为了将ddl和select、dml区分处理这里使用了两个topic来生产消息
`./sniffer-agent --export_type=kafka --kafka-server=$kafka_server:$kafka_server --kafka-group-id=sniffer --kafka-async-topic=non_ddl_sql_collector --kafka-sync-topic=ddl_sql_collector`
5.指定严格模式,通过查询获取长连接的用户名和数据库
`./sniffer-agent --strict_mode=true --admin_user=root --admin_passwd=123456`
#### 题外话
在做这个功能之前,项目组调研过类似功能的产品,最有名的是 [mysql-sniffer](https://github.com/Qihoo360/mysql-sniffer) 和 [go-sniffer](https://github.com/40t/go-sniffer),这两个产品都很优秀,不过我们的业务场景要求更多。
我们需要将提取的SQL信息发送到kafka进行处理之前的两个产品输出的结果需要进行一些处理然后自己发送在QPS比较高的情况下这些处理会消耗较多的CPU
另外mysql-sniffer使用c++开发,平台的适用性较差,后期扩展较难。
开发的过程中也借鉴了这些产品的思想另外在MySQL包解析的时候参考了一些 [TiDB](https://github.com/pingcap/tidb) 的内容,部分私有变量和函数直接复制使用,这里向这些优秀的产品致敬,如有侵权请随时联系。
## License
[MIT](https://opensource.org/licenses/MIT)

22
capture/config.go Normal file
View File

@ -0,0 +1,22 @@
package capture
import (
sd "github.com/zr-hebo/sniffer-agent/session-dealer"
log "github.com/sirupsen/logrus"
)
var (
localIPAddr string
sessionPool = make(map[string]sd.ConnSession)
)
func init() {
var err error
localIPAddr, err = getLocalIPAddr()
if err != nil {
panic(err)
}
log.Infof("parsed local ip address:%s", localIPAddr)
}

1
capture/const.go Normal file
View File

@ -0,0 +1 @@
package capture

180
capture/network.go Normal file
View File

@ -0,0 +1,180 @@
package capture
import (
"flag"
"fmt"
log "github.com/sirupsen/logrus"
sd "github.com/zr-hebo/sniffer-agent/session-dealer"
"github.com/zr-hebo/sniffer-agent/model"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
)
var (
DeviceName string
snifferPort int
)
func init() {
flag.StringVar(&DeviceName, "interface", "eth0", "network device name. Default is eth0")
flag.IntVar(&snifferPort, "port", 3306, "sniffer port. Default is 3306")
}
// networkCard is network device
type networkCard struct {
name string
listenPort int
}
func NewNetworkCard() (nc *networkCard) {
// init device
return &networkCard{name: DeviceName, listenPort: snifferPort}
}
// Listen get a connection.
func (nc *networkCard) Listen() (receiver chan model.QueryPiece) {
receiver = make(chan model.QueryPiece, 100)
go func() {
defer func() {
close(receiver)
}()
handle, err := pcap.OpenLive(DeviceName, 65535, false, pcap.BlockForever)
if err != nil {
panic(fmt.Sprintf("cannot open network interface %s <-- %s", nc.name, err.Error()))
}
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
for packet := range packetSource.Packets() {
if packet.NetworkLayer() == nil || packet.TransportLayer() == nil {
// log.Info("empty network layer")
continue
}
if packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
// log.Info("packet type is %s, not TCP", packet.TransportLayer().LayerType())
continue
}
qp := nc.parseTCPPackage(packet)
if qp != nil {
receiver <- qp
}
}
}()
return
}
func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPiece) {
var err error
defer func() {
if err != nil {
log.Error("parse TCP package failed <-- %s", err.Error())
}
}()
tcpConn := packet.TransportLayer().(*layers.TCP)
if tcpConn.SYN || tcpConn.RST {
return
}
if(int(tcpConn.DstPort) != nc.listenPort && int(tcpConn.SrcPort) != nc.listenPort) {
return
}
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
err = fmt.Errorf("no ip layer found in package")
return
}
ipInfo, ok := ipLayer.(*layers.IPv4)
if !ok {
err = fmt.Errorf("parsed no ip address")
return
}
// get IP from ip layer
srcIP := ipInfo.SrcIP.String()
dstIP := ipInfo.DstIP.String()
srcPort := int(tcpConn.SrcPort)
dstPort := int(tcpConn.DstPort)
if dstIP == localIPAddr && dstPort == nc.listenPort {
// deal mysql server response
err = readToServerPackage(srcIP, srcPort, tcpConn)
if err != nil {
return
}
} else if srcIP == localIPAddr && srcPort == nc.listenPort {
// deal mysql client request
qp, err = readFromServerPackage(dstIP, dstPort, tcpConn)
if err != nil {
return
}
}
return
}
func readFromServerPackage(srcIP string, srcPort int, tcpConn *layers.TCP) (qp model.QueryPiece, err error) {
defer func() {
if err != nil {
log.Error("read Mysql package send from mysql server to client failed <-- %s", err.Error())
}
}()
sessionKey := spliceSessionKey(srcIP, srcPort)
if tcpConn.FIN {
delete(sessionPool, sessionKey)
// log.Debugf("close connection from %s", sessionKey)
return
}
tcpPayload := tcpConn.Payload
if (len(tcpPayload) < 1) {
return
}
session := sessionPool[sessionKey]
if session != nil {
session.ReadFromServer(tcpPayload)
qp = session.GenerateQueryPiece()
}
return
}
func readToServerPackage(srcIP string, srcPort int, tcpConn *layers.TCP) (err error) {
defer func() {
if err != nil {
log.Error("read package send from client to mysql server failed <-- %s", err.Error())
}
}()
sessionKey := spliceSessionKey(srcIP, srcPort)
// when client try close connection remove session from session pool
if tcpConn.FIN {
delete(sessionPool, sessionKey)
// log.Debugf("close connection from %s", sessionKey)
return
}
tcpPayload := tcpConn.Payload
if (len(tcpPayload) < 1) {
return
}
session := sessionPool[sessionKey]
if session == nil {
session = sd.NewSession(sessionKey, srcIP, srcPort, localIPAddr, snifferPort)
sessionPool[sessionKey] = session
}
session.ReadFromClient(tcpPayload)
return
}

40
capture/util.go Normal file
View File

@ -0,0 +1,40 @@
package capture
import (
"fmt"
"net"
"strings"
)
func getLocalIPAddr() (ipAddr string, err error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return
}
for _, addr := range addrs {
addrStr := addr.String()
if strings.Contains(addrStr, "127.0.0.1") ||
strings.Contains(addrStr, "::1") ||
strings.Contains(addrStr, "/64") {
continue
}
addrStr = strings.TrimRight(addrStr, "1234567890")
addrStr = strings.TrimRight(addrStr, "/")
if len(addrStr) < 1 {
continue
}
ipAddr = addrStr
return
}
err = fmt.Errorf("no valid ip address found")
return
}
func spliceSessionKey(srcIP string, srcPort int) (sessionKey string) {
sessionKey = fmt.Sprintf("%s:%d", srcIP, srcPort)
return
}

31
communicator/config.go Normal file
View File

@ -0,0 +1,31 @@
package communicator
import (
"flag"
"net/http"
"time"
_ "net/http/pprof"
"github.com/gorilla/mux"
)
var (
communicatePort int
)
func init() {
flag.IntVar(&communicatePort, "communicate_port", 8088, "http server port. Default is 8088")
}
func Server() {
server := &http.Server{
Addr: ":" + string(communicatePort),
Handler: mux.NewRouter(),
IdleTimeout: time.Second * 5,
}
if err := server.ListenAndServe(); err != nil {
panic(err)
}
}

18
exporter/cli.go Normal file
View File

@ -0,0 +1,18 @@
package exporter
import (
"fmt"
"github.com/zr-hebo/sniffer-agent/model"
)
type cliExporter struct {
}
func NewCliExporter() *cliExporter {
return &cliExporter{}
}
func (c *cliExporter) Export (qp model.QueryPiece) (err error){
fmt.Println(qp.String())
return
}

117
exporter/kafka.go Normal file
View File

@ -0,0 +1,117 @@
package exporter
import (
"flag"
"fmt"
"regexp"
"strings"
log "github.com/sirupsen/logrus"
"github.com/Shopify/sarama"
"github.com/zr-hebo/sniffer-agent/model"
)
var (
ddlPatern = regexp.MustCompile(`(?i)^\s*(create|alter|drop)`)
kafkaServer string
kafkaGroupID string
asyncTopic string
syncTopic string
)
func init() {
flag.StringVar(
&kafkaServer, "kafka-server", "", "kafka server address. No default value")
flag.StringVar(
&kafkaGroupID,
"kafka-group-id", "", "kafka service group. No default value")
flag.StringVar(
&asyncTopic,
"kafka-async-topic", "", "kafka async send topic. No default value")
flag.StringVar(
&syncTopic,
"kafka-sync-topic", "", "kafka sync send topic. No default value")
}
type kafkaExporter struct {
asyncProducer sarama.AsyncProducer
syncProducer sarama.SyncProducer
asyncTopic string
syncTopic string
}
func checkParams() {
params := make(map[string]string)
params["kafka-server"] = kafkaServer
params["kafka-group-id"] = kafkaGroupID
params["kafka-async-topic"] = asyncTopic
params["kafka-sync-topic"] = syncTopic
for param := range params {
if len(params[param]) < 1{
panic(fmt.Sprintf("%s cannot be empty", param))
}
}
}
func NewKafkaExporter() (ke *kafkaExporter) {
checkParams()
ke = &kafkaExporter{}
conf := sarama.NewConfig()
conf.Producer.Return.Successes = true
conf.ClientID = kafkaGroupID
addrs := strings.Split(kafkaServer, ",")
syncProducer, err := sarama.NewSyncProducer(addrs, conf)
if err != nil {
panic(err.Error())
}
ke.syncProducer = syncProducer
asyncProducer, err := sarama.NewAsyncProducer(addrs, conf)
if err != nil {
panic(err.Error())
}
ke.asyncProducer = asyncProducer
ke.asyncTopic = asyncTopic
ke.syncTopic = syncTopic
go func() {
errors := ke.asyncProducer.Errors()
success := ke.asyncProducer.Successes()
for {
select {
case err := <-errors:
if err != nil {
log.Error(err.Error())
}
case <-success:
}
}
}()
return
}
func (ke *kafkaExporter) Export (qp model.QueryPiece) (err error){
if ddlPatern.MatchString(qp.GetSQL()) {
log.Debugf("deal ddl: %s\n", qp.String())
msg := &sarama.ProducerMessage{
Topic: ke.syncTopic,
Value: sarama.StringEncoder(qp.String()),
}
_, _, err = ke.syncProducer.SendMessage(msg)
if err != nil {
return
}
} else {
log.Debugf("deal non ddl: %s", qp.String())
msg := &sarama.ProducerMessage{
Topic: ke.asyncTopic,
Value: sarama.ByteEncoder(qp.Bytes()),
}
ke.asyncProducer.Input() <- msg
}
return
}

30
exporter/model.go Normal file
View File

@ -0,0 +1,30 @@
package exporter
import (
"flag"
"github.com/zr-hebo/sniffer-agent/model"
)
var (
exportType string
)
func init() {
flag.StringVar(&exportType,"export_type", "cli", "export type. Default is cli, that is command line")
}
type Exporter interface {
Export(model.QueryPiece) error
}
func NewExporter() Exporter {
switch exportType {
case "cli":
return NewCliExporter()
case "kafka":
return NewKafkaExporter()
default:
return NewCliExporter()
}
}

BIN
images/arch.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

62
main.go Normal file
View File

@ -0,0 +1,62 @@
package main
import (
"flag"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/zr-hebo/sniffer-agent/capture"
"github.com/zr-hebo/sniffer-agent/exporter"
"github.com/zr-hebo/sniffer-agent/communicator"
sd "github.com/zr-hebo/sniffer-agent/session-dealer"
)
var (
logLevel string
)
func init() {
flag.StringVar(&logLevel, "log_level", "info", "log level. Default is info")
}
func initLog() {
log.SetFormatter(&log.TextFormatter{})
log.SetOutput(os.Stdout)
switch logLevel {
case "debug":
log.SetLevel(log.DebugLevel)
case "info":
log.SetLevel(log.InfoLevel)
case "warn":
log.SetLevel(log.WarnLevel)
case "error":
log.SetLevel(log.ErrorLevel)
default:
panic(fmt.Sprintf("cannot set log level:%s, there have four types can set: debug, info, warn, error", logLevel))
}
}
func main() {
flag.Parse()
initLog()
sd.CheckParams()
go communicator.Server()
mainServer()
}
func mainServer() {
ept := exporter.NewExporter()
networkCard := capture.NewNetworkCard()
log.Info("begin listen")
for queryPiece := range networkCard.Listen() {
err := ept.Export(queryPiece)
if err != nil {
log.Error(err.Error())
}
}
log.Errorf("cannot get network package from %s", capture.DeviceName)
os.Exit(1)
}

49
model/query_piece.go Normal file
View File

@ -0,0 +1,49 @@
package model
import (
"encoding/json"
)
type QueryPiece interface {
String() string
Bytes() []byte
GetSQL() string
}
// MysqlQueryPiece 查询信息
type MysqlQueryPiece struct {
SessionID string `json:"sid"`
ClientHost string `json:"-"`
ServerIP string `json:"sip"`
ServerPort int `json:"sport"`
VisitUser *string `json:"user"`
VisitDB *string `json:"db"`
QuerySQL *string `json:"sql"`
BeginTime string `json:"bt"`
CostTimeInMS int64 `json:"cms"`
}
func (qp *MysqlQueryPiece) String() (str string) {
content, err := json.Marshal(qp)
if err != nil {
return err.Error()
}
return string(content)
}
func (qp *MysqlQueryPiece) Bytes() (bytes []byte) {
content, err := json.Marshal(qp)
if err != nil {
return []byte(err.Error())
}
return content
}
func (qp *MysqlQueryPiece) GetSQL() (str string) {
if qp.QuerySQL != nil {
return *qp.QuerySQL
}
return ""
}

53
scripts/check_kafka.py Normal file
View File

@ -0,0 +1,53 @@
# import json
from kafka import KafkaConsumer, KafkaProducer
kafka_server = '192.168.XX.XX:9091'
group_id = 'sniffer'
topic = 'ddl_sql_collector'
def check_consume():
conf = {
'bootstrap_servers': kafka_server,
'client_id': group_id,
'group_id': group_id,
'auto_offset_reset': 'earliest',
'session_timeout_ms': 60000,
'api_version': (0, 9, 0, 1)
}
consumer = KafkaConsumer(topic, **conf)
print('ready to consume')
for msg in consumer:
# event = json.loads(bytes.decode(msg.value))
print(msg)
def check_produce():
conf = {
'bootstrap_servers': kafka_server,
'client_id': group_id
}
# 'api_version': (0, 9, 0, 1)
producer = KafkaProducer(**conf)
try:
future = producer.send(topic, 'haha')
result = future.get(timeout=3)
print('send OK')
print(result)
except BaseException as e:
print('send failed')
# 发送失败时,用户需根据业务逻辑做异常处理,否则消息可能会丢失
print(str(e))
def _real_main():
check_produce()
# check_consume()
if __name__ == '__main__':
_real_main()

75
scripts/check_prepare.py Normal file
View File

@ -0,0 +1,75 @@
import time
import mysql.connector
from mysql.connector.cursor import MySQLCursorPrepared
config = {
'host': '192.168.XX.XX',
'port': 3358,
'database': 'sniffer',
'user': 'root',
'password': '',
'charset': 'utf8',
'use_unicode': True,
'get_warnings': True,
}
def _real_main(config):
while True:
_once_check()
def _once_check():
output = []
conn = mysql.connector.Connect(**config)
curprep = conn.cursor(cursor_class=MySQLCursorPrepared)
cur = conn.cursor()
# Drop table if exists, and create it new
stmt_drop = "DROP TABLE IF EXISTS names"
cur.execute(stmt_drop)
stmt_create = (
"CREATE TABLE names ("
"id TINYINT UNSIGNED NOT NULL AUTO_INCREMENT, "
"name VARCHAR(30) DEFAULT '' NOT NULL, "
"cnt TINYINT UNSIGNED DEFAULT 0, "
"PRIMARY KEY (id))"
)
cur.execute(stmt_create)
# Connector/Python also allows ? as placeholders for MySQL Prepared
# statements.
prepstmt = "INSERT INTO names (name) VALUES (%s)"
# Preparing the statement is done only once. It can be done before
# without data, or later with data.
curprep.execute(prepstmt)
# Insert 3 records
names = (
'Geert', 'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang',
'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang', 'Jan', 'Michel',
'wang', 'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang', 'Jan',
'Jan', 'Michel', 'wang')
for name in names:
curprep.execute(prepstmt, (name,))
conn.commit()
time.sleep(0.1)
# We use a normal cursor issue a SELECT
output.append("Inserted data")
cur.execute("SELECT id, name FROM names")
for row in cur:
output.append("{0} | {1}".format(*row))
# Cleaning up, dropping the table again
cur.execute(stmt_drop)
conn.close()
print(output)
if __name__ == '__main__':
_real_main(config)

View File

@ -0,0 +1,31 @@
#!/usr/bin/env bash
function execute_real(){
mysql_host=192.168.XX.XX
mysql_port=3358
user_name=user
passwd=123456
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "select 1"
sleep 1
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "use sniffer;show tables;create table haha(id int, name text)"
sleep 1
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e ""
sleep 1
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e ""
sleep 1
insert_cmd="insert into unibase.haha(id, name) values(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')"
insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')"
insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')"
insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')"
insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')"
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "$insert_cmd"
sleep 1
mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "use unibase; select * from haha; drop table haha"
sleep 1
}
while true
do
execute_real
done

18
session-dealer/config.go Normal file
View File

@ -0,0 +1,18 @@
package session_dealer
import (
"flag"
)
const (
ServiceTypeMysql = "mysql"
)
var (
serviceType string
)
func init() {
flag.StringVar(&serviceType, "service_type", "mysql", "service type. Default is mysql")
}

View File

@ -0,0 +1,24 @@
package session_dealer
import (
"github.com/zr-hebo/sniffer-agent/session-dealer/mysql"
)
func NewSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (session ConnSession) {
switch serviceType {
case ServiceTypeMysql:
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
default:
session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort)
}
return
}
func CheckParams() {
switch serviceType {
case ServiceTypeMysql:
mysql.CheckParams()
default:
mysql.CheckParams()
}
}

9
session-dealer/model.go Normal file
View File

@ -0,0 +1,9 @@
package session_dealer
import "github.com/zr-hebo/sniffer-agent/model"
type ConnSession interface {
ReadFromClient(bytes []byte)
ReadFromServer(bytes []byte)
GenerateQueryPiece() (qp model.QueryPiece)
}

View File

@ -0,0 +1,19 @@
package mysql
// parseAuthInfo parse username, dbname from mysql client auth info
func parseAuthInfo(data []byte) (userName, dbName string, err error) {
var resp handshakeResponse41
pos, err := parseHandshakeResponseHeader(&resp, data)
if err != nil {
return
}
// Read the remaining part of the packet.
if err = parseHandshakeResponseBody(&resp, data, pos); err != nil {
return
}
userName = resp.User
dbName = resp.DBName
return
}

View File

@ -0,0 +1,38 @@
package mysql
import (
"flag"
"fmt"
"regexp"
)
var (
uselessSQLPattern = regexp.MustCompile(`(?i)^\s*(select 1|select @@version_comment limit 1|`+
`SELECT user, db FROM information_schema.processlist WHERE host=|commit|begin)`)
)
var (
strictMode bool
adminUser string
adminPasswd string
)
func init() {
flag.BoolVar(&strictMode,"strict_mode", false, "strict mode. Default is false")
flag.StringVar(&adminUser,"admin_user", "", "admin user name. When set strict mode, must set admin user to query session info")
flag.StringVar(&adminPasswd,"admin_passwd", "", "admin user passwd. When use strict mode, must set admin user to query session info")
}
func CheckParams() {
if !strictMode {
return
}
if len(adminUser) < 1 {
panic(fmt.Sprintf("In strict mode, admin user name cannot be empty"))
}
if len(adminPasswd) < 1 {
panic(fmt.Sprintf("In strict mode, admin passwd cannot be empty"))
}
}

View File

@ -0,0 +1,50 @@
package mysql
import (
"fmt"
_ "github.com/go-sql-driver/mysql"
du "github.com/zr-hebo/util-db"
// log "github.com/sirupsen/logrus"
)
func expandLocalMysql(port int) (mysqlHost *du.MysqlDB) {
mysqlHost = new(du.MysqlDB)
mysqlHost.IP = "localhost"
mysqlHost.Port = port
mysqlHost.UserName = adminUser
mysqlHost.Passwd = adminPasswd
mysqlHost.DatabaseType = "mysql"
mysqlHost.ConnectTimeout = 1
return
}
func querySessionInfo(snifferPort int, clientHost string) (user, db *string, err error) {
mysqlServer := expandLocalMysql(snifferPort)
querySQL := fmt.Sprintf(
"SELECT user, db FROM information_schema.processlist WHERE host='%s'", clientHost)
// log.Debug(querySQL)
queryRow, err := mysqlServer.QueryRow(querySQL)
if err != nil {
return
}
if queryRow == nil {
return
}
userVal := queryRow.Record["user"]
if userVal != nil {
usrStr := userVal.(string)
user = &usrStr
}
dbVal := queryRow.Record["db"]
if dbVal != nil {
dbStr := dbVal.(string)
db = &dbStr
}
return
}

View File

@ -0,0 +1,105 @@
package mysql
import "errors"
// Command information.
const (
ComSleep byte = iota
ComQuit
ComInitDB
ComQuery
ComFieldList
ComCreateDB
ComDropDB
ComRefresh
ComShutdown
ComStatistics
ComProcessInfo
ComConnect
ComProcessKill
ComDebug
ComPing
ComTime
ComDelayedInsert
ComChangeUser
ComBinlogDump
ComTableDump
ComConnectOut
ComRegisterSlave
ComStmtPrepare
ComStmtExecute
ComStmtSendLongData
ComStmtClose
ComStmtReset
ComSetOption
ComStmtFetch
ComBinlogDumpGtid
ComResetConnection
)
const (
ComAuth = 141
)
// Client information.
const (
ClientLongPassword uint32 = 1 << iota
ClientFoundRows
ClientLongFlag
ClientConnectWithDB
ClientNoSchema
ClientCompress
ClientODBC
ClientLocalFiles
ClientIgnoreSpace
ClientProtocol41
ClientInteractive
ClientSSL
ClientIgnoreSigpipe
ClientTransactions
ClientReserved
ClientSecureConnection
ClientMultiStatements
ClientMultiResults
ClientPSMultiResults
ClientPluginAuth
ClientConnectAtts
ClientPluginAuthLenencClientData
)
// Auth name information.
const (
AuthName = "mysql_native_password"
)
// MySQL database and tables.
const (
// SystemDB is the name of system database.
SystemDB = "mysql"
// UserTable is the table in system db contains user info.
UserTable = "User"
// DBTable is the table in system db contains db scope privilege info.
DBTable = "DB"
// GlobalVariablesTable is the table contains global system variables.
GlobalVariablesTable = "GLOBAL_VARIABLES"
// GlobalStatusTable is the table contains global status variables.
GlobalStatusTable = "GLOBAL_STATUS"
)
// Identifier length limitations.
// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
const (
// MaxPayloadLen is the max packet payload length.
MaxPayloadLen = 1<<24 - 1
)
const (
datetimeFormat = "2006-01-02 15:04:05"
)
var (
ErrMalformPacket = errors.New("malform packet error")
)

View File

@ -0,0 +1,9 @@
package mysql
type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
}

View File

@ -0,0 +1,175 @@
package mysql
import (
"fmt"
"time"
"github.com/zr-hebo/sniffer-agent/model"
log "github.com/sirupsen/logrus"
)
type MysqlSession struct {
connectionID string
visitUser *string
visitDB *string
clientHost string
clientPort int
serverIP string
serverPort int
beginTime int64
expectSize int
prepareInfo *prepareInfo
cachedPrepareStmt map[int]*string
tcpCache []byte
cachedStmtBytes []byte
}
type prepareInfo struct {
prepareStmtID int
}
func NewMysqlSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (ms *MysqlSession) {
ms = &MysqlSession{
connectionID: sessionKey,
clientHost: clientIP,
clientPort: clientPort,
serverIP: serverIP,
serverPort: serverPort,
beginTime: time.Now().UnixNano() / int64(time.Millisecond),
cachedPrepareStmt: make(map[int]*string),
}
return
}
func (ms *MysqlSession) ReadFromServer(bytes []byte) {
if ms.expectSize < 1 {
ms.expectSize = extractMysqlPayloadSize(bytes)
contents := bytes[4:]
if ms.prepareInfo != nil && contents[0] == 0 {
ms.prepareInfo.prepareStmtID = bytesToInt(contents[1:5])
}
ms.expectSize = ms.expectSize - len(contents)
} else {
ms.expectSize = ms.expectSize - len(bytes)
}
}
func (ms *MysqlSession) ReadFromClient(bytes []byte) {
if ms.expectSize < 1 {
ms.expectSize = extractMysqlPayloadSize(bytes)
contents := bytes[4:]
if contents[0] == ComStmtPrepare {
ms.prepareInfo = &prepareInfo{}
}
ms.expectSize = ms.expectSize - len(contents)
ms.tcpCache = append(ms.tcpCache, contents...)
} else {
ms.expectSize = ms.expectSize - len(bytes)
ms.tcpCache = append(ms.tcpCache, bytes...)
if len(ms.tcpCache) == MaxPayloadLen {
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
ms.tcpCache = ms.tcpCache[:0]
ms.expectSize = 0
}
}
}
func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) {
if len(ms.cachedStmtBytes) < 1 && len(ms.tcpCache) < 1 {
return
}
var mqp *model.MysqlQueryPiece = nil
ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...)
switch ms.cachedStmtBytes[0] {
case ComAuth:
var userName, dbName string
var err error
userName, dbName, err = parseAuthInfo(ms.cachedStmtBytes)
if err != nil {
return
}
ms.visitUser = &userName
ms.visitDB = &dbName
case ComInitDB:
newDBName := string(ms.cachedStmtBytes[1:])
useSQL := fmt.Sprintf("use %s", newDBName)
mqp = ms.composeQueryPiece()
mqp.QuerySQL = &useSQL
// update session database
ms.visitDB = &newDBName
case ComCreateDB:
case ComDropDB:
case ComQuery:
mqp = ms.composeQueryPiece()
querySQL := string(ms.cachedStmtBytes[1:])
mqp.QuerySQL = &querySQL
case ComStmtPrepare:
mqp = ms.composeQueryPiece()
querySQL := string(ms.cachedStmtBytes[1:])
mqp.QuerySQL = &querySQL
ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = &querySQL
log.Debugf("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID)
case ComStmtExecute:
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
mqp = ms.composeQueryPiece()
mqp.QuerySQL = ms.cachedPrepareStmt[prepareStmtID]
log.Debugf("execute prepare statement:%d", prepareStmtID)
case ComStmtClose:
prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5])
delete(ms.cachedPrepareStmt, prepareStmtID)
log.Debugf("remove prepare statement:%d", prepareStmtID)
default:
}
if strictMode && mqp != nil && mqp.VisitUser == nil {
user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID)
if err != nil {
log.Errorf("query user and db from mysql failed <-- %s", err.Error())
} else {
mqp.VisitUser = user
mqp.VisitDB = db
}
}
ms.tcpCache = ms.tcpCache[:0]
ms.cachedStmtBytes = ms.cachedStmtBytes[:0]
ms.expectSize = 0
ms.prepareInfo = nil
return filterQueryPieceBySQL(mqp)
}
func filterQueryPieceBySQL(mqp *model.MysqlQueryPiece) (model.QueryPiece) {
if mqp == nil || mqp.QuerySQL == nil {
return nil
} else if (uselessSQLPattern.MatchString(*mqp.QuerySQL)) {
return nil
}
return mqp
}
func (ms *MysqlSession) composeQueryPiece() (mqp *model.MysqlQueryPiece) {
nowInMS := time.Now().UnixNano() / int64(time.Millisecond)
mqp = &model.MysqlQueryPiece{
SessionID: ms.connectionID,
ClientHost: ms.clientHost,
ServerIP: ms.serverIP,
ServerPort: ms.serverPort,
VisitUser: ms.visitUser,
VisitDB: ms.visitDB,
BeginTime: time.Unix(ms.beginTime/1000, 0).Format(datetimeFormat),
CostTimeInMS: nowInMS - ms.beginTime,
}
return mqp
}

View File

@ -0,0 +1,131 @@
package mysql
import (
"bytes"
"encoding/binary"
)
// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
func parseHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
return 0, ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
return offset, nil
}
// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
func parseHandshakeResponseBody(packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
err = ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
num, null, off := parseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
} else if packet.Capability&ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}
if packet.Capability&ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset = offset + idx + 1
}
}
if packet.Capability&ClientPluginAuth > 0 {
// TODO: Support mysql.ClientPluginAuth, skip it now
idx := bytes.IndexByte(data[offset:], 0)
offset = offset + idx + 1
}
if packet.Capability&ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
}
return nil
}
func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
switch b[0] {
// 251: NULL
case 0xfb:
n = 1
isNull = true
return
// 252: value of following 2
case 0xfc:
num = uint64(b[1]) | uint64(b[2])<<8
n = 3
return
// 253: value of following 3
case 0xfd:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
n = 4
return
// 254: value of following 8
case 0xfe:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56
n = 9
return
}
// 0-250: value of first byte
num = uint64(b[0])
n = 1
return
}
func extractMysqlPayloadSize(payload []byte) int {
header := payload[:4]
return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
}
func bytesToInt(contents []byte) int {
return int(uint32(contents[0]) | uint32(contents[1])<<8 | uint32(contents[2])<<16 | uint32(contents[3])<<24)
}