go-sniffer/plugSrc/mysql/build/entry.go

357 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package build
import (
"github.com/google/gopacket"
"io"
"bytes"
"errors"
"log"
"strconv"
"sync"
"time"
"fmt"
"encoding/binary"
"strings"
"os"
)
const (
Port = 3306
Version = "0.1"
CmdPort = "-p"
)
type Mysql struct {
port int
version string
source map[string]*stream
}
type stream struct {
packets chan *packet
stmtMap map[uint32]*Stmt
}
type packet struct {
isClientFlow bool
seq int
length int
payload []byte
}
var mysql *Mysql
var once sync.Once
func NewInstance() *Mysql {
once.Do(func() {
mysql = &Mysql{
port :Port,
version:Version,
source: make(map[string]*stream),
}
})
return mysql
}
func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
//uuid
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
//generate resolve's stream
if _, ok := m.source[uuid]; !ok {
var newStream = stream{
packets:make(chan *packet, 100),
stmtMap:make(map[uint32]*Stmt),
}
m.source[uuid] = &newStream
go newStream.resolve()
}
//read bi-directional packet
//server -> client || client -> server
for {
newPacket := m.newPacket(net, transport, buf)
if newPacket == nil {
return
}
m.source[uuid].packets <- newPacket
}
}
func (m *Mysql) BPFFilter() string {
return "tcp and port "+strconv.Itoa(m.port);
}
func (m *Mysql) Version() string {
return Version
}
func (m *Mysql) SetFlag(flg []string) {
c := len(flg)
if c == 0 {
return
}
if c >> 1 == 0 {
fmt.Println("ERR : Mysql Number of parameters")
os.Exit(1)
}
for i:=0;i<c;i=i+2 {
key := flg[i]
val := flg[i+1]
switch key {
case CmdPort:
port, err := strconv.Atoi(val);
m.port = port
if err != nil {
panic("ERR : port")
}
if port < 0 || port > 65535 {
panic("ERR : port(0-65535)")
}
break
default:
panic("ERR : mysql's params")
}
}
}
func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
//read packet
var payload bytes.Buffer
var seq uint8
var err error
if seq, err = m.resolvePacketTo(r, &payload); err != nil {
return nil
}
//close stream
if err == io.EOF {
fmt.Println(net, transport, " close")
return nil
} else if err != nil {
fmt.Println("ERR : Unknown stream", net, transport, ":", err)
}
//generate new packet
var pk = packet{
seq: int(seq),
length:payload.Len(),
payload:payload.Bytes(),
}
if transport.Src().String() == strconv.Itoa(m.port) {
pk.isClientFlow = false
}else{
pk.isClientFlow = true
}
return &pk
}
func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) {
header := make([]byte, 4)
if n, err := io.ReadFull(r, header); err != nil {
if n == 0 && err == io.EOF {
return 0, io.EOF
}
return 0, errors.New("ERR : Unknown stream")
}
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
var seq uint8
seq = header[3]
if n, err := io.CopyN(w, r, int64(length)); err != nil {
return 0, errors.New("ERR : Unknown stream")
} else if n != int64(length) {
return 0, errors.New("ERR : Unknown stream")
} else {
return seq, nil
}
return seq, nil
}
func (stm *stream) resolve() {
for {
select {
case packet := <- stm.packets:
if packet.length != 0 {
if packet.isClientFlow {
stm.resolveClientPacket(packet.payload, packet.seq)
} else {
stm.resolveServerPacket(packet.payload, packet.seq)
}
}
}
}
}
func (stm *stream) findStmtPacket (srv chan *packet, seq int) *packet {
for {
select {
case packet, ok := <- stm.packets:
if !ok {
return nil
}
if packet.seq == seq {
return packet
}
case <-time.After(5 * time.Second):
return nil
}
}
}
func (stm *stream) resolveServerPacket(payload []byte, seq int) {
var msg = ""
if len(payload) == 0 {
return
}
switch payload[0] {
case 0xff:
errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
errorMsg,_ := ReadStringFromByte(payload[4:])
msg = GetNowStr(false)+"%s Err code:%s,Err msg:%s"
msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))
case 0x00:
var pos = 1
l,_,_ := LengthEncodedInt(payload[pos:])
affectedRows := int(l)
msg += GetNowStr(false)+"%s Effect Row:%s"
msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))
default:
return
}
fmt.Println(msg)
}
func (stm *stream) resolveClientPacket(payload []byte, seq int) {
var msg string
switch payload[0] {
case COM_INIT_DB:
msg = fmt.Sprintf("USE %s;\n", payload[1:])
case COM_DROP_DB:
msg = fmt.Sprintf("Drop DB %s;\n", payload[1:])
case COM_CREATE_DB, COM_QUERY:
statement := string(payload[1:])
msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement)
case COM_STMT_PREPARE:
serverPacket := stm.findStmtPacket(stm.packets, seq+1)
if serverPacket == nil {
log.Println("ERR : Not found stm packet")
return
}
//fetch stm id
stmtID := binary.LittleEndian.Uint32(serverPacket.payload[1:5])
stmt := &Stmt{
ID: stmtID,
Query: string(payload[1:]),
}
//record stm sql
stm.stmtMap[stmtID] = stmt
stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7])
stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9])
stmt.Args = make([]interface{}, stmt.ParamCount)
msg = PreparePacket+stmt.Query
case COM_STMT_SEND_LONG_DATA:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
paramId := binary.LittleEndian.Uint16(payload[5:7])
stmt, _ := stm.stmtMap[stmtID]
if stmt.Args[paramId] == nil {
stmt.Args[paramId] = payload[7:]
} else {
if b, ok := stmt.Args[paramId].([]byte); ok {
b = append(b, payload[7:]...)
stmt.Args[paramId] = b
}
}
return
case COM_STMT_RESET:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
stmt, _:= stm.stmtMap[stmtID]
stmt.Args = make([]interface{}, stmt.ParamCount)
return
case COM_STMT_EXECUTE:
var pos = 1
stmtID := binary.LittleEndian.Uint32(payload[pos : pos+4])
pos += 4
var stmt *Stmt
var ok bool
if stmt, ok = stm.stmtMap[stmtID]; ok == false {
log.Println("ERR : Not found stm id", stmtID)
return
}
//params
pos += 5
if stmt.ParamCount > 0 {
//Null-Bitmaplen = (paramsCount + 7) / 8 byte
step := int((stmt.ParamCount + 7) / 8)
nullBitmap := payload[pos : pos+step]
pos += step
//Parameter separator
flag := payload[pos]
pos++
var pTypes []byte
var pValues []byte
//if flag == 1
//n len = paramsCount * 2 byte
if flag == 1 {
pTypes = payload[pos : pos+int(stmt.ParamCount)*2]
pos += int(stmt.ParamCount) * 2
pValues = payload[pos:]
}
//bind params
err := stmt.BindArgs(nullBitmap, pTypes, pValues)
if err != nil {
log.Println("ERR : Could not bind params", err)
}
}
msg = string(stmt.WriteToText())
default:
return
}
fmt.Println(GetNowStr(true) + msg)
}