mirror of https://github.com/40t/go-sniffer.git
302 lines
5.5 KiB
Go
302 lines
5.5 KiB
Go
package build
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"unicode/utf16"
|
|
|
|
"github.com/google/gopacket"
|
|
)
|
|
|
|
const (
|
|
Port = 1433
|
|
Version = "0.1"
|
|
CmdPort = "-p"
|
|
)
|
|
|
|
type Mssql struct {
|
|
port int
|
|
version string
|
|
source map[string]*stream
|
|
}
|
|
|
|
type stream struct {
|
|
packets chan *packet
|
|
}
|
|
|
|
type packet struct {
|
|
isClientFlow bool
|
|
status int
|
|
packetType int
|
|
length int
|
|
payload []byte
|
|
}
|
|
|
|
// packet types
|
|
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
|
|
const (
|
|
packSQLBatch = 1
|
|
packRPCRequest = 3
|
|
packReply = 4
|
|
packAttention = 6
|
|
|
|
packBulkLoadBCP = 7
|
|
packTransMgrReq = 14
|
|
packNormal = 15
|
|
packLogin7 = 16
|
|
packSSPIMessage = 17
|
|
packPrelogin = 18
|
|
)
|
|
|
|
var mssql *Mssql
|
|
var once sync.Once
|
|
|
|
func NewInstance() *Mssql {
|
|
|
|
once.Do(func() {
|
|
mssql = &Mssql{
|
|
port: Port,
|
|
version: Version,
|
|
source: make(map[string]*stream),
|
|
}
|
|
})
|
|
return mssql
|
|
}
|
|
|
|
func (m *Mssql) Version() string {
|
|
return m.version
|
|
}
|
|
|
|
func (m *Mssql) BPFFilter() string {
|
|
return "tcp and port " + strconv.Itoa(m.port)
|
|
}
|
|
|
|
func (m *Mssql) SetFlag(flg []string) {
|
|
c := len(flg)
|
|
|
|
if c == 0 {
|
|
return
|
|
}
|
|
|
|
if c>>1 == 0 {
|
|
fmt.Println("ERR : Mssql Number of parameters")
|
|
os.Exit(1)
|
|
}
|
|
|
|
for i := 0; i < c; i = i + 2 {
|
|
key := flg[i]
|
|
val := flg[i+1]
|
|
|
|
switch key {
|
|
case CmdPort:
|
|
port, err := strconv.Atoi(val)
|
|
m.port = port
|
|
if err != nil {
|
|
panic("ERR : port")
|
|
}
|
|
if port < 0 || port > 65535 {
|
|
panic("ERR : port(0-65535)")
|
|
}
|
|
break
|
|
default:
|
|
panic("ERR : mssql's params")
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
// ResolveStream ...
|
|
func (m *Mssql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
|
|
//uuid
|
|
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
|
|
|
|
// log.Println(uuid)
|
|
|
|
if _, ok := m.source[uuid]; !ok {
|
|
var newStream = &stream{
|
|
packets: make(chan *packet, 100),
|
|
}
|
|
m.source[uuid] = newStream
|
|
go newStream.resolve()
|
|
}
|
|
|
|
for {
|
|
|
|
// log.Println("ssss")
|
|
|
|
newPacket := m.newPacket(net, transport, buf)
|
|
if newPacket == nil {
|
|
return
|
|
}
|
|
m.source[uuid].packets <- newPacket
|
|
|
|
}
|
|
// log.Println('ddd')
|
|
}
|
|
|
|
func (m *Mssql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
|
|
// read packet
|
|
var packet *packet
|
|
var err error
|
|
packet, err = readStream(r)
|
|
|
|
//stream close
|
|
if err == io.EOF {
|
|
fmt.Println(net, transport, " close")
|
|
return nil
|
|
} else if err != nil {
|
|
fmt.Println("ERR : Unknown stream", net, transport, ":", err)
|
|
return nil
|
|
}
|
|
|
|
//set flow direction
|
|
if transport.Src().String() == strconv.Itoa(m.port) {
|
|
packet.isClientFlow = false
|
|
} else {
|
|
packet.isClientFlow = true
|
|
}
|
|
return packet
|
|
}
|
|
|
|
func (m *stream) resolve() {
|
|
for {
|
|
select {
|
|
case packet := <-m.packets:
|
|
if packet.isClientFlow {
|
|
m.resolveClientPacket(packet)
|
|
} else {
|
|
m.resolveServerPacket(packet)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func readStream(r io.Reader) (*packet, error) {
|
|
|
|
var buffer bytes.Buffer
|
|
|
|
header := make([]byte, 8)
|
|
p := &packet{}
|
|
if _, err := io.ReadFull(r, header); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
p.packetType = int(uint32(header[0]))
|
|
p.status = int(uint32(header[1]))
|
|
p.length = int(binary.BigEndian.Uint16(header[2:4]))
|
|
|
|
if p.length > 0 {
|
|
io.CopyN(&buffer, r, int64(p.length-8))
|
|
}
|
|
p.payload = buffer.Bytes()
|
|
return p, nil
|
|
}
|
|
|
|
func ucs22str(s []byte) (string, error) {
|
|
if len(s)%2 != 0 {
|
|
return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
|
|
}
|
|
buf := make([]uint16, len(s)/2)
|
|
for i := 0; i < len(s); i += 2 {
|
|
buf[i/2] = binary.LittleEndian.Uint16(s[i:])
|
|
}
|
|
return string(utf16.Decode(buf)), nil
|
|
}
|
|
|
|
func (m *stream) resolveClientPacket(p *packet) {
|
|
|
|
var msg string
|
|
|
|
switch p.packetType {
|
|
case 1:
|
|
headerLength := int(binary.LittleEndian.Uint32(p.payload[0:4]))
|
|
// fmt.Printf("headers %x %d\n %x \n", p.payload[0:4], headerLength, p.payload)
|
|
if headerLength > 22 {
|
|
//not exists headers
|
|
msg = fmt.Sprintf("【query】 %s", string(p.payload))
|
|
|
|
} else {
|
|
//tds 7.2+
|
|
msg = fmt.Sprintf("【query】 %s", string(p.payload[headerLength:]))
|
|
}
|
|
case 3:
|
|
// 4 byte
|
|
pos := 0
|
|
headerLength := int(binary.LittleEndian.Uint32(p.payload[0:4]))
|
|
// fmt.Printf("rpc headers %x %d\n \n", p.payload[0:4], headerLength)
|
|
pos += headerLength
|
|
|
|
//rpc name length
|
|
rpcLength := int(binary.LittleEndian.Uint16(p.payload[pos : pos+2]))
|
|
|
|
rpcLength = rpcLength * 2
|
|
|
|
pos += 2
|
|
|
|
rpcName, _ := ucs22str(p.payload[pos : pos+rpcLength])
|
|
|
|
// fmt.Printf("rpc name %s %d %x", rpcName, rpcLength, p.payload[pos:pos+rpcLength])
|
|
|
|
pos += rpcLength
|
|
|
|
if strings.Compare(rpcName, `sp_executesql`) != 0 {
|
|
return
|
|
}
|
|
//OPTIONS Flags 2byte
|
|
|
|
pos += 2
|
|
|
|
//name length 1byte
|
|
nameLength := int(p.payload[pos])
|
|
// fmt.Printf("parameter nameLength %d", nameLength)
|
|
|
|
pos = pos + 1 + nameLength*2
|
|
|
|
//STATUS FLAGS 1byte
|
|
pos += 1
|
|
|
|
typeNvarchar := p.payload[pos]
|
|
// fmt.Printf("typeNvarchar %x ", typeNvarchar)
|
|
if typeNvarchar == 0xe7 {
|
|
pos += 7
|
|
|
|
//value
|
|
valueLength := int(binary.LittleEndian.Uint16(p.payload[pos+1 : pos+3]))
|
|
pos += 2
|
|
|
|
msg = fmt.Sprintf("【query】%s", string(p.payload[pos:pos+valueLength]))
|
|
|
|
}
|
|
// ParameterData
|
|
|
|
case 4:
|
|
msg = fmt.Sprintf("【query】 %s", "Tabular result")
|
|
|
|
}
|
|
|
|
fmt.Println(GetNowStr(true), msg)
|
|
}
|
|
|
|
func (m *stream) resolveServerPacket(p *packet) {
|
|
|
|
var msg string
|
|
switch p.packetType {
|
|
case 4: //response
|
|
rows, errMsg := parseToken(p.payload)
|
|
if rows == 0 && len(errMsg) != 0 {
|
|
msg = fmt.Sprintf("【Err】Effect Rows:%d, message: %s", rows, errMsg)
|
|
} else {
|
|
msg = fmt.Sprintf("【OK】Effect Rows:%d", rows)
|
|
}
|
|
}
|
|
|
|
fmt.Println(GetNowStr(false), msg)
|
|
}
|