From 67e917f4f3ff0c15818da20bb7d3d2f9ef81f8b0 Mon Sep 17 00:00:00 2001 From: hebo <xiaobiao@jd.com> Date: Fri, 6 Dec 2019 13:46:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=8Eproxy=E8=8E=B7=E5=8F=96client=20ip?= =?UTF-8?q?=E5=9C=B0=E5=9D=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- capture/network.go | 31 ++++++++++++++++++------------- model/mysql_query_piece.go | 3 +-- session-dealer/controller.go | 6 +++--- session-dealer/mysql/session.go | 22 +++++++++++++--------- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/capture/network.go b/capture/network.go index fcc5603..2da8a5c 100644 --- a/capture/network.go +++ b/capture/network.go @@ -15,7 +15,7 @@ import ( "golang.org/x/net/bpf" "github.com/google/gopacket/pcapgo" - proto "github.com/pires/go-proxyproto" + pp "github.com/pires/go-proxyproto" log "github.com/sirupsen/logrus" "github.com/zr-hebo/sniffer-agent/model" sd "github.com/zr-hebo/sniffer-agent/session-dealer" @@ -147,14 +147,8 @@ func (nc *networkCard) listenNormal() { // deal auth packet if sd.IsAuthPacket(tcpPkt.Payload) { - reader := bufio.NewReader(bytes.NewReader(data)) - header, _ := proto.Read(reader) - var clientIP *string - if header != nil { - clientIPContent := header.SourceAddress.String() - clientIP = &clientIPContent - } - nc.parseTCPPackage(packet, clientIP) + authHeader, _ := pp.Read(bufio.NewReader(bytes.NewReader(tcpPkt.Payload))) + nc.parseTCPPackage(packet, authHeader) continue } @@ -174,7 +168,7 @@ func (nc *networkCard) listenNormal() { return } -func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string) { +func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, authHeader *pp.Header) { var err error defer func() { if err != nil { @@ -203,9 +197,19 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string) dstIP := ipInfo.DstIP.String() srcPort := int(tcpPkt.SrcPort) dstPort := int(tcpPkt.DstPort) + if dstPort == nc.listenPort { + // get client ip from proxy auth info + var clientIP *string + var clientPort int + if authHeader != nil && authHeader.SourceAddress.String() != srcIP { + clientIPContent := authHeader.SourceAddress.String() + clientIP = &clientIPContent + clientPort = int(authHeader.SourcePort) + } + // deal mysql server response - err = readToServerPackage(clientIP, &srcIP, srcPort, tcpPkt, nc.receiver) + err = readToServerPackage(clientIP, clientPort, &srcIP, srcPort, tcpPkt, nc.receiver) if err != nil { return } @@ -255,7 +259,8 @@ func readFromServerPackage( } func readToServerPackage( - clientIP, srcIP *string, srcPort int, tcpPkt *layers.TCP, receiver chan model.QueryPiece) (err error) { + clientIP *string, clientPort int, srcIP *string, srcPort int, tcpPkt *layers.TCP, + receiver chan model.QueryPiece) (err error) { defer func() { if err != nil { log.Error("read package send from client to mysql server failed <-- %s", err.Error()) @@ -282,7 +287,7 @@ func readToServerPackage( sessionKey := spliceSessionKey(srcIP, srcPort) session := sessionPool[*sessionKey] if session == nil { - session = sd.NewSession(sessionKey, clientIP, srcIP, srcPort, localIPAddr, snifferPort, receiver) + session = sd.NewSession(sessionKey, clientIP, clientPort, srcIP, srcPort, localIPAddr, snifferPort, receiver) sessionPool[*sessionKey] = session } diff --git a/model/mysql_query_piece.go b/model/mysql_query_piece.go index 8c75060..a4587e7 100644 --- a/model/mysql_query_piece.go +++ b/model/mysql_query_piece.go @@ -26,7 +26,7 @@ type PooledMysqlQueryPiece struct { } func NewPooledMysqlQueryPiece( - sessionID, clientIP, visitUser, visitDB, clientHost, serverIP *string, + sessionID, clientIP, visitUser, visitDB, serverIP *string, clientPort, serverPort int, throwPacketRate float64, stmtBeginTime int64) ( mqp *PooledMysqlQueryPiece) { mqp = mqpp.Dequeue() @@ -35,7 +35,6 @@ func NewPooledMysqlQueryPiece( mqp.SessionID = sessionID mqp.ClientHost = clientIP mqp.ClientPort = clientPort - mqp.ClientHost = clientHost mqp.ServerIP = serverIP mqp.ServerPort = serverPort mqp.VisitUser = visitUser diff --git a/session-dealer/controller.go b/session-dealer/controller.go index 585796d..c8174dc 100644 --- a/session-dealer/controller.go +++ b/session-dealer/controller.go @@ -5,13 +5,13 @@ import ( "github.com/zr-hebo/sniffer-agent/session-dealer/mysql" ) -func NewSession(sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int, +func NewSession(sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int, receiver chan model.QueryPiece) (session ConnSession) { switch serviceType { case ServiceTypeMysql: - session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver) + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver) default: - session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver) + session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver) } return } diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go index 1d70aa2..c9fc9df 100644 --- a/session-dealer/mysql/session.go +++ b/session-dealer/mysql/session.go @@ -15,9 +15,10 @@ type MysqlSession struct { connectionID *string visitUser *string visitDB *string - clientAlias *string - clientHost *string + clientIP *string clientPort int + srcIP *string + srcPort int serverIP *string serverPort int stmtBeginTime int64 @@ -47,13 +48,14 @@ type prepareInfo struct { } func NewMysqlSession( - sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int, + sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int, receiver chan model.QueryPiece) (ms *MysqlSession) { ms = &MysqlSession{ connectionID: sessionKey, - clientAlias: clientAlias, - clientHost: clientIP, + clientIP: clientIP, clientPort: clientPort, + srcIP: srcIP, + srcPort: srcPort, serverIP: serverIP, serverPort: serverPort, stmtBeginTime: time.Now().UnixNano() / millSecondUnit, @@ -335,11 +337,13 @@ func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) (* } func (ms *MysqlSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) { - clientIP := ms.clientAlias + clientIP := ms.clientIP + clientPort := ms.clientPort if clientIP == nil || len(*clientIP) < 1 { - clientIP = ms.clientHost + clientIP = ms.srcIP + clientPort = ms.serverPort } return model.NewPooledMysqlQueryPiece( - ms.connectionID, ms.clientHost, ms.visitUser, ms.visitDB, ms.clientHost, ms.serverIP, - ms.clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime) + ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP, + clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime) }