diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bc9cbe0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea/ +*.log +*.swp +one_key.sh +sniffer-agent +vendor/ diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 0000000..606625b --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,93 @@ +{ + "ImportPath": "github.com/zr-hebo/sniffer-agent", + "GoVersion": "go1.12", + "GodepVersion": "v79", + "Packages": [ + "./..." + ], + "Deps": [ + { + "ImportPath": "github.com/Shopify/sarama", + "Comment": "v1.14.0-20-gcd645bf", + "Rev": "cd645bfba7622e7de8971388543ff97ee026aad4" + }, + { + "ImportPath": "github.com/davecgh/go-spew/spew", + "Comment": "v1.1.0-9-gecdeabc", + "Rev": "ecdeabc65495df2dec95d7c4a4c3e021903035e5" + }, + { + "ImportPath": "github.com/eapache/go-resiliency/breaker", + "Comment": "v1.0.0-6-gb1fe83b", + "Rev": "b1fe83b5b03f624450823b751b662259ffc6af70" + }, + { + "ImportPath": "github.com/eapache/go-xerial-snappy", + "Rev": "bb955e01b9346ac19dc29eb16586c90ded99a98c" + }, + { + "ImportPath": "github.com/eapache/queue", + "Comment": "v1.1.0", + "Rev": "44cc805cf13205b55f69e14bcb69867d1ae92f98" + }, + { + "ImportPath": "github.com/go-sql-driver/mysql", + "Comment": "v1.3-42-gfade210", + "Rev": "fade21009797158e7b79e04c340118a9220c6f9e" + }, + { + "ImportPath": "github.com/golang/snappy", + "Rev": "553a641470496b2327abcac10b36396bd98e45c9" + }, + { + "ImportPath": "github.com/google/gopacket", + "Comment": "v1.1.17-26-gce2e696", + "Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949" + }, + { + "ImportPath": "github.com/google/gopacket/layers", + "Comment": "v1.1.17-26-gce2e696", + "Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949" + }, + { + "ImportPath": "github.com/google/gopacket/pcap", + "Comment": "v1.1.17-26-gce2e696", + "Rev": "ce2e696dc0c9917ecdebd800c892b839f06b2949" + }, + { + "ImportPath": "github.com/pierrec/lz4", + "Comment": "v1.0.1", + "Rev": "08c27939df1bd95e881e2c2367a749964ad1fceb" + }, + { + "ImportPath": "github.com/pierrec/xxHash/xxHash32", + "Comment": "v0.1-11-ga0006b1", + "Rev": "a0006b13c722f7f12368c00a3d3c2ae8a999a0c6" + }, + { + "ImportPath": "github.com/rcrowley/go-metrics", + "Rev": "e181e095bae94582363434144c61a9653aff6e50" + }, + { + "ImportPath": "github.com/sirupsen/logrus", + "Comment": "v1.0.4", + "Rev": "d682213848ed68c0a260ca37d6dd5ace8423f5ba" + }, + { + "ImportPath": "github.com/zr-hebo/util-db", + "Rev": "3ff29f916f7b712b3adc53c4b9b19b13b8bbed87" + }, + { + "ImportPath": "golang.org/x/crypto/ssh/terminal", + "Rev": "eb71ad9bd329b5ac0fd0148dd99bd62e8be8e035" + }, + { + "ImportPath": "golang.org/x/sys/unix", + "Rev": "ac767d655b305d4e9612f5f6e33120b9176c4ad4" + }, + { + "ImportPath": "golang.org/x/sys/windows", + "Rev": "ac767d655b305d4e9612f5f6e33120b9176c4ad4" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 0000000..4cdaa53 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..248e15a --- /dev/null +++ b/Makefile @@ -0,0 +1 @@ +go build diff --git a/README.md b/README.md index e6cc1c9..af73569 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,90 @@ -# sniffer-agent -抓取tcp包解析出mysql语句 +### sniffer-agent + +> Sniffer TCP package, parsed with mysql protocol, optional you can just print on screen or send query info to Kafka. +> 抓取tcp包解析出mysql语句,将查询信息打印在屏幕上或者发送到Kafka + +## Architecture + +架构设计: + +本项目采用模块化设计,主要分为四大模块:抓包模块,协议解析模块,输出模块,心跳模块 +![架构设计图](https://github.com/zr-hebo/sniffer-agent/blob/master/images/arch.png) + +## Parse Protocol + +sniffer-agent采用模块化结构,支持用户添加自己的解析模块,只要实现了统一的接口即可 +- [x] MySQL +- [ ] PostgreSQL +- [ ] Redis +- [ ] Mongodb + +目前输出的内容都是解析结果组成的json。 +MySQL协议的解析结果示例如下: +``` +{"sid":"10.XX.XX.XX:54656","sip":"192.168.XX.XX","sport":3306,"user":"root","db":"unibase","sql":"show tables","bt":"2019-08-05 18:23:09","cms":15} +``` +其中sid代表客户端ip:port组成的session标识,sip代表server ip,sport代表server port,user代表查询用户,db代表当前连接的库名,sql代表查询语句,bt代表查询开始时间,cms代表查询消耗的时间,单位是毫秒 + +## Exporter + +输出模块主要负责,将解析的结果对外输出。默认情况下输出到命令行,可以通过指定export_type参数选择kafka,这时候会直接将解析结果发送到kafka。 +同样只要实现了export接口,用户可以自定义自己的输出方式。 + +## Install: + +环境: + +golang:1.12 + +libpcap包 + +测试脚本运行在python3环境下 + + +1.安装依赖,目前自测支持Linux系列操作系统,其他版本的系统有待验证 + +CentOS: +``` +yum install libpcap-devel +``` + +Ubuntu: +``` +apt-get install libpcap-dev +``` +2.执行编译命令 go build + +## Demo + +目前只支持MySQL协议的抓取,需要将编译后的二进制文件上传到MySQL服务器上 +1.最简单的使用 + +`./sniffer-agent` + +2.指定log级别,可以指定的值为debug、info、warn、error,默认是info + +`./sniffer-agent --log_level=debug` + +默认会监听 网卡:eth0,端口3306 + +3.指定网卡和监听端口 + +`./sniffer-agent --interface=eth0 --port=3358` + +4.指定输出到kafka,为了将ddl和select、dml区分处理,这里使用了两个topic来生产消息 + +`./sniffer-agent --export_type=kafka --kafka-server=$kafka_server:$kafka_server --kafka-group-id=sniffer --kafka-async-topic=non_ddl_sql_collector --kafka-sync-topic=ddl_sql_collector` + +5.指定严格模式,通过查询获取长连接的用户名和数据库 + +`./sniffer-agent --strict_mode=true --admin_user=root --admin_passwd=123456` + +#### 题外话 +在做这个功能之前,项目组调研过类似功能的产品,最有名的是 [mysql-sniffer](https://github.com/Qihoo360/mysql-sniffer) 和 [go-sniffer](https://github.com/40t/go-sniffer),这两个产品都很优秀,不过我们的业务场景要求更多。 +我们需要将提取的SQL信息发送到kafka进行处理,之前的两个产品输出的结果需要进行一些处理然后自己发送,在QPS比较高的情况下,这些处理会消耗较多的CPU; +另外mysql-sniffer使用c++开发,平台的适用性较差,后期扩展较难。 +开发的过程中也借鉴了这些产品的思想,另外在MySQL包解析的时候,参考了一些 [TiDB](https://github.com/pingcap/tidb) 的内容,部分私有变量和函数直接复制使用,这里向这些优秀的产品致敬,如有侵权请随时联系。 + +## License +[MIT](https://opensource.org/licenses/MIT) + diff --git a/capture/config.go b/capture/config.go new file mode 100644 index 0000000..ecdbbec --- /dev/null +++ b/capture/config.go @@ -0,0 +1,22 @@ +package capture + +import ( + sd "github.com/zr-hebo/sniffer-agent/session-dealer" + log "github.com/sirupsen/logrus" +) + +var ( + localIPAddr string + + sessionPool = make(map[string]sd.ConnSession) +) + +func init() { + var err error + localIPAddr, err = getLocalIPAddr() + if err != nil { + panic(err) + } + + log.Infof("parsed local ip address:%s", localIPAddr) +} diff --git a/capture/const.go b/capture/const.go new file mode 100644 index 0000000..647d52b --- /dev/null +++ b/capture/const.go @@ -0,0 +1 @@ +package capture \ No newline at end of file diff --git a/capture/network.go b/capture/network.go new file mode 100644 index 0000000..257be39 --- /dev/null +++ b/capture/network.go @@ -0,0 +1,180 @@ +package capture + +import ( + "flag" + "fmt" + log "github.com/sirupsen/logrus" + sd "github.com/zr-hebo/sniffer-agent/session-dealer" + "github.com/zr-hebo/sniffer-agent/model" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" +) + +var ( + DeviceName string + snifferPort int +) + +func init() { + flag.StringVar(&DeviceName, "interface", "eth0", "network device name. Default is eth0") + flag.IntVar(&snifferPort, "port", 3306, "sniffer port. Default is 3306") +} + +// networkCard is network device +type networkCard struct { + name string + listenPort int +} + +func NewNetworkCard() (nc *networkCard) { + // init device + return &networkCard{name: DeviceName, listenPort: snifferPort} +} + +// Listen get a connection. +func (nc *networkCard) Listen() (receiver chan model.QueryPiece) { + receiver = make(chan model.QueryPiece, 100) + + go func() { + defer func() { + close(receiver) + }() + + handle, err := pcap.OpenLive(DeviceName, 65535, false, pcap.BlockForever) + if err != nil { + panic(fmt.Sprintf("cannot open network interface %s <-- %s", nc.name, err.Error())) + } + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + for packet := range packetSource.Packets() { + if packet.NetworkLayer() == nil || packet.TransportLayer() == nil { + // log.Info("empty network layer") + continue + } + + if packet.TransportLayer().LayerType() != layers.LayerTypeTCP { + // log.Info("packet type is %s, not TCP", packet.TransportLayer().LayerType()) + continue + } + + qp := nc.parseTCPPackage(packet) + if qp != nil { + receiver <- qp + } + } + }() + + return +} + +func (nc *networkCard) parseTCPPackage(packet gopacket.Packet) (qp model.QueryPiece) { + var err error + defer func() { + if err != nil { + log.Error("parse TCP package failed <-- %s", err.Error()) + } + }() + + tcpConn := packet.TransportLayer().(*layers.TCP) + if tcpConn.SYN || tcpConn.RST { + return + } + + if(int(tcpConn.DstPort) != nc.listenPort && int(tcpConn.SrcPort) != nc.listenPort) { + return + } + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + if ipLayer == nil { + err = fmt.Errorf("no ip layer found in package") + return + } + + ipInfo, ok := ipLayer.(*layers.IPv4) + if !ok { + err = fmt.Errorf("parsed no ip address") + return + } + + // get IP from ip layer + srcIP := ipInfo.SrcIP.String() + dstIP := ipInfo.DstIP.String() + srcPort := int(tcpConn.SrcPort) + dstPort := int(tcpConn.DstPort) + if dstIP == localIPAddr && dstPort == nc.listenPort { + // deal mysql server response + err = readToServerPackage(srcIP, srcPort, tcpConn) + if err != nil { + return + } + + } else if srcIP == localIPAddr && srcPort == nc.listenPort { + // deal mysql client request + qp, err = readFromServerPackage(dstIP, dstPort, tcpConn) + if err != nil { + return + } + } + + return +} + +func readFromServerPackage(srcIP string, srcPort int, tcpConn *layers.TCP) (qp model.QueryPiece, err error) { + defer func() { + if err != nil { + log.Error("read Mysql package send from mysql server to client failed <-- %s", err.Error()) + } + }() + + sessionKey := spliceSessionKey(srcIP, srcPort) + if tcpConn.FIN { + delete(sessionPool, sessionKey) + // log.Debugf("close connection from %s", sessionKey) + return + } + + tcpPayload := tcpConn.Payload + if (len(tcpPayload) < 1) { + return + } + + session := sessionPool[sessionKey] + if session != nil { + session.ReadFromServer(tcpPayload) + qp = session.GenerateQueryPiece() + } + + return +} + +func readToServerPackage(srcIP string, srcPort int, tcpConn *layers.TCP) (err error) { + defer func() { + if err != nil { + log.Error("read package send from client to mysql server failed <-- %s", err.Error()) + } + }() + + sessionKey := spliceSessionKey(srcIP, srcPort) + // when client try close connection remove session from session pool + if tcpConn.FIN { + delete(sessionPool, sessionKey) + // log.Debugf("close connection from %s", sessionKey) + return + } + + tcpPayload := tcpConn.Payload + if (len(tcpPayload) < 1) { + return + } + + session := sessionPool[sessionKey] + if session == nil { + session = sd.NewSession(sessionKey, srcIP, srcPort, localIPAddr, snifferPort) + sessionPool[sessionKey] = session + } + + session.ReadFromClient(tcpPayload) + return +} + diff --git a/capture/util.go b/capture/util.go new file mode 100644 index 0000000..9bc4d1a --- /dev/null +++ b/capture/util.go @@ -0,0 +1,40 @@ +package capture + +import ( + "fmt" + "net" + "strings" +) + +func getLocalIPAddr() (ipAddr string, err error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return + } + + for _, addr := range addrs { + addrStr := addr.String() + if strings.Contains(addrStr, "127.0.0.1") || + strings.Contains(addrStr, "::1") || + strings.Contains(addrStr, "/64") { + continue + } + + addrStr = strings.TrimRight(addrStr, "1234567890") + addrStr = strings.TrimRight(addrStr, "/") + if len(addrStr) < 1 { + continue + } + + ipAddr = addrStr + return + } + + err = fmt.Errorf("no valid ip address found") + return +} + +func spliceSessionKey(srcIP string, srcPort int) (sessionKey string) { + sessionKey = fmt.Sprintf("%s:%d", srcIP, srcPort) + return +} diff --git a/communicator/config.go b/communicator/config.go new file mode 100644 index 0000000..1ff034e --- /dev/null +++ b/communicator/config.go @@ -0,0 +1,31 @@ +package communicator + +import ( + "flag" + "net/http" + "time" + + _ "net/http/pprof" + "github.com/gorilla/mux" +) + +var ( + communicatePort int +) + +func init() { + flag.IntVar(&communicatePort, "communicate_port", 8088, "http server port. Default is 8088") +} + +func Server() { + server := &http.Server{ + Addr: ":" + string(communicatePort), + Handler: mux.NewRouter(), + IdleTimeout: time.Second * 5, + } + + if err := server.ListenAndServe(); err != nil { + panic(err) + } +} + diff --git a/exporter/cli.go b/exporter/cli.go new file mode 100644 index 0000000..3ba7718 --- /dev/null +++ b/exporter/cli.go @@ -0,0 +1,18 @@ +package exporter + +import ( + "fmt" + "github.com/zr-hebo/sniffer-agent/model" +) + +type cliExporter struct { +} + +func NewCliExporter() *cliExporter { + return &cliExporter{} +} + +func (c *cliExporter) Export (qp model.QueryPiece) (err error){ + fmt.Println(qp.String()) + return +} \ No newline at end of file diff --git a/exporter/kafka.go b/exporter/kafka.go new file mode 100644 index 0000000..0e7dddc --- /dev/null +++ b/exporter/kafka.go @@ -0,0 +1,117 @@ +package exporter + +import ( + "flag" + "fmt" + "regexp" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/Shopify/sarama" + "github.com/zr-hebo/sniffer-agent/model" +) + +var ( + ddlPatern = regexp.MustCompile(`(?i)^\s*(create|alter|drop)`) + kafkaServer string + kafkaGroupID string + asyncTopic string + syncTopic string +) + +func init() { + flag.StringVar( + &kafkaServer, "kafka-server", "", "kafka server address. No default value") + flag.StringVar( + &kafkaGroupID, + "kafka-group-id", "", "kafka service group. No default value") + flag.StringVar( + &asyncTopic, + "kafka-async-topic", "", "kafka async send topic. No default value") + flag.StringVar( + &syncTopic, + "kafka-sync-topic", "", "kafka sync send topic. No default value") +} + +type kafkaExporter struct { + asyncProducer sarama.AsyncProducer + syncProducer sarama.SyncProducer + asyncTopic string + syncTopic string +} + +func checkParams() { + params := make(map[string]string) + params["kafka-server"] = kafkaServer + params["kafka-group-id"] = kafkaGroupID + params["kafka-async-topic"] = asyncTopic + params["kafka-sync-topic"] = syncTopic + for param := range params { + if len(params[param]) < 1{ + panic(fmt.Sprintf("%s cannot be empty", param)) + } + } +} + +func NewKafkaExporter() (ke *kafkaExporter) { + checkParams() + ke = &kafkaExporter{} + conf := sarama.NewConfig() + conf.Producer.Return.Successes = true + conf.ClientID = kafkaGroupID + addrs := strings.Split(kafkaServer, ",") + syncProducer, err := sarama.NewSyncProducer(addrs, conf) + if err != nil { + panic(err.Error()) + } + ke.syncProducer = syncProducer + + asyncProducer, err := sarama.NewAsyncProducer(addrs, conf) + if err != nil { + panic(err.Error()) + } + ke.asyncProducer = asyncProducer + ke.asyncTopic = asyncTopic + ke.syncTopic = syncTopic + + go func() { + errors := ke.asyncProducer.Errors() + success := ke.asyncProducer.Successes() + for { + select { + case err := <-errors: + if err != nil { + log.Error(err.Error()) + } + + case <-success: + } + } + }() + return +} + +func (ke *kafkaExporter) Export (qp model.QueryPiece) (err error){ + if ddlPatern.MatchString(qp.GetSQL()) { + log.Debugf("deal ddl: %s\n", qp.String()) + + msg := &sarama.ProducerMessage{ + Topic: ke.syncTopic, + Value: sarama.StringEncoder(qp.String()), + } + _, _, err = ke.syncProducer.SendMessage(msg) + if err != nil { + return + } + + } else { + log.Debugf("deal non ddl: %s", qp.String()) + msg := &sarama.ProducerMessage{ + Topic: ke.asyncTopic, + Value: sarama.ByteEncoder(qp.Bytes()), + } + + ke.asyncProducer.Input() <- msg + } + return +} diff --git a/exporter/model.go b/exporter/model.go new file mode 100644 index 0000000..26666f5 --- /dev/null +++ b/exporter/model.go @@ -0,0 +1,30 @@ +package exporter + +import ( + "flag" + + "github.com/zr-hebo/sniffer-agent/model" +) + +var ( + exportType string +) + +func init() { + flag.StringVar(&exportType,"export_type", "cli", "export type. Default is cli, that is command line") +} + +type Exporter interface { + Export(model.QueryPiece) error +} + +func NewExporter() Exporter { + switch exportType { + case "cli": + return NewCliExporter() + case "kafka": + return NewKafkaExporter() + default: + return NewCliExporter() + } +} diff --git a/images/arch.png b/images/arch.png new file mode 100644 index 0000000..1205183 Binary files /dev/null and b/images/arch.png differ diff --git a/main.go b/main.go new file mode 100644 index 0000000..ff38957 --- /dev/null +++ b/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "flag" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + "github.com/zr-hebo/sniffer-agent/capture" + "github.com/zr-hebo/sniffer-agent/exporter" + "github.com/zr-hebo/sniffer-agent/communicator" + sd "github.com/zr-hebo/sniffer-agent/session-dealer" +) + +var ( + logLevel string +) + +func init() { + flag.StringVar(&logLevel, "log_level", "info", "log level. Default is info") +} + +func initLog() { + log.SetFormatter(&log.TextFormatter{}) + log.SetOutput(os.Stdout) + switch logLevel { + case "debug": + log.SetLevel(log.DebugLevel) + case "info": + log.SetLevel(log.InfoLevel) + case "warn": + log.SetLevel(log.WarnLevel) + case "error": + log.SetLevel(log.ErrorLevel) + default: + panic(fmt.Sprintf("cannot set log level:%s, there have four types can set: debug, info, warn, error", logLevel)) + } +} + +func main() { + flag.Parse() + initLog() + sd.CheckParams() + + go communicator.Server() + mainServer() +} + +func mainServer() { + ept := exporter.NewExporter() + networkCard := capture.NewNetworkCard() + log.Info("begin listen") + for queryPiece := range networkCard.Listen() { + err := ept.Export(queryPiece) + if err != nil { + log.Error(err.Error()) + } + } + + log.Errorf("cannot get network package from %s", capture.DeviceName) + os.Exit(1) +} \ No newline at end of file diff --git a/model/query_piece.go b/model/query_piece.go new file mode 100644 index 0000000..f0946e9 --- /dev/null +++ b/model/query_piece.go @@ -0,0 +1,49 @@ +package model + +import ( + "encoding/json" +) + +type QueryPiece interface { + String() string + Bytes() []byte + GetSQL() string +} + +// MysqlQueryPiece 查询信息 +type MysqlQueryPiece struct { + SessionID string `json:"sid"` + ClientHost string `json:"-"` + ServerIP string `json:"sip"` + ServerPort int `json:"sport"` + VisitUser *string `json:"user"` + VisitDB *string `json:"db"` + QuerySQL *string `json:"sql"` + BeginTime string `json:"bt"` + CostTimeInMS int64 `json:"cms"` +} + +func (qp *MysqlQueryPiece) String() (str string) { + content, err := json.Marshal(qp) + if err != nil { + return err.Error() + } + + return string(content) +} + +func (qp *MysqlQueryPiece) Bytes() (bytes []byte) { + content, err := json.Marshal(qp) + if err != nil { + return []byte(err.Error()) + } + + return content +} + +func (qp *MysqlQueryPiece) GetSQL() (str string) { + if qp.QuerySQL != nil { + return *qp.QuerySQL + } + return "" +} diff --git a/scripts/check_kafka.py b/scripts/check_kafka.py new file mode 100644 index 0000000..36448d6 --- /dev/null +++ b/scripts/check_kafka.py @@ -0,0 +1,53 @@ +# import json +from kafka import KafkaConsumer, KafkaProducer + + + +kafka_server = '192.168.XX.XX:9091' + +group_id = 'sniffer' +topic = 'ddl_sql_collector' + + +def check_consume(): + conf = { + 'bootstrap_servers': kafka_server, + 'client_id': group_id, + 'group_id': group_id, + 'auto_offset_reset': 'earliest', + 'session_timeout_ms': 60000, + 'api_version': (0, 9, 0, 1) + } + consumer = KafkaConsumer(topic, **conf) + print('ready to consume') + for msg in consumer: + # event = json.loads(bytes.decode(msg.value)) + print(msg) + + +def check_produce(): + conf = { + 'bootstrap_servers': kafka_server, + 'client_id': group_id + } + # 'api_version': (0, 9, 0, 1) + producer = KafkaProducer(**conf) + try: + future = producer.send(topic, 'haha') + result = future.get(timeout=3) + print('send OK') + print(result) + + except BaseException as e: + print('send failed') + # 发送失败时,用户需根据业务逻辑做异常处理,否则消息可能会丢失 + print(str(e)) + + +def _real_main(): + check_produce() + # check_consume() + + +if __name__ == '__main__': + _real_main() diff --git a/scripts/check_prepare.py b/scripts/check_prepare.py new file mode 100644 index 0000000..16427e9 --- /dev/null +++ b/scripts/check_prepare.py @@ -0,0 +1,75 @@ +import time +import mysql.connector +from mysql.connector.cursor import MySQLCursorPrepared + +config = { + 'host': '192.168.XX.XX', + 'port': 3358, + 'database': 'sniffer', + 'user': 'root', + 'password': '', + 'charset': 'utf8', + 'use_unicode': True, + 'get_warnings': True, +} + + +def _real_main(config): + while True: + _once_check() + + +def _once_check(): + output = [] + conn = mysql.connector.Connect(**config) + + curprep = conn.cursor(cursor_class=MySQLCursorPrepared) + cur = conn.cursor() + + # Drop table if exists, and create it new + stmt_drop = "DROP TABLE IF EXISTS names" + cur.execute(stmt_drop) + + stmt_create = ( + "CREATE TABLE names (" + "id TINYINT UNSIGNED NOT NULL AUTO_INCREMENT, " + "name VARCHAR(30) DEFAULT '' NOT NULL, " + "cnt TINYINT UNSIGNED DEFAULT 0, " + "PRIMARY KEY (id))" + ) + cur.execute(stmt_create) + + # Connector/Python also allows ? as placeholders for MySQL Prepared + # statements. + prepstmt = "INSERT INTO names (name) VALUES (%s)" + + # Preparing the statement is done only once. It can be done before + # without data, or later with data. + curprep.execute(prepstmt) + + # Insert 3 records + names = ( + 'Geert', 'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang', + 'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang', 'Jan', 'Michel', + 'wang', 'Jan', 'Michel', 'wang', 'Jan', 'Michel', 'wang', 'Jan', + 'Jan', 'Michel', 'wang') + for name in names: + curprep.execute(prepstmt, (name,)) + conn.commit() + time.sleep(0.1) + + # We use a normal cursor issue a SELECT + output.append("Inserted data") + cur.execute("SELECT id, name FROM names") + for row in cur: + output.append("{0} | {1}".format(*row)) + + # Cleaning up, dropping the table again + cur.execute(stmt_drop) + + conn.close() + print(output) + + +if __name__ == '__main__': + _real_main(config) diff --git a/scripts/generate_mysql_select.sh b/scripts/generate_mysql_select.sh new file mode 100755 index 0000000..1e85a79 --- /dev/null +++ b/scripts/generate_mysql_select.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +function execute_real(){ + mysql_host=192.168.XX.XX + mysql_port=3358 + user_name=user + passwd=123456 + + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "select 1" + sleep 1 + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "use sniffer;show tables;create table haha(id int, name text)" + sleep 1 + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "" + sleep 1 + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "" + sleep 1 + insert_cmd="insert into unibase.haha(id, name) values(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')" + insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')" + insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')" + insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')" + insert_cmd="$insert_cmd,(10, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')" + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "$insert_cmd" + sleep 1 + mysql -h$mysql_host -P$mysql_port -u$user_name -p$passwd jmms -e "use unibase; select * from haha; drop table haha" + sleep 1 +} + +while true + do + execute_real + done diff --git a/session-dealer/config.go b/session-dealer/config.go new file mode 100644 index 0000000..bc1e1d8 --- /dev/null +++ b/session-dealer/config.go @@ -0,0 +1,18 @@ +package session_dealer + +import ( + "flag" +) + +const ( + ServiceTypeMysql = "mysql" +) + +var ( + serviceType string +) + +func init() { + flag.StringVar(&serviceType, "service_type", "mysql", "service type. Default is mysql") +} + diff --git a/session-dealer/controller.go b/session-dealer/controller.go new file mode 100644 index 0000000..f1a41f7 --- /dev/null +++ b/session-dealer/controller.go @@ -0,0 +1,24 @@ +package session_dealer + +import ( + "github.com/zr-hebo/sniffer-agent/session-dealer/mysql" +) + +func NewSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (session ConnSession) { + switch serviceType { + case ServiceTypeMysql: + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort) + default: + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, serverIP, serverPort) + } + return +} + +func CheckParams() { + switch serviceType { + case ServiceTypeMysql: + mysql.CheckParams() + default: + mysql.CheckParams() + } +} diff --git a/session-dealer/model.go b/session-dealer/model.go new file mode 100644 index 0000000..fb306d3 --- /dev/null +++ b/session-dealer/model.go @@ -0,0 +1,9 @@ +package session_dealer + +import "github.com/zr-hebo/sniffer-agent/model" + +type ConnSession interface { + ReadFromClient(bytes []byte) + ReadFromServer(bytes []byte) + GenerateQueryPiece() (qp model.QueryPiece) +} diff --git a/session-dealer/mysql/auth_info.go b/session-dealer/mysql/auth_info.go new file mode 100644 index 0000000..7da3c9e --- /dev/null +++ b/session-dealer/mysql/auth_info.go @@ -0,0 +1,19 @@ +package mysql + +// parseAuthInfo parse username, dbname from mysql client auth info +func parseAuthInfo(data []byte) (userName, dbName string, err error) { + var resp handshakeResponse41 + pos, err := parseHandshakeResponseHeader(&resp, data) + if err != nil { + return + } + + // Read the remaining part of the packet. + if err = parseHandshakeResponseBody(&resp, data, pos); err != nil { + return + } + + userName = resp.User + dbName = resp.DBName + return +} diff --git a/session-dealer/mysql/config.go b/session-dealer/mysql/config.go new file mode 100644 index 0000000..c671275 --- /dev/null +++ b/session-dealer/mysql/config.go @@ -0,0 +1,38 @@ +package mysql + +import ( + "flag" + "fmt" + "regexp" +) + +var ( + uselessSQLPattern = regexp.MustCompile(`(?i)^\s*(select 1|select @@version_comment limit 1|`+ + `SELECT user, db FROM information_schema.processlist WHERE host=|commit|begin)`) +) + +var ( + strictMode bool + adminUser string + adminPasswd string +) + +func init() { + flag.BoolVar(&strictMode,"strict_mode", false, "strict mode. Default is false") + flag.StringVar(&adminUser,"admin_user", "", "admin user name. When set strict mode, must set admin user to query session info") + flag.StringVar(&adminPasswd,"admin_passwd", "", "admin user passwd. When use strict mode, must set admin user to query session info") +} + +func CheckParams() { + if !strictMode { + return + } + + if len(adminUser) < 1 { + panic(fmt.Sprintf("In strict mode, admin user name cannot be empty")) + } + + if len(adminPasswd) < 1 { + panic(fmt.Sprintf("In strict mode, admin passwd cannot be empty")) + } +} diff --git a/session-dealer/mysql/connections.go b/session-dealer/mysql/connections.go new file mode 100644 index 0000000..96cdcda --- /dev/null +++ b/session-dealer/mysql/connections.go @@ -0,0 +1,50 @@ +package mysql + +import ( + "fmt" + + _ "github.com/go-sql-driver/mysql" + du "github.com/zr-hebo/util-db" + // log "github.com/sirupsen/logrus" +) + +func expandLocalMysql(port int) (mysqlHost *du.MysqlDB) { + mysqlHost = new(du.MysqlDB) + mysqlHost.IP = "localhost" + mysqlHost.Port = port + mysqlHost.UserName = adminUser + mysqlHost.Passwd = adminPasswd + mysqlHost.DatabaseType = "mysql" + mysqlHost.ConnectTimeout = 1 + + return +} + +func querySessionInfo(snifferPort int, clientHost string) (user, db *string, err error) { + mysqlServer := expandLocalMysql(snifferPort) + querySQL := fmt.Sprintf( + "SELECT user, db FROM information_schema.processlist WHERE host='%s'", clientHost) + // log.Debug(querySQL) + queryRow, err := mysqlServer.QueryRow(querySQL) + if err != nil { + return + } + + if queryRow == nil { + return + } + + userVal := queryRow.Record["user"] + if userVal != nil { + usrStr := userVal.(string) + user = &usrStr + } + + dbVal := queryRow.Record["db"] + if dbVal != nil { + dbStr := dbVal.(string) + db = &dbStr + } + + return +} \ No newline at end of file diff --git a/session-dealer/mysql/const.go b/session-dealer/mysql/const.go new file mode 100644 index 0000000..cb4b30d --- /dev/null +++ b/session-dealer/mysql/const.go @@ -0,0 +1,105 @@ +package mysql + +import "errors" + +// Command information. +const ( + ComSleep byte = iota + ComQuit + ComInitDB + ComQuery + ComFieldList + ComCreateDB + ComDropDB + ComRefresh + ComShutdown + ComStatistics + ComProcessInfo + ComConnect + ComProcessKill + ComDebug + ComPing + ComTime + ComDelayedInsert + ComChangeUser + ComBinlogDump + ComTableDump + ComConnectOut + ComRegisterSlave + ComStmtPrepare + ComStmtExecute + ComStmtSendLongData + ComStmtClose + ComStmtReset + ComSetOption + ComStmtFetch + ComBinlogDumpGtid + ComResetConnection +) + +const ( + ComAuth = 141 +) + +// Client information. +const ( + ClientLongPassword uint32 = 1 << iota + ClientFoundRows + ClientLongFlag + ClientConnectWithDB + ClientNoSchema + ClientCompress + ClientODBC + ClientLocalFiles + ClientIgnoreSpace + ClientProtocol41 + ClientInteractive + ClientSSL + ClientIgnoreSigpipe + ClientTransactions + ClientReserved + ClientSecureConnection + ClientMultiStatements + ClientMultiResults + ClientPSMultiResults + ClientPluginAuth + ClientConnectAtts + ClientPluginAuthLenencClientData +) + + +// Auth name information. +const ( + AuthName = "mysql_native_password" +) + + +// MySQL database and tables. +const ( + // SystemDB is the name of system database. + SystemDB = "mysql" + // UserTable is the table in system db contains user info. + UserTable = "User" + // DBTable is the table in system db contains db scope privilege info. + DBTable = "DB" + // GlobalVariablesTable is the table contains global system variables. + GlobalVariablesTable = "GLOBAL_VARIABLES" + // GlobalStatusTable is the table contains global status variables. + GlobalStatusTable = "GLOBAL_STATUS" +) + + +// Identifier length limitations. +// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html +const ( + // MaxPayloadLen is the max packet payload length. + MaxPayloadLen = 1<<24 - 1 +) + +const ( + datetimeFormat = "2006-01-02 15:04:05" +) + +var ( + ErrMalformPacket = errors.New("malform packet error") +) \ No newline at end of file diff --git a/session-dealer/mysql/model.go b/session-dealer/mysql/model.go new file mode 100644 index 0000000..a33b40c --- /dev/null +++ b/session-dealer/mysql/model.go @@ -0,0 +1,9 @@ +package mysql + +type handshakeResponse41 struct { + Capability uint32 + Collation uint8 + User string + DBName string + Auth []byte +} diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go new file mode 100644 index 0000000..d734a7c --- /dev/null +++ b/session-dealer/mysql/session.go @@ -0,0 +1,175 @@ +package mysql + +import ( + "fmt" + "time" + + "github.com/zr-hebo/sniffer-agent/model" + log "github.com/sirupsen/logrus" +) + +type MysqlSession struct { + connectionID string + visitUser *string + visitDB *string + clientHost string + clientPort int + serverIP string + serverPort int + beginTime int64 + expectSize int + prepareInfo *prepareInfo + cachedPrepareStmt map[int]*string + tcpCache []byte + cachedStmtBytes []byte +} + +type prepareInfo struct { + prepareStmtID int +} + +func NewMysqlSession(sessionKey string, clientIP string, clientPort int, serverIP string, serverPort int) (ms *MysqlSession) { + ms = &MysqlSession{ + connectionID: sessionKey, + clientHost: clientIP, + clientPort: clientPort, + serverIP: serverIP, + serverPort: serverPort, + beginTime: time.Now().UnixNano() / int64(time.Millisecond), + cachedPrepareStmt: make(map[int]*string), + } + return +} + +func (ms *MysqlSession) ReadFromServer(bytes []byte) { + if ms.expectSize < 1 { + ms.expectSize = extractMysqlPayloadSize(bytes) + contents := bytes[4:] + if ms.prepareInfo != nil && contents[0] == 0 { + ms.prepareInfo.prepareStmtID = bytesToInt(contents[1:5]) + } + ms.expectSize = ms.expectSize - len(contents) + + } else { + ms.expectSize = ms.expectSize - len(bytes) + } +} + +func (ms *MysqlSession) ReadFromClient(bytes []byte) { + if ms.expectSize < 1 { + ms.expectSize = extractMysqlPayloadSize(bytes) + contents := bytes[4:] + if contents[0] == ComStmtPrepare { + ms.prepareInfo = &prepareInfo{} + } + + ms.expectSize = ms.expectSize - len(contents) + ms.tcpCache = append(ms.tcpCache, contents...) + + } else { + ms.expectSize = ms.expectSize - len(bytes) + ms.tcpCache = append(ms.tcpCache, bytes...) + if len(ms.tcpCache) == MaxPayloadLen { + ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...) + ms.tcpCache = ms.tcpCache[:0] + ms.expectSize = 0 + } + } +} + +func (ms *MysqlSession) GenerateQueryPiece() (qp model.QueryPiece) { + if len(ms.cachedStmtBytes) < 1 && len(ms.tcpCache) < 1 { + return + } + + var mqp *model.MysqlQueryPiece = nil + ms.cachedStmtBytes = append(ms.cachedStmtBytes, ms.tcpCache...) + switch ms.cachedStmtBytes[0] { + case ComAuth: + var userName, dbName string + var err error + userName, dbName, err = parseAuthInfo(ms.cachedStmtBytes) + if err != nil { + return + } + ms.visitUser = &userName + ms.visitDB = &dbName + + case ComInitDB: + newDBName := string(ms.cachedStmtBytes[1:]) + useSQL := fmt.Sprintf("use %s", newDBName) + mqp = ms.composeQueryPiece() + mqp.QuerySQL = &useSQL + // update session database + ms.visitDB = &newDBName + + case ComCreateDB: + case ComDropDB: + case ComQuery: + mqp = ms.composeQueryPiece() + querySQL := string(ms.cachedStmtBytes[1:]) + mqp.QuerySQL = &querySQL + + case ComStmtPrepare: + mqp = ms.composeQueryPiece() + querySQL := string(ms.cachedStmtBytes[1:]) + mqp.QuerySQL = &querySQL + ms.cachedPrepareStmt[ms.prepareInfo.prepareStmtID] = &querySQL + log.Debugf("prepare statement %s, get id:%d", querySQL, ms.prepareInfo.prepareStmtID) + + case ComStmtExecute: + prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5]) + mqp = ms.composeQueryPiece() + mqp.QuerySQL = ms.cachedPrepareStmt[prepareStmtID] + log.Debugf("execute prepare statement:%d", prepareStmtID) + + case ComStmtClose: + prepareStmtID := bytesToInt(ms.cachedStmtBytes[1:5]) + delete(ms.cachedPrepareStmt, prepareStmtID) + log.Debugf("remove prepare statement:%d", prepareStmtID) + + default: + } + + if strictMode && mqp != nil && mqp.VisitUser == nil { + user, db, err := querySessionInfo(ms.serverPort, mqp.SessionID) + if err != nil { + log.Errorf("query user and db from mysql failed <-- %s", err.Error()) + } else { + mqp.VisitUser = user + mqp.VisitDB = db + } + } + + ms.tcpCache = ms.tcpCache[:0] + ms.cachedStmtBytes = ms.cachedStmtBytes[:0] + ms.expectSize = 0 + ms.prepareInfo = nil + return filterQueryPieceBySQL(mqp) +} + +func filterQueryPieceBySQL(mqp *model.MysqlQueryPiece) (model.QueryPiece) { + if mqp == nil || mqp.QuerySQL == nil { + return nil + + } else if (uselessSQLPattern.MatchString(*mqp.QuerySQL)) { + return nil + } + + return mqp +} + +func (ms *MysqlSession) composeQueryPiece() (mqp *model.MysqlQueryPiece) { + nowInMS := time.Now().UnixNano() / int64(time.Millisecond) + mqp = &model.MysqlQueryPiece{ + SessionID: ms.connectionID, + ClientHost: ms.clientHost, + ServerIP: ms.serverIP, + ServerPort: ms.serverPort, + VisitUser: ms.visitUser, + VisitDB: ms.visitDB, + BeginTime: time.Unix(ms.beginTime/1000, 0).Format(datetimeFormat), + CostTimeInMS: nowInMS - ms.beginTime, + } + return mqp +} diff --git a/session-dealer/mysql/util.go b/session-dealer/mysql/util.go new file mode 100644 index 0000000..9153bcf --- /dev/null +++ b/session-dealer/mysql/util.go @@ -0,0 +1,131 @@ +package mysql + +import ( + "bytes" + "encoding/binary" +) + +// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41. +func parseHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) { + // Ensure there are enough data to read: + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if len(data) < 4+4+1+23 { + return 0, ErrMalformPacket + } + + offset := 0 + // capability + capability := binary.LittleEndian.Uint32(data[:4]) + packet.Capability = capability + offset += 4 + // skip max packet size + offset += 4 + // charset, skip, if you want to use another charset, use set names + packet.Collation = data[offset] + offset++ + // skip reserved 23[00] + offset += 23 + + return offset, nil +} + +// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part). +func parseHandshakeResponseBody(packet *handshakeResponse41, data []byte, offset int) (err error) { + defer func() { + // Check malformat packet cause out of range is disgusting, but don't panic! + if r := recover(); r != nil { + err = ErrMalformPacket + } + }() + // user name + packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)]) + offset += len(packet.User) + 1 + + if packet.Capability&ClientPluginAuthLenencClientData > 0 { + // MySQL client sets the wrong capability, it will set this bit even server doesn't + // support ClientPluginAuthLenencClientData. + // https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478 + num, null, off := parseLengthEncodedInt(data[offset:]) + offset += off + if !null { + packet.Auth = data[offset : offset+int(num)] + offset += int(num) + } + } else if packet.Capability&ClientSecureConnection > 0 { + // auth length and auth + authLen := int(data[offset]) + offset++ + packet.Auth = data[offset : offset+authLen] + offset += authLen + } else { + packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] + offset += len(packet.Auth) + 1 + } + + if packet.Capability&ClientConnectWithDB > 0 { + if len(data[offset:]) > 0 { + idx := bytes.IndexByte(data[offset:], 0) + packet.DBName = string(data[offset : offset+idx]) + offset = offset + idx + 1 + } + } + + if packet.Capability&ClientPluginAuth > 0 { + // TODO: Support mysql.ClientPluginAuth, skip it now + idx := bytes.IndexByte(data[offset:], 0) + offset = offset + idx + 1 + } + + if packet.Capability&ClientConnectAtts > 0 { + if len(data[offset:]) == 0 { + // Defend some ill-formated packet, connection attribute is not important and can be ignored. + return nil + } + } + + return nil +} + +func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { + switch b[0] { + // 251: NULL + case 0xfb: + n = 1 + isNull = true + return + + // 252: value of following 2 + case 0xfc: + num = uint64(b[1]) | uint64(b[2])<<8 + n = 3 + return + + // 253: value of following 3 + case 0xfd: + num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 + n = 4 + return + + // 254: value of following 8 + case 0xfe: + num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | + uint64(b[7])<<48 | uint64(b[8])<<56 + n = 9 + return + } + + // 0-250: value of first byte + num = uint64(b[0]) + n = 1 + return +} + +func extractMysqlPayloadSize(payload []byte) int { + header := payload[:4] + return int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) +} + +func bytesToInt(contents []byte) int { + return int(uint32(contents[0]) | uint32(contents[1])<<8 | uint32(contents[2])<<16 | uint32(contents[3])<<24) +} \ No newline at end of file