mirror of https://github.com/40t/go-sniffer.git
357 lines
6.8 KiB
Go
357 lines
6.8 KiB
Go
package build
|
||
|
||
import (
|
||
"github.com/google/gopacket"
|
||
"io"
|
||
"bytes"
|
||
"errors"
|
||
"log"
|
||
"strconv"
|
||
"sync"
|
||
"time"
|
||
"fmt"
|
||
"encoding/binary"
|
||
"strings"
|
||
"os"
|
||
)
|
||
|
||
const (
|
||
Port = 3306
|
||
Version = "0.1"
|
||
CmdPort = "-p"
|
||
)
|
||
|
||
type Mysql struct {
|
||
port int
|
||
version string
|
||
source map[string]*stream
|
||
}
|
||
|
||
type stream struct {
|
||
packets chan *packet
|
||
stmtMap map[uint32]*Stmt
|
||
}
|
||
|
||
type packet struct {
|
||
isClientFlow bool
|
||
seq int
|
||
length int
|
||
payload []byte
|
||
}
|
||
|
||
var mysql *Mysql
|
||
var once sync.Once
|
||
func NewInstance() *Mysql {
|
||
|
||
once.Do(func() {
|
||
mysql = &Mysql{
|
||
port :Port,
|
||
version:Version,
|
||
source: make(map[string]*stream),
|
||
}
|
||
})
|
||
|
||
return mysql
|
||
}
|
||
|
||
func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
|
||
|
||
//uuid
|
||
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
|
||
|
||
//generate resolve's stream
|
||
if _, ok := m.source[uuid]; !ok {
|
||
|
||
var newStream = stream{
|
||
packets:make(chan *packet, 100),
|
||
stmtMap:make(map[uint32]*Stmt),
|
||
}
|
||
|
||
m.source[uuid] = &newStream
|
||
go newStream.resolve()
|
||
}
|
||
|
||
//read bi-directional packet
|
||
//server -> client || client -> server
|
||
for {
|
||
|
||
newPacket := m.newPacket(net, transport, buf)
|
||
|
||
if newPacket == nil {
|
||
return
|
||
}
|
||
|
||
m.source[uuid].packets <- newPacket
|
||
}
|
||
}
|
||
|
||
func (m *Mysql) BPFFilter() string {
|
||
return "tcp and port "+strconv.Itoa(m.port);
|
||
}
|
||
|
||
func (m *Mysql) Version() string {
|
||
return Version
|
||
}
|
||
|
||
func (m *Mysql) SetFlag(flg []string) {
|
||
|
||
c := len(flg)
|
||
|
||
if c == 0 {
|
||
return
|
||
}
|
||
if c >> 1 == 0 {
|
||
fmt.Println("ERR : Mysql Number of parameters")
|
||
os.Exit(1)
|
||
}
|
||
for i:=0;i<c;i=i+2 {
|
||
key := flg[i]
|
||
val := flg[i+1]
|
||
|
||
switch key {
|
||
case CmdPort:
|
||
port, err := strconv.Atoi(val);
|
||
m.port = port
|
||
if err != nil {
|
||
panic("ERR : port")
|
||
}
|
||
if port < 0 || port > 65535 {
|
||
panic("ERR : port(0-65535)")
|
||
}
|
||
break
|
||
default:
|
||
panic("ERR : mysql's params")
|
||
}
|
||
}
|
||
}
|
||
|
||
func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
|
||
|
||
//read packet
|
||
var payload bytes.Buffer
|
||
var seq uint8
|
||
var err error
|
||
if seq, err = m.resolvePacketTo(r, &payload); err != nil {
|
||
return nil
|
||
}
|
||
|
||
//close stream
|
||
if err == io.EOF {
|
||
fmt.Println(net, transport, " close")
|
||
return nil
|
||
} else if err != nil {
|
||
fmt.Println("ERR : Unknown stream", net, transport, ":", err)
|
||
}
|
||
|
||
//generate new packet
|
||
var pk = packet{
|
||
seq: int(seq),
|
||
length:payload.Len(),
|
||
payload:payload.Bytes(),
|
||
}
|
||
if transport.Src().String() == strconv.Itoa(m.port) {
|
||
pk.isClientFlow = false
|
||
}else{
|
||
pk.isClientFlow = true
|
||
}
|
||
|
||
return &pk
|
||
}
|
||
|
||
func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) {
|
||
|
||
header := make([]byte, 4)
|
||
if n, err := io.ReadFull(r, header); err != nil {
|
||
if n == 0 && err == io.EOF {
|
||
return 0, io.EOF
|
||
}
|
||
return 0, errors.New("ERR : Unknown stream")
|
||
}
|
||
|
||
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
||
|
||
var seq uint8
|
||
seq = header[3]
|
||
|
||
if n, err := io.CopyN(w, r, int64(length)); err != nil {
|
||
return 0, errors.New("ERR : Unknown stream")
|
||
} else if n != int64(length) {
|
||
return 0, errors.New("ERR : Unknown stream")
|
||
} else {
|
||
return seq, nil
|
||
}
|
||
|
||
return seq, nil
|
||
}
|
||
|
||
func (stm *stream) resolve() {
|
||
for {
|
||
select {
|
||
case packet := <- stm.packets:
|
||
if packet.length != 0 {
|
||
if packet.isClientFlow {
|
||
stm.resolveClientPacket(packet.payload, packet.seq)
|
||
} else {
|
||
stm.resolveServerPacket(packet.payload, packet.seq)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func (stm *stream) findStmtPacket (srv chan *packet, seq int) *packet {
|
||
for {
|
||
select {
|
||
case packet, ok := <- stm.packets:
|
||
if !ok {
|
||
return nil
|
||
}
|
||
if packet.seq == seq {
|
||
return packet
|
||
}
|
||
case <-time.After(5 * time.Second):
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
func (stm *stream) resolveServerPacket(payload []byte, seq int) {
|
||
|
||
var msg = ""
|
||
if len(payload) == 0 {
|
||
return
|
||
}
|
||
switch payload[0] {
|
||
|
||
case 0xff:
|
||
errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
|
||
errorMsg,_ := ReadStringFromByte(payload[4:])
|
||
|
||
msg = GetNowStr(false)+"%s Err code:%s,Err msg:%s"
|
||
msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))
|
||
|
||
case 0x00:
|
||
var pos = 1
|
||
l,_,_ := LengthEncodedInt(payload[pos:])
|
||
affectedRows := int(l)
|
||
|
||
msg += GetNowStr(false)+"%s Effect Row:%s"
|
||
msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))
|
||
|
||
default:
|
||
return
|
||
}
|
||
|
||
fmt.Println(msg)
|
||
}
|
||
|
||
func (stm *stream) resolveClientPacket(payload []byte, seq int) {
|
||
|
||
var msg string
|
||
switch payload[0] {
|
||
|
||
case COM_INIT_DB:
|
||
|
||
msg = fmt.Sprintf("USE %s;\n", payload[1:])
|
||
case COM_DROP_DB:
|
||
|
||
msg = fmt.Sprintf("Drop DB %s;\n", payload[1:])
|
||
case COM_CREATE_DB, COM_QUERY:
|
||
|
||
statement := string(payload[1:])
|
||
msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement)
|
||
case COM_STMT_PREPARE:
|
||
|
||
serverPacket := stm.findStmtPacket(stm.packets, seq+1)
|
||
if serverPacket == nil {
|
||
log.Println("ERR : Not found stm packet")
|
||
return
|
||
}
|
||
|
||
//fetch stm id
|
||
stmtID := binary.LittleEndian.Uint32(serverPacket.payload[1:5])
|
||
stmt := &Stmt{
|
||
ID: stmtID,
|
||
Query: string(payload[1:]),
|
||
}
|
||
|
||
//record stm sql
|
||
stm.stmtMap[stmtID] = stmt
|
||
stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7])
|
||
stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9])
|
||
stmt.Args = make([]interface{}, stmt.ParamCount)
|
||
|
||
msg = PreparePacket+stmt.Query
|
||
case COM_STMT_SEND_LONG_DATA:
|
||
|
||
stmtID := binary.LittleEndian.Uint32(payload[1:5])
|
||
paramId := binary.LittleEndian.Uint16(payload[5:7])
|
||
stmt, _ := stm.stmtMap[stmtID]
|
||
|
||
if stmt.Args[paramId] == nil {
|
||
stmt.Args[paramId] = payload[7:]
|
||
} else {
|
||
if b, ok := stmt.Args[paramId].([]byte); ok {
|
||
b = append(b, payload[7:]...)
|
||
stmt.Args[paramId] = b
|
||
}
|
||
}
|
||
return
|
||
case COM_STMT_RESET:
|
||
|
||
stmtID := binary.LittleEndian.Uint32(payload[1:5])
|
||
stmt, _:= stm.stmtMap[stmtID]
|
||
stmt.Args = make([]interface{}, stmt.ParamCount)
|
||
return
|
||
case COM_STMT_EXECUTE:
|
||
|
||
var pos = 1
|
||
stmtID := binary.LittleEndian.Uint32(payload[pos : pos+4])
|
||
pos += 4
|
||
var stmt *Stmt
|
||
var ok bool
|
||
if stmt, ok = stm.stmtMap[stmtID]; ok == false {
|
||
log.Println("ERR : Not found stm id", stmtID)
|
||
return
|
||
}
|
||
|
||
//params
|
||
pos += 5
|
||
if stmt.ParamCount > 0 {
|
||
|
||
//(Null-Bitmap,len = (paramsCount + 7) / 8 byte)
|
||
step := int((stmt.ParamCount + 7) / 8)
|
||
nullBitmap := payload[pos : pos+step]
|
||
pos += step
|
||
|
||
//Parameter separator
|
||
flag := payload[pos]
|
||
|
||
pos++
|
||
|
||
var pTypes []byte
|
||
var pValues []byte
|
||
|
||
//if flag == 1
|
||
//n (len = paramsCount * 2 byte)
|
||
if flag == 1 {
|
||
pTypes = payload[pos : pos+int(stmt.ParamCount)*2]
|
||
pos += int(stmt.ParamCount) * 2
|
||
pValues = payload[pos:]
|
||
}
|
||
|
||
//bind params
|
||
err := stmt.BindArgs(nullBitmap, pTypes, pValues)
|
||
if err != nil {
|
||
log.Println("ERR : Could not bind params", err)
|
||
}
|
||
}
|
||
msg = string(stmt.WriteToText())
|
||
default:
|
||
return
|
||
}
|
||
|
||
fmt.Println(GetNowStr(true) + msg)
|
||
}
|
||
|