从proxy获取client ip地址

This commit is contained in:
hebo 2019-12-06 13:46:59 +08:00
parent f70a646902
commit 67e917f4f3
4 changed files with 35 additions and 27 deletions

View File

@ -15,7 +15,7 @@ import (
"golang.org/x/net/bpf" "golang.org/x/net/bpf"
"github.com/google/gopacket/pcapgo" "github.com/google/gopacket/pcapgo"
proto "github.com/pires/go-proxyproto" pp "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/zr-hebo/sniffer-agent/model" "github.com/zr-hebo/sniffer-agent/model"
sd "github.com/zr-hebo/sniffer-agent/session-dealer" sd "github.com/zr-hebo/sniffer-agent/session-dealer"
@ -147,14 +147,8 @@ func (nc *networkCard) listenNormal() {
// deal auth packet // deal auth packet
if sd.IsAuthPacket(tcpPkt.Payload) { if sd.IsAuthPacket(tcpPkt.Payload) {
reader := bufio.NewReader(bytes.NewReader(data)) authHeader, _ := pp.Read(bufio.NewReader(bytes.NewReader(tcpPkt.Payload)))
header, _ := proto.Read(reader) nc.parseTCPPackage(packet, authHeader)
var clientIP *string
if header != nil {
clientIPContent := header.SourceAddress.String()
clientIP = &clientIPContent
}
nc.parseTCPPackage(packet, clientIP)
continue continue
} }
@ -174,7 +168,7 @@ func (nc *networkCard) listenNormal() {
return return
} }
func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string) { func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, authHeader *pp.Header) {
var err error var err error
defer func() { defer func() {
if err != nil { if err != nil {
@ -203,9 +197,19 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string)
dstIP := ipInfo.DstIP.String() dstIP := ipInfo.DstIP.String()
srcPort := int(tcpPkt.SrcPort) srcPort := int(tcpPkt.SrcPort)
dstPort := int(tcpPkt.DstPort) dstPort := int(tcpPkt.DstPort)
if dstPort == nc.listenPort { if dstPort == nc.listenPort {
// get client ip from proxy auth info
var clientIP *string
var clientPort int
if authHeader != nil && authHeader.SourceAddress.String() != srcIP {
clientIPContent := authHeader.SourceAddress.String()
clientIP = &clientIPContent
clientPort = int(authHeader.SourcePort)
}
// deal mysql server response // deal mysql server response
err = readToServerPackage(clientIP, &srcIP, srcPort, tcpPkt, nc.receiver) err = readToServerPackage(clientIP, clientPort, &srcIP, srcPort, tcpPkt, nc.receiver)
if err != nil { if err != nil {
return return
} }
@ -255,7 +259,8 @@ func readFromServerPackage(
} }
func readToServerPackage( func readToServerPackage(
clientIP, srcIP *string, srcPort int, tcpPkt *layers.TCP, receiver chan model.QueryPiece) (err error) { clientIP *string, clientPort int, srcIP *string, srcPort int, tcpPkt *layers.TCP,
receiver chan model.QueryPiece) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
log.Error("read package send from client to mysql server failed <-- %s", err.Error()) log.Error("read package send from client to mysql server failed <-- %s", err.Error())
@ -282,7 +287,7 @@ func readToServerPackage(
sessionKey := spliceSessionKey(srcIP, srcPort) sessionKey := spliceSessionKey(srcIP, srcPort)
session := sessionPool[*sessionKey] session := sessionPool[*sessionKey]
if session == nil { if session == nil {
session = sd.NewSession(sessionKey, clientIP, srcIP, srcPort, localIPAddr, snifferPort, receiver) session = sd.NewSession(sessionKey, clientIP, clientPort, srcIP, srcPort, localIPAddr, snifferPort, receiver)
sessionPool[*sessionKey] = session sessionPool[*sessionKey] = session
} }

View File

@ -26,7 +26,7 @@ type PooledMysqlQueryPiece struct {
} }
func NewPooledMysqlQueryPiece( func NewPooledMysqlQueryPiece(
sessionID, clientIP, visitUser, visitDB, clientHost, serverIP *string, sessionID, clientIP, visitUser, visitDB, serverIP *string,
clientPort, serverPort int, throwPacketRate float64, stmtBeginTime int64) ( clientPort, serverPort int, throwPacketRate float64, stmtBeginTime int64) (
mqp *PooledMysqlQueryPiece) { mqp *PooledMysqlQueryPiece) {
mqp = mqpp.Dequeue() mqp = mqpp.Dequeue()
@ -35,7 +35,6 @@ func NewPooledMysqlQueryPiece(
mqp.SessionID = sessionID mqp.SessionID = sessionID
mqp.ClientHost = clientIP mqp.ClientHost = clientIP
mqp.ClientPort = clientPort mqp.ClientPort = clientPort
mqp.ClientHost = clientHost
mqp.ServerIP = serverIP mqp.ServerIP = serverIP
mqp.ServerPort = serverPort mqp.ServerPort = serverPort
mqp.VisitUser = visitUser mqp.VisitUser = visitUser

View File

@ -5,13 +5,13 @@ import (
"github.com/zr-hebo/sniffer-agent/session-dealer/mysql" "github.com/zr-hebo/sniffer-agent/session-dealer/mysql"
) )
func NewSession(sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int, func NewSession(sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int,
receiver chan model.QueryPiece) (session ConnSession) { receiver chan model.QueryPiece) (session ConnSession) {
switch serviceType { switch serviceType {
case ServiceTypeMysql: case ServiceTypeMysql:
session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver) session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver)
default: default:
session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver) session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver)
} }
return return
} }

View File

@ -15,9 +15,10 @@ type MysqlSession struct {
connectionID *string connectionID *string
visitUser *string visitUser *string
visitDB *string visitDB *string
clientAlias *string clientIP *string
clientHost *string
clientPort int clientPort int
srcIP *string
srcPort int
serverIP *string serverIP *string
serverPort int serverPort int
stmtBeginTime int64 stmtBeginTime int64
@ -47,13 +48,14 @@ type prepareInfo struct {
} }
func NewMysqlSession( func NewMysqlSession(
sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int, sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int,
receiver chan model.QueryPiece) (ms *MysqlSession) { receiver chan model.QueryPiece) (ms *MysqlSession) {
ms = &MysqlSession{ ms = &MysqlSession{
connectionID: sessionKey, connectionID: sessionKey,
clientAlias: clientAlias, clientIP: clientIP,
clientHost: clientIP,
clientPort: clientPort, clientPort: clientPort,
srcIP: srcIP,
srcPort: srcPort,
serverIP: serverIP, serverIP: serverIP,
serverPort: serverPort, serverPort: serverPort,
stmtBeginTime: time.Now().UnixNano() / millSecondUnit, stmtBeginTime: time.Now().UnixNano() / millSecondUnit,
@ -335,11 +337,13 @@ func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) (*
} }
func (ms *MysqlSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) { func (ms *MysqlSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) {
clientIP := ms.clientAlias clientIP := ms.clientIP
clientPort := ms.clientPort
if clientIP == nil || len(*clientIP) < 1 { if clientIP == nil || len(*clientIP) < 1 {
clientIP = ms.clientHost clientIP = ms.srcIP
clientPort = ms.serverPort
} }
return model.NewPooledMysqlQueryPiece( return model.NewPooledMysqlQueryPiece(
ms.connectionID, ms.clientHost, ms.visitUser, ms.visitDB, ms.clientHost, ms.serverIP, ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP,
ms.clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime) clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime)
} }