go-sniffer/plugSrc/mssql/build/entry.go

302 lines
5.5 KiB
Go

package build
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"os"
"strconv"
"strings"
"sync"
"unicode/utf16"
"github.com/google/gopacket"
)
const (
Port = 1433
Version = "0.1"
CmdPort = "-p"
)
type Mssql struct {
port int
version string
source map[string]*stream
}
type stream struct {
packets chan *packet
}
type packet struct {
isClientFlow bool
status int
packetType int
length int
payload []byte
}
// packet types
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch = 1
packRPCRequest = 3
packReply = 4
packAttention = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
packLogin7 = 16
packSSPIMessage = 17
packPrelogin = 18
)
var mssql *Mssql
var once sync.Once
func NewInstance() *Mssql {
once.Do(func() {
mssql = &Mssql{
port: Port,
version: Version,
source: make(map[string]*stream),
}
})
return mssql
}
func (m *Mssql) Version() string {
return m.version
}
func (m *Mssql) BPFFilter() string {
return "tcp and port " + strconv.Itoa(m.port)
}
func (m *Mssql) SetFlag(flg []string) {
c := len(flg)
if c == 0 {
return
}
if c>>1 == 0 {
fmt.Println("ERR : Mssql Number of parameters")
os.Exit(1)
}
for i := 0; i < c; i = i + 2 {
key := flg[i]
val := flg[i+1]
switch key {
case CmdPort:
port, err := strconv.Atoi(val)
m.port = port
if err != nil {
panic("ERR : port")
}
if port < 0 || port > 65535 {
panic("ERR : port(0-65535)")
}
break
default:
panic("ERR : mssql's params")
}
}
}
// ResolveStream ...
func (m *Mssql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
//uuid
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
// log.Println(uuid)
if _, ok := m.source[uuid]; !ok {
var newStream = &stream{
packets: make(chan *packet, 100),
}
m.source[uuid] = newStream
go newStream.resolve()
}
for {
// log.Println("ssss")
newPacket := m.newPacket(net, transport, buf)
if newPacket == nil {
return
}
m.source[uuid].packets <- newPacket
}
// log.Println('ddd')
}
func (m *Mssql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
// read packet
var packet *packet
var err error
packet, err = readStream(r)
//stream close
if err == io.EOF {
fmt.Println(net, transport, " close")
return nil
} else if err != nil {
fmt.Println("ERR : Unknown stream", net, transport, ":", err)
return nil
}
//set flow direction
if transport.Src().String() == strconv.Itoa(m.port) {
packet.isClientFlow = false
} else {
packet.isClientFlow = true
}
return packet
}
func (m *stream) resolve() {
for {
select {
case packet := <-m.packets:
if packet.isClientFlow {
m.resolveClientPacket(packet)
} else {
m.resolveServerPacket(packet)
}
}
}
}
func readStream(r io.Reader) (*packet, error) {
var buffer bytes.Buffer
header := make([]byte, 8)
p := &packet{}
if _, err := io.ReadFull(r, header); err != nil {
return nil, err
}
p.packetType = int(uint32(header[0]))
p.status = int(uint32(header[1]))
p.length = int(binary.BigEndian.Uint16(header[2:4]))
if p.length > 0 {
io.CopyN(&buffer, r, int64(p.length-8))
}
p.payload = buffer.Bytes()
return p, nil
}
func ucs22str(s []byte) (string, error) {
if len(s)%2 != 0 {
return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
}
buf := make([]uint16, len(s)/2)
for i := 0; i < len(s); i += 2 {
buf[i/2] = binary.LittleEndian.Uint16(s[i:])
}
return string(utf16.Decode(buf)), nil
}
func (m *stream) resolveClientPacket(p *packet) {
var msg string
switch p.packetType {
case 1:
headerLength := int(binary.LittleEndian.Uint32(p.payload[0:4]))
// fmt.Printf("headers %x %d\n %x \n", p.payload[0:4], headerLength, p.payload)
if headerLength > 22 {
//not exists headers
msg = fmt.Sprintf("【query】 %s", string(p.payload))
} else {
//tds 7.2+
msg = fmt.Sprintf("【query】 %s", string(p.payload[headerLength:]))
}
case 3:
// 4 byte
pos := 0
headerLength := int(binary.LittleEndian.Uint32(p.payload[0:4]))
// fmt.Printf("rpc headers %x %d\n \n", p.payload[0:4], headerLength)
pos += headerLength
//rpc name length
rpcLength := int(binary.LittleEndian.Uint16(p.payload[pos : pos+2]))
rpcLength = rpcLength * 2
pos += 2
rpcName, _ := ucs22str(p.payload[pos : pos+rpcLength])
// fmt.Printf("rpc name %s %d %x", rpcName, rpcLength, p.payload[pos:pos+rpcLength])
pos += rpcLength
if strings.Compare(rpcName, `sp_executesql`) != 0 {
return
}
//OPTIONS Flags 2byte
pos += 2
//name length 1byte
nameLength := int(p.payload[pos])
// fmt.Printf("parameter nameLength %d", nameLength)
pos = pos + 1 + nameLength*2
//STATUS FLAGS 1byte
pos += 1
typeNvarchar := p.payload[pos]
// fmt.Printf("typeNvarchar %x ", typeNvarchar)
if typeNvarchar == 0xe7 {
pos += 7
//value
valueLength := int(binary.LittleEndian.Uint16(p.payload[pos+1 : pos+3]))
pos += 2
msg = fmt.Sprintf("【query】%s", string(p.payload[pos:pos+valueLength]))
}
// ParameterData
case 4:
msg = fmt.Sprintf("【query】 %s", "Tabular result")
}
fmt.Println(GetNowStr(true), msg)
}
func (m *stream) resolveServerPacket(p *packet) {
var msg string
switch p.packetType {
case 4: //response
rows, errMsg := parseToken(p.payload)
if rows == 0 && len(errMsg) != 0 {
msg = fmt.Sprintf("【Err】Effect Rows:%d, message: %s", rows, errMsg)
} else {
msg = fmt.Sprintf("【OK】Effect Rows:%d", rows)
}
}
fmt.Println(GetNowStr(false), msg)
}