This commit is contained in:
wasd
2018-09-25 14:13:45 +08:00
commit 3b4d8613fa
14 changed files with 2171 additions and 0 deletions

View File

@@ -0,0 +1,350 @@
package build
import (
"github.com/google/gopacket"
"io"
"bytes"
"errors"
"log"
"strconv"
"sync"
"time"
"fmt"
"encoding/binary"
"strings"
"os"
)
const (
Port = 3306
Version = "0.1"
CmdPort = "-p"
)
type Mysql struct {
port int//端口
version string//插件版本
source map[string]*stream//流
}
type stream struct {
packets chan *packet
stmtMap map[uint32]*Stmt
}
type packet struct {
isClientFlow bool
seq int
length int
payload []byte
}
var mysql *Mysql
var once sync.Once
func NewInstance() *Mysql {
once.Do(func() {
mysql = &Mysql{
port :Port,
version:Version,
source: make(map[string]*stream),
}
})
return mysql
}
func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
//uuid
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
//generate resolve's stream
if _, ok := m.source[uuid]; !ok {
var newStream = stream{
packets:make(chan *packet, 100),
stmtMap:make(map[uint32]*Stmt),
}
m.source[uuid] = &newStream
go newStream.resolve()
}
//read bi-directional packet
//server -> client || client -> server
for {
newPacket := m.newPacket(net, transport, buf)
if newPacket == nil {
return
}
m.source[uuid].packets <- newPacket
}
}
func (m *Mysql) BPFFilter() string {
return "tcp and port "+strconv.Itoa(m.port);
}
func (m *Mysql) Version() string {
return Version
}
func (m *Mysql) SetFlag(flg []string) {
c := len(flg)
if c == 0 {
return
}
if c >> 1 == 0 {
fmt.Println("Mysql参数数量不正确!")
os.Exit(1)
}
for i:=0;i<c;i=i+2 {
key := flg[i]
val := flg[i+1]
switch key {
case CmdPort:
port, err := strconv.Atoi(val);
m.port = port
if err != nil {
panic("端口数不正确")
}
if port < 0 || port > 65535 {
panic("参数不正确: 端口范围(0-65535)")
}
break
default:
panic("参数不正确")
}
}
}
func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
//read packet
var payload bytes.Buffer
var seq uint8
var err error
if seq, err = m.resolvePacketTo(r, &payload); err != nil {
return nil
}
//close stream
if err == io.EOF {
fmt.Println(net, transport, " 关闭")
return nil
} else if err != nil {
fmt.Println("错误流:", net, transport, ":", err)
}
//generate new packet
var pk = packet{
seq: int(seq),
length:payload.Len(),
payload:payload.Bytes(),
}
if transport.Src().String() == strconv.Itoa(Port) {
pk.isClientFlow = false
}else{
pk.isClientFlow = true
}
return &pk
}
func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) {
header := make([]byte, 4)
if n, err := io.ReadFull(r, header); err != nil {
if n == 0 && err == io.EOF {
return 0, io.EOF
}
return 0, errors.New("错误流")
}
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
var seq uint8
seq = header[3]
if n, err := io.CopyN(w, r, int64(length)); err != nil {
return 0, errors.New("错误流")
} else if n != int64(length) {
return 0, errors.New("错误流")
} else {
return seq, nil
}
return seq, nil
}
func (stm *stream) resolve() {
for {
select {
case packet := <- stm.packets:
if packet.isClientFlow {
stm.resolveClientPacket(packet.payload, packet.seq)
} else {
stm.resolveServerPacket(packet.payload, packet.seq)
}
}
}
}
func (stm *stream) findStmtPacket (srv chan *packet, seq int) *packet {
for {
select {
case packet, ok := <- stm.packets:
if !ok {
return nil
}
if packet.seq == seq {
return packet
}
case <-time.After(5 * time.Second):
return nil
}
}
}
func (stm *stream) resolveServerPacket(payload []byte, seq int) {
var msg = ""
switch payload[0] {
case 0xff:
errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
errorMsg,_ := ReadStringFromByte(payload[4:])
msg = GetNowStr(false)+"%s 错误代码:%s,错误信息:%s"
msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))
case 0x00:
var pos = 1
l,_ := LengthBinary(payload[pos:])
affectedRows := int(l)
msg += GetNowStr(false)+"%s 影响行数:%s"
msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))
default:
return
}
fmt.Println(msg)
}
func (stm *stream) resolveClientPacket(payload []byte, seq int) {
var msg string
switch payload[0] {
case COM_INIT_DB:
msg = fmt.Sprintf("USE %s;\n", payload[1:])
case COM_DROP_DB:
msg = fmt.Sprintf("删除数据库 %s;\n", payload[1:])
case COM_CREATE_DB, COM_QUERY:
statement := string(payload[1:])
msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement)
case COM_STMT_PREPARE:
serverPacket := stm.findStmtPacket(stm.packets, seq+1)
if serverPacket == nil {
log.Println("找不到预处理响应包")
}
//获取响应包中预处理id
stmtID := binary.LittleEndian.Uint32(serverPacket.payload[1:5])
stmt := &Stmt{
ID: stmtID,
Query: string(payload[1:]),
}
//记录预处理语句
stm.stmtMap[stmtID] = stmt
stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7])
stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9])
stmt.Args = make([]interface{}, stmt.ParamCount)
msg = PreparePacket+stmt.Query
case COM_STMT_SEND_LONG_DATA:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
paramId := binary.LittleEndian.Uint16(payload[5:7])
stmt, _ := stm.stmtMap[stmtID]
if stmt.Args[paramId] == nil {
stmt.Args[paramId] = payload[7:]
} else {
if b, ok := stmt.Args[paramId].([]byte); ok {
b = append(b, payload[7:]...)
stmt.Args[paramId] = b
}
}
return
case COM_STMT_RESET:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
stmt, _:= stm.stmtMap[stmtID]
stmt.Args = make([]interface{}, stmt.ParamCount)
return
case COM_STMT_EXECUTE:
var pos = 1
stmtID := binary.LittleEndian.Uint32(payload[pos : pos+4])
pos += 4
var stmt *Stmt
var ok bool
if stmt, ok = stm.stmtMap[stmtID]; ok == false {
log.Println("未发现预处理id: ", stmtID)
}
//参数
pos += 5
if stmt.ParamCount > 0 {
//空位图Null-Bitmap长度 = (参数数量 + 7) / 8 字节)
step := int((stmt.ParamCount + 7) / 8)
nullBitmap := payload[pos : pos+step]
pos += step
//参数分隔标志
flag := payload[pos]
pos++
var pTypes []byte
var pValues []byte
//如果参数分隔标志值为1
//n 每个参数的类型值(长度 = 参数数量 * 2 字节)
//n 每个参数的值
if flag == 1 {
pTypes = payload[pos : pos+int(stmt.ParamCount)*2]
pos += int(stmt.ParamCount) * 2
pValues = payload[pos:]
}
//绑定参数
err := stmt.BindArgs(nullBitmap, pTypes, pValues)
if err != nil {
log.Println("预处理绑定参数失败: ", err)
}
}
msg = string(stmt.WriteToText())
default:
return
}
fmt.Println(GetNowStr(true) + msg)
}