From 67e917f4f3ff0c15818da20bb7d3d2f9ef81f8b0 Mon Sep 17 00:00:00 2001
From: hebo <xiaobiao@jd.com>
Date: Fri, 6 Dec 2019 13:46:59 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BB=8Eproxy=E8=8E=B7=E5=8F=96client=20ip?=
 =?UTF-8?q?=E5=9C=B0=E5=9D=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 capture/network.go              | 31 ++++++++++++++++++-------------
 model/mysql_query_piece.go      |  3 +--
 session-dealer/controller.go    |  6 +++---
 session-dealer/mysql/session.go | 22 +++++++++++++---------
 4 files changed, 35 insertions(+), 27 deletions(-)

diff --git a/capture/network.go b/capture/network.go
index fcc5603..2da8a5c 100644
--- a/capture/network.go
+++ b/capture/network.go
@@ -15,7 +15,7 @@ import (
 	"golang.org/x/net/bpf"
 
 	"github.com/google/gopacket/pcapgo"
-	proto "github.com/pires/go-proxyproto"
+	pp "github.com/pires/go-proxyproto"
 	log "github.com/sirupsen/logrus"
 	"github.com/zr-hebo/sniffer-agent/model"
 	sd "github.com/zr-hebo/sniffer-agent/session-dealer"
@@ -147,14 +147,8 @@ func (nc *networkCard) listenNormal() {
 
 			// deal auth packet
 			if sd.IsAuthPacket(tcpPkt.Payload) {
-				reader := bufio.NewReader(bytes.NewReader(data))
-				header, _ := proto.Read(reader)
-				var clientIP *string
-				if header != nil {
-					clientIPContent := header.SourceAddress.String()
-					clientIP = &clientIPContent
-				}
-				nc.parseTCPPackage(packet, clientIP)
+				authHeader, _ := pp.Read(bufio.NewReader(bytes.NewReader(tcpPkt.Payload)))
+				nc.parseTCPPackage(packet, authHeader)
 				continue
 			}
 
@@ -174,7 +168,7 @@ func (nc *networkCard) listenNormal() {
 	return
 }
 
-func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string) {
+func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, authHeader *pp.Header) {
 	var err error
 	defer func() {
 		if err != nil {
@@ -203,9 +197,19 @@ func (nc *networkCard) parseTCPPackage(packet gopacket.Packet, clientIP *string)
 	dstIP := ipInfo.DstIP.String()
 	srcPort := int(tcpPkt.SrcPort)
 	dstPort := int(tcpPkt.DstPort)
+
 	if dstPort == nc.listenPort {
+		// get client ip from proxy auth info
+		var clientIP *string
+		var clientPort int
+		if authHeader != nil && authHeader.SourceAddress.String() != srcIP {
+			clientIPContent := authHeader.SourceAddress.String()
+			clientIP = &clientIPContent
+			clientPort = int(authHeader.SourcePort)
+		}
+
 		// deal mysql server response
-		err = readToServerPackage(clientIP, &srcIP, srcPort, tcpPkt, nc.receiver)
+		err = readToServerPackage(clientIP, clientPort, &srcIP, srcPort, tcpPkt, nc.receiver)
 		if err != nil {
 			return
 		}
@@ -255,7 +259,8 @@ func readFromServerPackage(
 }
 
 func readToServerPackage(
-	clientIP, srcIP *string, srcPort int, tcpPkt *layers.TCP, receiver chan model.QueryPiece) (err error) {
+	clientIP *string, clientPort int, srcIP *string, srcPort int, tcpPkt *layers.TCP,
+	receiver chan model.QueryPiece) (err error) {
 	defer func() {
 		if err != nil {
 			log.Error("read package send from client to mysql server failed <-- %s", err.Error())
@@ -282,7 +287,7 @@ func readToServerPackage(
 	sessionKey := spliceSessionKey(srcIP, srcPort)
 	session := sessionPool[*sessionKey]
 	if session == nil {
-		session = sd.NewSession(sessionKey, clientIP, srcIP, srcPort, localIPAddr, snifferPort, receiver)
+		session = sd.NewSession(sessionKey, clientIP, clientPort, srcIP, srcPort, localIPAddr, snifferPort, receiver)
 		sessionPool[*sessionKey] = session
 	}
 
diff --git a/model/mysql_query_piece.go b/model/mysql_query_piece.go
index 8c75060..a4587e7 100644
--- a/model/mysql_query_piece.go
+++ b/model/mysql_query_piece.go
@@ -26,7 +26,7 @@ type PooledMysqlQueryPiece struct {
 }
 
 func NewPooledMysqlQueryPiece(
-	sessionID, clientIP, visitUser, visitDB, clientHost, serverIP *string,
+	sessionID, clientIP, visitUser, visitDB, serverIP *string,
 	clientPort, serverPort int, throwPacketRate float64, stmtBeginTime int64) (
 	mqp *PooledMysqlQueryPiece) {
 	mqp = mqpp.Dequeue()
@@ -35,7 +35,6 @@ func NewPooledMysqlQueryPiece(
 	mqp.SessionID = sessionID
 	mqp.ClientHost = clientIP
 	mqp.ClientPort = clientPort
-	mqp.ClientHost = clientHost
 	mqp.ServerIP = serverIP
 	mqp.ServerPort = serverPort
 	mqp.VisitUser = visitUser
diff --git a/session-dealer/controller.go b/session-dealer/controller.go
index 585796d..c8174dc 100644
--- a/session-dealer/controller.go
+++ b/session-dealer/controller.go
@@ -5,13 +5,13 @@ import (
 	"github.com/zr-hebo/sniffer-agent/session-dealer/mysql"
 )
 
-func NewSession(sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int,
+func NewSession(sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int,
 	receiver chan model.QueryPiece) (session ConnSession) {
 	switch serviceType {
 	case ServiceTypeMysql:
-		session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver)
+		session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver)
 	default:
-		session = mysql.NewMysqlSession(sessionKey, clientAlias, clientIP, clientPort, serverIP, serverPort, receiver)
+		session = mysql.NewMysqlSession(sessionKey, clientIP, clientPort, srcIP, srcPort, serverIP, serverPort, receiver)
 	}
 	return
 }
diff --git a/session-dealer/mysql/session.go b/session-dealer/mysql/session.go
index 1d70aa2..c9fc9df 100644
--- a/session-dealer/mysql/session.go
+++ b/session-dealer/mysql/session.go
@@ -15,9 +15,10 @@ type MysqlSession struct {
 	connectionID  *string
 	visitUser     *string
 	visitDB       *string
-	clientAlias   *string
-	clientHost    *string
+	clientIP      *string
 	clientPort    int
+	srcIP         *string
+	srcPort       int
 	serverIP      *string
 	serverPort    int
 	stmtBeginTime int64
@@ -47,13 +48,14 @@ type prepareInfo struct {
 }
 
 func NewMysqlSession(
-	sessionKey, clientAlias, clientIP *string, clientPort int, serverIP *string, serverPort int,
+	sessionKey, clientIP *string, clientPort int, srcIP *string, srcPort int, serverIP *string, serverPort int,
 	receiver chan model.QueryPiece) (ms *MysqlSession) {
 	ms = &MysqlSession{
 		connectionID:       sessionKey,
-		clientAlias:        clientAlias,
-		clientHost:         clientIP,
+		clientIP:           clientIP,
 		clientPort:         clientPort,
+		srcIP:              srcIP,
+		srcPort:            srcPort,
 		serverIP:           serverIP,
 		serverPort:         serverPort,
 		stmtBeginTime:      time.Now().UnixNano() / millSecondUnit,
@@ -335,11 +337,13 @@ func filterQueryPieceBySQL(mqp *model.PooledMysqlQueryPiece, querySQL []byte) (*
 }
 
 func (ms *MysqlSession) composeQueryPiece() (mqp *model.PooledMysqlQueryPiece) {
-	clientIP := ms.clientAlias
+	clientIP := ms.clientIP
+	clientPort := ms.clientPort
 	if clientIP == nil || len(*clientIP) < 1 {
-		clientIP = ms.clientHost
+		clientIP = ms.srcIP
+		clientPort = ms.serverPort
 	}
 	return model.NewPooledMysqlQueryPiece(
-		ms.connectionID, ms.clientHost, ms.visitUser, ms.visitDB, ms.clientHost, ms.serverIP,
-		ms.clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime)
+		ms.connectionID, clientIP, ms.visitUser, ms.visitDB, ms.serverIP,
+		clientPort, ms.serverPort, communicator.GetMysqlCapturePacketRate(), ms.stmtBeginTime)
 }