implementl get user and database

Signed-off-by: zhuhuijun <zhuhuijunzhj@gmail.com>
This commit is contained in:
zhuhuijun
2022-12-07 18:07:00 +08:00
parent add66245ba
commit af88bad075
166 changed files with 58325 additions and 102 deletions

View File

@@ -4,7 +4,7 @@ import (
"bytes"
"encoding/base64"
"fmt"
"github.com/40t/go-sniffer/plugSrc/mongodb/build/internal/json"
"go-sniffer/plugSrc/mongodb/build/internal/json"
"strconv"
"time"
)

View File

@@ -4,9 +4,9 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"time"
"go-sniffer/plugSrc/mongodb/build/bson"
"io"
"github.com/40t/go-sniffer/plugSrc/mongodb/build/bson"
"time"
)
func GetNowStr(isClient bool) string {
@@ -15,7 +15,7 @@ func GetNowStr(isClient bool) string {
msg += time.Now().Format(layout)
if isClient {
msg += "| cli -> ser |"
}else{
} else {
msg += "| ser -> cli |"
}
return msg
@@ -54,7 +54,7 @@ func ReadString(r io.Reader) string {
return string(result)
}
func ReadBson2Json(r io.Reader) (string) {
func ReadBson2Json(r io.Reader) string {
//read len
docLen := ReadInt32(r)
@@ -83,4 +83,3 @@ func ReadBson2Json(r io.Reader) (string) {
}
return string(jsonStr)
}

View File

@@ -1,53 +1,57 @@
package build
import (
"github.com/google/gopacket"
"io"
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/google/gopacket"
"go-sniffer/pkg/model"
"go-sniffer/pkg/parse"
"io"
"log"
"os"
"strconv"
"strings"
"sync"
"time"
"fmt"
"encoding/binary"
"strings"
"os"
)
const (
Port = 3306
Version = "0.1"
CmdPort = "-p"
Port = 3306
Version = "0.1"
CmdPort = "-p"
)
type Mysql struct {
port int
version string
source map[string]*stream
port int
version string
source map[string]*stream
}
type stream struct {
packets chan *packet
stmtMap map[uint32]*Stmt
session *model.MysqlSession
}
type packet struct {
isClientFlow bool
seq int
length int
payload []byte
seq int
length int
payload []byte
}
var mysql *Mysql
var once sync.Once
func NewInstance() *Mysql {
once.Do(func() {
mysql = &Mysql{
port :Port,
version:Version,
source: make(map[string]*stream),
port: Port,
version: Version,
source: make(map[string]*stream),
}
})
@@ -60,11 +64,13 @@ func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())
//generate resolve's stream
if _, ok := m.source[uuid]; !ok {
old, ok := m.source[uuid]
if !ok {
var newStream = stream{
packets:make(chan *packet, 100),
stmtMap:make(map[uint32]*Stmt),
packets: make(chan *packet, 100),
stmtMap: make(map[uint32]*Stmt),
session: parse.GenSession(net, transport),
}
m.source[uuid] = &newStream
@@ -81,36 +87,40 @@ func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {
return
}
if ok && old != nil && old.session != nil {
m.source[uuid].session = old.session
}
m.source[uuid].packets <- newPacket
}
}
func (m *Mysql) BPFFilter() string {
return "tcp and port "+strconv.Itoa(m.port);
return "tcp and port " + strconv.Itoa(m.port)
}
func (m *Mysql) Version() string {
return Version
}
func (m *Mysql) SetFlag(flg []string) {
func (m *Mysql) SetFlag(flg []string) {
c := len(flg)
if c == 0 {
return
}
if c >> 1 == 0 {
if c>>1 == 0 {
fmt.Println("ERR : Mysql Number of parameters")
os.Exit(1)
}
for i:=0;i<c;i=i+2 {
for i := 0; i < c; i = i + 2 {
key := flg[i]
val := flg[i+1]
switch key {
case CmdPort:
port, err := strconv.Atoi(val);
port, err := strconv.Atoi(val)
m.port = port
if err != nil {
panic("ERR : port")
@@ -145,13 +155,13 @@ func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {
//generate new packet
var pk = packet{
seq: int(seq),
length:payload.Len(),
payload:payload.Bytes(),
seq: int(seq),
length: payload.Len(),
payload: payload.Bytes(),
}
if transport.Src().String() == strconv.Itoa(m.port) {
pk.isClientFlow = false
}else{
} else {
pk.isClientFlow = true
}
@@ -187,7 +197,7 @@ func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) {
func (stm *stream) resolve() {
for {
select {
case packet := <- stm.packets:
case packet := <-stm.packets:
if packet.length != 0 {
if packet.isClientFlow {
stm.resolveClientPacket(packet.payload, packet.seq)
@@ -199,10 +209,10 @@ func (stm *stream) resolve() {
}
}
func (stm *stream) findStmtPacket (srv chan *packet, seq int) *packet {
func (stm *stream) findStmtPacket(srv chan *packet, seq int) *packet {
for {
select {
case packet, ok := <- stm.packets:
case packet, ok := <-stm.packets:
if !ok {
return nil
}
@@ -221,45 +231,73 @@ func (stm *stream) resolveServerPacket(payload []byte, seq int) {
if len(payload) == 0 {
return
}
//fmt.Println(string(payload))
switch payload[0] {
case 0xff:
errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
errorMsg, _ := ReadStringFromByte(payload[4:])
case 0xff:
errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
errorMsg,_ := ReadStringFromByte(payload[4:])
msg = GetNowStr(false) + "%s Err code:%s,Err msg:%s"
msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))
msg = GetNowStr(false)+"%s Err code:%s,Err msg:%s"
msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))
case 0x00:
var pos = 1
l, _, _ := LengthEncodedInt(payload[pos:])
affectedRows := int(l)
case 0x00:
var pos = 1
l,_,_ := LengthEncodedInt(payload[pos:])
affectedRows := int(l)
msg += GetNowStr(false) + "%s Effect Row:%s"
msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))
msg += GetNowStr(false)+"%s Effect Row:%s"
msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))
default:
return
default:
return
}
fmt.Println(msg)
if stm.session != nil {
result := model.MysqlQueryPiece{
BaseQueryPiece: model.BaseQueryPiece{},
ClientHost: stm.session.ClientIP,
ClientPort: stm.session.ClientPort,
VisitUser: stm.session.UserName,
VisitDB: stm.session.DBName,
CostTimeInMS: 0,
Message: msg,
}
fmt.Println(result.ToString())
}
}
func (stm *stream) resolveClientPacket(payload []byte, seq int) {
if parse.IsAuth(payload[0]) {
userName, dbName, err := parse.AuthInfoParse(payload)
if err != nil {
fmt.Printf("parse auth info failed <-- %s", err.Error())
return
}
var msg string
stm.session.UserName = userName
stm.session.DBName = dbName
}
var (
msg string
raw string
)
switch payload[0] {
case COM_INIT_DB:
msg = fmt.Sprintf("USE %s;\n", payload[1:])
raw = msg
case COM_DROP_DB:
msg = fmt.Sprintf("Drop DB %s;\n", payload[1:])
raw = msg
case COM_CREATE_DB, COM_QUERY:
statement := string(payload[1:])
msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement)
raw = string(payload[1:])
case COM_STMT_PREPARE:
serverPacket := stm.findStmtPacket(stm.packets, seq+1)
@@ -279,14 +317,15 @@ func (stm *stream) resolveClientPacket(payload []byte, seq int) {
stm.stmtMap[stmtID] = stmt
stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7])
stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9])
stmt.Args = make([]interface{}, stmt.ParamCount)
stmt.Args = make([]interface{}, stmt.ParamCount)
msg = PreparePacket+stmt.Query
msg = PreparePacket + stmt.Query
raw = stmt.Query
case COM_STMT_SEND_LONG_DATA:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
paramId := binary.LittleEndian.Uint16(payload[5:7])
stmt, _ := stm.stmtMap[stmtID]
stmtID := binary.LittleEndian.Uint32(payload[1:5])
paramId := binary.LittleEndian.Uint16(payload[5:7])
stmt, _ := stm.stmtMap[stmtID]
if stmt.Args[paramId] == nil {
stmt.Args[paramId] = payload[7:]
@@ -300,7 +339,7 @@ func (stm *stream) resolveClientPacket(payload []byte, seq int) {
case COM_STMT_RESET:
stmtID := binary.LittleEndian.Uint32(payload[1:5])
stmt, _:= stm.stmtMap[stmtID]
stmt, _ := stm.stmtMap[stmtID]
stmt.Args = make([]interface{}, stmt.ParamCount)
return
case COM_STMT_EXECUTE:
@@ -329,7 +368,7 @@ func (stm *stream) resolveClientPacket(payload []byte, seq int) {
pos++
var pTypes []byte
var pTypes []byte
var pValues []byte
//if flag == 1
@@ -347,10 +386,27 @@ func (stm *stream) resolveClientPacket(payload []byte, seq int) {
}
}
msg = string(stmt.WriteToText())
raw = string(stmt.WriteToText())
default:
return
}
fmt.Println(GetNowStr(true) + msg)
}
if stm.session != nil {
result := model.MysqlQueryPiece{
BaseQueryPiece: model.BaseQueryPiece{
ServerIP: stm.session.ServerIP,
ServerPort: stm.session.ServerPort,
},
ClientHost: stm.session.ClientIP,
ClientPort: stm.session.ClientPort,
VisitUser: stm.session.UserName,
VisitDB: stm.session.DBName,
QuerySQL: raw,
CostTimeInMS: 0,
}
fmt.Println(result.ToString())
}
}