commit 3b4d8613faba4ce0b56d2ca07e7201d7a79d193b Author: wasd <1547422976@qq.com> Date: Tue Sep 25 14:13:45 2018 +0800 Init diff --git a/README.md b/README.md new file mode 100644 index 0000000..c653f78 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# go-sniffer + +Testing... \ No newline at end of file diff --git a/core/assembly.go b/core/assembly.go new file mode 100644 index 0000000..90ffba2 --- /dev/null +++ b/core/assembly.go @@ -0,0 +1,758 @@ +// Copyright (C) MongoDB, Inc. 2014-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package core + +import ( + "fmt" + "log" + "sync" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" +) + +var memLog = new(bool) +var debugLog = new(bool) + +const invalidSequence = -1 +const uint32Max = 0xFFFFFFFF + +// Sequence is a TCP sequence number. It provides a few convenience functions +// for handling TCP wrap-around. The sequence should always be in the range +// [0,0xFFFFFFFF]... its other bits are simply used in wrap-around calculations +// and should never be set. +type Sequence int64 + +// Difference defines an ordering for comparing TCP sequences that's safe for +// roll-overs. It returns: +// > 0 : if t comes after s +// < 0 : if t comes before s +// 0 : if t == s +// The number returned is the sequence difference, so 4.Difference(8) will +// return 4. +// +// It handles rollovers by considering any sequence in the first quarter of the +// uint32 space to be after any sequence in the last quarter of that space, thus +// wrapping the uint32 space. +func (s Sequence) Difference(t Sequence) int { + if s > uint32Max-uint32Max/4 && t < uint32Max/4 { + t += uint32Max + } else if t > uint32Max-uint32Max/4 && s < uint32Max/4 { + s += uint32Max + } + return int(t - s) +} + +// Add adds an integer to a sequence and returns the resulting sequence. +func (s Sequence) Add(t int) Sequence { + return (s + Sequence(t)) & uint32Max +} + +// Reassembly objects are passed by an Assembler into Streams using the +// Reassembled call. Callers should not need to create these structs themselves +// except for testing. +type Reassembly struct { + // Bytes is the next set of bytes in the stream. May be empty. + Bytes []byte + // Skip is set to non-zero if bytes were skipped between this and the last + // Reassembly. If this is the first packet in a connection and we didn't + // see the start, we have no idea how many bytes we skipped, so we set it to + // -1. Otherwise, it's set to the number of bytes skipped. + Skip int + // Start is set if this set of bytes has a TCP SYN accompanying it. + Start bool + // End is set if this set of bytes has a TCP FIN or RST accompanying it. + End bool + // Seen is the timestamp this set of bytes was pulled off the wire. + Seen time.Time +} + +const pageBytes = 1900 + +// page is used to store TCP data we're not ready for yet (out-of-order +// packets). Unused pages are stored in and returned from a pageCache, which +// avoids memory allocation. Used pages are stored in a doubly-linked list in a +// connection. +type page struct { + tcpassembly.Reassembly + seq Sequence + index int + prev, next *page + buf [pageBytes]byte +} + +// pageCache is a concurrency-unsafe store of page objects we use to avoid +// memory allocation as much as we can. It grows but never shrinks. +type pageCache struct { + free []*page + pcSize int + size, used int + pages [][]page + pageRequests int64 +} + +const initialAllocSize = 1024 + +func newPageCache() *pageCache { + pc := &pageCache{ + free: make([]*page, 0, initialAllocSize), + pcSize: initialAllocSize, + } + pc.grow() + return pc +} + +// grow exponentially increases the size of our page cache as much as necessary. +func (c *pageCache) grow() { + pages := make([]page, c.pcSize) + c.pages = append(c.pages, pages) + c.size += c.pcSize + for i := range pages { + c.free = append(c.free, &pages[i]) + } + if *memLog { + log.Println("PageCache: created", c.pcSize, "new pages") + } + c.pcSize *= 2 +} + +// next returns a clean, ready-to-use page object. +func (c *pageCache) next(ts time.Time) (p *page) { + if *memLog { + c.pageRequests++ + if c.pageRequests&0xFFFF == 0 { + log.Println("PageCache:", c.pageRequests, "requested,", c.used, "used,", len(c.free), "free") + } + } + if len(c.free) == 0 { + c.grow() + } + i := len(c.free) - 1 + p, c.free = c.free[i], c.free[:i] + p.prev = nil + p.next = nil + p.Seen = ts + p.Bytes = p.buf[:0] + c.used++ + return p +} + +// replace replaces a page into the pageCache. +func (c *pageCache) replace(p *page) { + c.used-- + c.free = append(c.free, p) +} + +// Stream is implemented by the caller to handle incoming reassembled TCP data. +// Callers create a StreamFactory, then StreamPool uses it to create a new +// Stream for every TCP stream. +// +// assembly will, in order: +// 1) Create the stream via StreamFactory.New +// 2) Call Reassembled 0 or more times, passing in reassembled TCP data in +// order +// 3) Call ReassemblyComplete one time, after which the stream is +// dereferenced by assembly. +type Stream interface { + // Reassembled is called zero or more times. Assembly guarantees that the + // set of all Reassembly objects passed in during all calls are presented in + // the order they appear in the TCP stream. Reassembly objects are reused + // after each Reassembled call, so it's important to copy anything you need + // out of them (specifically out of Reassembly.Bytes) that you need to stay + // around after you return from the Reassembled call. + Reassembled([]Reassembly) + // ReassemblyComplete is called when assembly decides there is no more data + // for this Stream, either because a FIN or RST packet was seen, or because + // the stream has timed out without any new packet data (due to a call to + // FlushOlderThan). + ReassemblyComplete() +} + +// StreamFactory is used by assembly to create a new stream for each new TCP +// session. +type StreamFactory interface { + // New should return a new stream for the given TCP key. + New(netFlow, tcpFlow gopacket.Flow) tcpassembly.Stream +} + +func (p *StreamPool) connections() []*connection { + p.mu.RLock() + conns := make([]*connection, 0, len(p.conns)) + for _, conn := range p.conns { + conns = append(conns, conn) + } + p.mu.RUnlock() + return conns +} + + //FlushOlderThan finds any streams waiting for packets older than the given + //time, and pushes through the data they have (IE: tells them to stop waiting + //and skip the data they're waiting for). + // + //Each Stream maintains a list of zero or more sets of bytes it has received + //out-of-order. For example, if it has processed up through sequence number + //10, it might have bytes [15-20), [20-25), [30,50) in its list. Each set of + //bytes also has the timestamp it was originally viewed. A flush call will + //look at the smallest subsequent set of bytes, in this case [15-20), and if + //its timestamp is older than the passed-in time, it will push it and all + //contiguous byte-sets out to the Stream's Reassembled function. In this case, + //it will push [15-20), but also [20-25), since that's contiguous. It will + //only push [30-50) if its timestamp is also older than the passed-in time, + //otherwise it will wait until the next FlushOlderThan to see if bytes [25-30) + //come in. + // + //If it pushes all bytes (or there were no sets of bytes to begin with) AND the + //connection has not received any bytes since the passed-in time, the + //connection will be closed. + // + //Returns the number of connections flushed, and of those, the number closed + //because of the flush. +func (a *Assembler) FlushOlderThan(t time.Time) (flushed, closed int) { + conns := a.connPool.connections() + closes := 0 + flushes := 0 + for _, conn := range conns { + flushed := false + conn.mu.Lock() + if conn.closed { + // Already closed connection, nothing to do here. + conn.mu.Unlock() + continue + } + for conn.first != nil && conn.first.Seen.Before(t) { + a.skipFlush(conn) + flushed = true + if conn.closed { + closes++ + break + } + } + if !conn.closed && conn.first == nil && conn.lastSeen.Before(t) { + flushed = true + a.closeConnection(conn) + closes++ + } + if flushed { + flushes++ + } + conn.mu.Unlock() + } + return flushes, closes +} + +// FlushAll flushes all remaining data into all remaining connections, closing +// those connections. It returns the total number of connections flushed/closed +// by the call. +func (a *Assembler) FlushAll() (closed int) { + conns := a.connPool.connections() + closed = len(conns) + for _, conn := range conns { + conn.mu.Lock() + for !conn.closed { + a.skipFlush(conn) + } + conn.mu.Unlock() + } + return +} + +type key [2]gopacket.Flow + +func (k *key) String() string { + return fmt.Sprintf("%s:%s", k[0], k[1]) +} + +// StreamPool stores all streams created by Assemblers, allowing multiple +// assemblers to work together on stream processing while enforcing the fact +// that a single stream receives its data serially. It is safe for concurrency, +// usable by multiple Assemblers at once. +// +// StreamPool handles the creation and storage of Stream objects used by one or +// more Assembler objects. When a new TCP stream is found by an Assembler, it +// creates an associated Stream by calling its StreamFactory's New method. +// Thereafter (until the stream is closed), that Stream object will receive +// assembled TCP data via Assembler's calls to the stream's Reassembled +// function. +// +// Like the Assembler, StreamPool attempts to minimize allocation. Unlike the +// Assembler, though, it does have to do some locking to make sure that the +// connection objects it stores are accessible to multiple Assemblers. +type StreamPool struct { + conns map[key]*connection + users int + mu sync.RWMutex + factory StreamFactory + free []*connection + all [][]connection + nextAlloc int + newConnectionCount int64 +} + +func (p *StreamPool) grow() { + conns := make([]connection, p.nextAlloc) + p.all = append(p.all, conns) + for i := range conns { + p.free = append(p.free, &conns[i]) + } + if *memLog { + log.Println("StreamPool: created", p.nextAlloc, "new connections") + } + p.nextAlloc *= 2 +} + +// NewStreamPool creates a new connection pool. Streams will +// be created as necessary using the passed-in StreamFactory. +func NewStreamPool(factory StreamFactory) *StreamPool { + return &StreamPool{ + conns: make(map[key]*connection, initialAllocSize), + free: make([]*connection, 0, initialAllocSize), + factory: factory, + nextAlloc: initialAllocSize, + } +} + +const assemblerReturnValueInitialSize = 16 + +// NewAssembler creates a new assembler. Pass in the StreamPool +// to use, may be shared across assemblers. +// +// This sets some sane defaults for the assembler options, +// see DefaultAssemblerOptions for details. +func NewAssembler(pool *StreamPool) *Assembler { + pool.mu.Lock() + pool.users++ + pool.mu.Unlock() + return &Assembler{ + ret: make([]tcpassembly.Reassembly, assemblerReturnValueInitialSize), + pc: newPageCache(), + connPool: pool, + AssemblerOptions: DefaultAssemblerOptions, + } +} + +// DefaultAssemblerOptions provides default options for an assembler. +// These options are used by default when calling NewAssembler, so if +// modified before a NewAssembler call they'll affect the resulting Assembler. +// +// Note that the default options can result in ever-increasing memory usage +// unless one of the Flush* methods is called on a regular basis. +var DefaultAssemblerOptions = AssemblerOptions{ + MaxBufferedPagesPerConnection: 0, // unlimited + MaxBufferedPagesTotal: 0, // unlimited +} + +type connection struct { + key key + pages int + first, last *page + nextSeq Sequence + created, lastSeen time.Time + stream tcpassembly.Stream + closed bool + mu sync.Mutex +} + +func (conn *connection) reset(k key, s tcpassembly.Stream, ts time.Time) { + conn.key = k + conn.pages = 0 + conn.first, conn.last = nil, nil + conn.nextSeq = invalidSequence + conn.created = ts + conn.stream = s + conn.closed = false +} + +// AssemblerOptions controls the behavior of each assembler. Modify the +// options of each assembler you create to change their behavior. +type AssemblerOptions struct { + // MaxBufferedPagesTotal is an upper limit on the total number of pages to + // buffer while waiting for out-of-order packets. Once this limit is + // reached, the assembler will degrade to flushing every connection it gets + // a packet for. If <= 0, this is ignored. + MaxBufferedPagesTotal int + // MaxBufferedPagesPerConnection is an upper limit on the number of pages + // buffered for a single connection. Should this limit be reached for a + // particular connection, the smallest sequence number will be flushed, + // along with any contiguous data. If <= 0, this is ignored. + MaxBufferedPagesPerConnection int +} + +// Assembler handles reassembling TCP streams. It is not safe for +// concurrency... after passing a packet in via the Assemble call, the caller +// must wait for that call to return before calling Assemble again. Callers can +// get around this by creating multiple assemblers that share a StreamPool. In +// that case, each individual stream will still be handled serially (each stream +// has an individual mutex associated with it), however multiple assemblers can +// assemble different connections concurrently. +// +// The Assembler provides (hopefully) fast TCP stream re-assembly for sniffing +// applications written in Go. The Assembler uses the following methods to be +// as fast as possible, to keep packet processing speedy: +// +// Avoids Lock Contention +// +// Assemblers locks connections, but each connection has an individual lock, and +// rarely will two Assemblers be looking at the same connection. Assemblers +// lock the StreamPool when looking up connections, but they use Reader locks +// initially, and only force a write lock if they need to create a new +// connection or close one down. These happen much less frequently than +// individual packet handling. +// +// Each assembler runs in its own goroutine, and the only state shared between +// goroutines is through the StreamPool. Thus all internal Assembler state can +// be handled without any locking. +// +// NOTE: If you can guarantee that packets going to a set of Assemblers will +// contain information on different connections per Assembler (for example, +// they're already hashed by PF_RING hashing or some other hashing mechanism), +// then we recommend you use a seperate StreamPool per Assembler, thus avoiding +// all lock contention. Only when different Assemblers could receive packets +// for the same Stream should a StreamPool be shared between them. +// +// Avoids Memory Copying +// +// In the common case, handling of a single TCP packet should result in zero +// memory allocations. The Assembler will look up the connection, figure out +// that the packet has arrived in order, and immediately pass that packet on to +// the appropriate connection's handling code. Only if a packet arrives out of +// order is its contents copied and stored in memory for later. +// +// Avoids Memory Allocation +// +// Assemblers try very hard to not use memory allocation unless absolutely +// necessary. Packet data for sequential packets is passed directly to streams +// with no copying or allocation. Packet data for out-of-order packets is +// copied into reusable pages, and new pages are only allocated rarely when the +// page cache runs out. Page caches are Assembler-specific, thus not used +// concurrently and requiring no locking. +// +// Internal representations for connection objects are also reused over time. +// Because of this, the most common memory allocation done by the Assembler is +// generally what's done by the caller in StreamFactory.New. If no allocation +// is done there, then very little allocation is done ever, mostly to handle +// large increases in bandwidth or numbers of connections. +// +// TODO: The page caches used by an Assembler will grow to the size necessary +// to handle a workload, and currently will never shrink. This means that +// traffic spikes can result in large memory usage which isn't garbage collected +// when typical traffic levels return. +type Assembler struct { + AssemblerOptions + ret []tcpassembly.Reassembly + pc *pageCache + connPool *StreamPool +} + +func (p *StreamPool) newConnection(k key, s tcpassembly.Stream, ts time.Time) (c *connection) { + if *memLog { + p.newConnectionCount++ + if p.newConnectionCount&0x7FFF == 0 { + log.Println("StreamPool:", p.newConnectionCount, "requests,", len(p.conns), "used,", len(p.free), "free") + } + } + if len(p.free) == 0 { + p.grow() + } + index := len(p.free) - 1 + c, p.free = p.free[index], p.free[:index] + c.reset(k, s, ts) + return c +} + +// getConnection returns a connection. If end is true and a connection +// does not already exist, returns nil. This allows us to check for a +// connection without actually creating one if it doesn't already exist. +func (p *StreamPool) getConnection(k key, end bool, ts time.Time) *connection { + p.mu.RLock() + conn := p.conns[k] + p.mu.RUnlock() + if end || conn != nil { + return conn + } + s := p.factory.New(k[0], k[1]) + p.mu.Lock() + conn = p.newConnection(k, s, ts) + if conn2 := p.conns[k]; conn2 != nil { + p.mu.Unlock() + return conn2 + } + p.conns[k] = conn + p.mu.Unlock() + return conn +} + +// Assemble calls AssembleWithTimestamp with the current timestamp, useful for +// packets being read directly off the wire. +func (a *Assembler) Assemble(netFlow gopacket.Flow, t *layers.TCP) { + a.AssembleWithTimestamp(netFlow, t, time.Now()) +} + +// AssembleWithTimestamp reassembles the given TCP packet into its appropriate +// stream. +// +// The timestamp passed in must be the timestamp the packet was seen. For +// packets read off the wire, time.Now() should be fine. For packets read from +// PCAP files, CaptureInfo.Timestamp should be passed in. This timestamp will +// affect which streams are flushed by a call to FlushOlderThan. +// +// Each Assemble call results in, in order: +// +// zero or one calls to StreamFactory.New, creating a stream +// zero or one calls to Reassembled on a single stream +// zero or one calls to ReassemblyComplete on the same stream +func (a *Assembler) AssembleWithTimestamp(netFlow gopacket.Flow, t *layers.TCP, timestamp time.Time) { + // Ignore empty TCP packets + if !t.SYN && !t.FIN && !t.RST && len(t.LayerPayload()) == 0 { + return + } + + a.ret = a.ret[:0] + key := key{netFlow, t.TransportFlow()} + var conn *connection + // This for loop handles a race condition where a connection will close, + // lock the connection pool, and remove itself, but before it locked the + // connection pool it's returned to another Assemble statement. This should + // loop 0-1 times for the VAST majority of cases. + for { + conn = a.connPool.getConnection( + key, !t.SYN && len(t.LayerPayload()) == 0, timestamp) + if conn == nil { + if *debugLog { + log.Printf("%v got empty packet on otherwise empty connection", key) + } + return + } + conn.mu.Lock() + if !conn.closed { + break + } + conn.mu.Unlock() + } + if conn.lastSeen.Before(timestamp) { + conn.lastSeen = timestamp + } + seq, bytes := Sequence(t.Seq), t.Payload + + if conn.nextSeq == invalidSequence { + // Handling the first packet we've seen on the stream. + skip := 0 + if !t.SYN { + // don't add 1 since we're just going to assume the sequence number + // without the SYN packet. + // stream was picked up somewhere in the middle, so indicate that we + // don't know how many packets came before it. + conn.nextSeq = seq.Add(len(bytes)) + skip = -1 + } else { + // for SYN packets, also increment the sequence number by 1 + conn.nextSeq = seq.Add(len(bytes) + 1) + } + a.ret = append(a.ret, tcpassembly.Reassembly{ + Bytes: bytes, + Skip: skip, + Start: t.SYN, + Seen: timestamp, + }) + a.insertIntoConn(t, conn, timestamp) + } else if diff := conn.nextSeq.Difference(seq); diff > 0 { + a.insertIntoConn(t, conn, timestamp) + } else { + bytes, conn.nextSeq = byteSpan(conn.nextSeq, seq, bytes) + a.ret = append(a.ret, tcpassembly.Reassembly{ + Bytes: bytes, + Skip: 0, + End: t.RST || t.FIN, + Seen: timestamp, + }) + } + if len(a.ret) > 0 { + a.sendToConnection(conn) + } + conn.mu.Unlock() +} + +func byteSpan(expected, received Sequence, bytes []byte) (toSend []byte, next Sequence) { + if expected == invalidSequence { + return bytes, received.Add(len(bytes)) + } + span := int(received.Difference(expected)) + if span <= 0 { + return bytes, received.Add(len(bytes)) + } else if len(bytes) < span { + return nil, expected + } + return bytes[span:], expected.Add(len(bytes) - span) +} + +// sendToConnection sends the current values in a.ret to the connection, closing +// the connection if the last thing sent had End set. +func (a *Assembler) sendToConnection(conn *connection) { + a.addContiguous(conn) + if conn.stream == nil { + panic("why?") + } + conn.stream.Reassembled(a.ret) + if a.ret[len(a.ret)-1].End { + a.closeConnection(conn) + } +} + +// addContiguous adds contiguous byte-sets to a connection. +func (a *Assembler) addContiguous(conn *connection) { + for conn.first != nil && conn.nextSeq.Difference(conn.first.seq) <= 0 { + a.addNextFromConn(conn) + } +} + +// skipFlush skips the first set of bytes we're waiting for and returns the +// first set of bytes we have. If we have no bytes pending, it closes the +// connection. +func (a *Assembler) skipFlush(conn *connection) { + if *debugLog { + log.Printf("%v skipFlush %v", conn.key, conn.nextSeq) + } + if conn.first == nil { + a.closeConnection(conn) + return + } + a.ret = a.ret[:0] + a.addNextFromConn(conn) + a.addContiguous(conn) + a.sendToConnection(conn) +} + +func (p *StreamPool) remove(conn *connection) { + p.mu.Lock() + delete(p.conns, conn.key) + p.free = append(p.free, conn) + p.mu.Unlock() +} + +func (a *Assembler) closeConnection(conn *connection) { + if *debugLog { + log.Printf("%v closing", conn.key) + } + conn.stream.ReassemblyComplete() + conn.closed = true + a.connPool.remove(conn) + for p := conn.first; p != nil; p = p.next { + a.pc.replace(p) + } +} + +// traverseConn traverses our doubly-linked list of pages for the correct +// position to put the given sequence number. Note that it traverses backwards, +// starting at the highest sequence number and going down, since we assume the +// common case is that TCP packets for a stream will appear in-order, with +// minimal loss or packet reordering. +func (conn *connection) traverseConn(seq Sequence) (prev, current *page) { + prev = conn.last + for prev != nil && prev.seq.Difference(seq) < 0 { + current = prev + prev = current.prev + } + return +} + +// pushBetween inserts the doubly-linked list first-...-last in between the +// nodes prev-next in another doubly-linked list. If prev is nil, makes first +// the new first page in the connection's list. If next is nil, makes last the +// new last page in the list. first/last may point to the same page. +func (conn *connection) pushBetween(prev, next, first, last *page) { + // Maintain our doubly linked list + if next == nil || conn.last == nil { + conn.last = last + } else { + last.next = next + next.prev = last + } + if prev == nil || conn.first == nil { + conn.first = first + } else { + first.prev = prev + prev.next = first + } +} + +func (a *Assembler) insertIntoConn(t *layers.TCP, conn *connection, ts time.Time) { + if conn.first != nil && conn.first.seq == conn.nextSeq { + panic("wtf") + } + p, p2, numPages := a.pagesFromTCP(t, ts) + prev, current := conn.traverseConn(Sequence(t.Seq)) + conn.pushBetween(prev, current, p, p2) + conn.pages += numPages + if (a.MaxBufferedPagesPerConnection > 0 && conn.pages >= a.MaxBufferedPagesPerConnection) || + (a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) { + if *debugLog { + log.Printf("%v hit max buffer size: %+v, %v, %v", conn.key, a.AssemblerOptions, conn.pages, a.pc.used) + } + a.addNextFromConn(conn) + } +} + +// pagesFromTCP creates a page (or set of pages) from a TCP packet. Note that +// it should NEVER receive a SYN packet, as it doesn't handle sequences +// correctly. +// +// It returns the first and last page in its doubly-linked list of new pages. +func (a *Assembler) pagesFromTCP(t *layers.TCP, ts time.Time) (p, p2 *page, numPages int) { + first := a.pc.next(ts) + current := first + numPages++ + seq, bytes := Sequence(t.Seq), t.Payload + for { + length := min(len(bytes), pageBytes) + current.Bytes = current.buf[:length] + copy(current.Bytes, bytes) + current.seq = seq + bytes = bytes[length:] + if len(bytes) == 0 { + break + } + seq = seq.Add(length) + current.next = a.pc.next(ts) + current.next.prev = current + current = current.next + numPages++ + } + current.End = t.RST || t.FIN + return first, current, numPages +} + +// addNextFromConn pops the first page from a connection off and adds it to the +// return array. +func (a *Assembler) addNextFromConn(conn *connection) { + if conn.nextSeq == invalidSequence { + conn.first.Skip = -1 + } else if diff := conn.nextSeq.Difference(conn.first.seq); diff > 0 { + conn.first.Skip = int(diff) + } + conn.first.Bytes, conn.nextSeq = byteSpan(conn.nextSeq, conn.first.seq, conn.first.Bytes) + if *debugLog { + log.Printf("%v adding from conn (%v, %v)", conn.key, conn.first.seq, conn.nextSeq) + } + a.ret = append(a.ret, conn.first.Reassembly) + a.pc.replace(conn.first) + if conn.first == conn.last { + conn.first = nil + conn.last = nil + } else { + conn.first = conn.first.next + conn.first.prev = nil + } + conn.pages-- +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/core/cmd.go b/core/cmd.go new file mode 100644 index 0000000..a9ec039 --- /dev/null +++ b/core/cmd.go @@ -0,0 +1,144 @@ +package core + +import ( + "os" + "strings" + "fmt" + "net" + "strconv" +) + +const InternalCmdPrefix = "--" +const ( + InternalCmdHelp = "help" //帮助文档 + InternalCmdEnv = "env" //环境变量 + InternalCmdList = "list" //插件列表 + InternalCmdVer = "ver" //版本信息 + InternalDevice = "dev" //设备链表 +) + +type Cmd struct { + Device string + plugHandle *Plug +} + +func NewCmd(p *Plug) *Cmd { + + return &Cmd{ + plugHandle:p, + } +} + +//start +func (cm *Cmd) Run() { + + //使用帮助 + if len(os.Args) <= 1 { + cm.printHelpMessage(); + os.Exit(1) + } + + //解析命令 + firstArg := string(os.Args[1]) + if strings.HasPrefix(firstArg, InternalCmdPrefix) { + cm.parseInternalCmd() + } else { + cm.parsePlugCmd() + } +} + +//解析内部参数 +func (cm *Cmd) parseInternalCmd() { + + arg := string(os.Args[1]) + cmd := strings.Trim(arg, InternalCmdPrefix) + + switch cmd { + case InternalCmdHelp: + cm.printHelpMessage() + break; + case InternalCmdEnv: + fmt.Println("插件路径:"+cm.plugHandle.dir) + break + case InternalCmdList: + cm.plugHandle.PrintList() + break + case InternalCmdVer: + fmt.Println(cxt.Version) + break + case InternalDevice: + cm.printDevice() + break; + } + os.Exit(1) +} + +//使用说明 +func (cm *Cmd) printHelpMessage() { + + fmt.Println("==================================================================================") + fmt.Println("[使用说明]") + fmt.Println("") + fmt.Println(" go-sniffer [设备名] [插件名] [插件参数(可选)]") + fmt.Println() + fmt.Println(" [例子]") + fmt.Println(" go-sniffer en0 redis 抓取redis数据包") + fmt.Println(" go-sniffer en0 mysql -p 3306 抓取mysql数据包,端口3306") + fmt.Println() + fmt.Println(" go-sniffer --[命令]") + fmt.Println(" --help 帮助信息") + fmt.Println(" --env 环境变量") + fmt.Println(" --list 插件列表") + fmt.Println(" --ver 版本信息") + fmt.Println(" --dev 设备列表") + fmt.Println(" [例子]") + fmt.Println(" go-sniffer --list 查看可抓取的协议") + fmt.Println() + fmt.Println("==================================================================================") + cm.printDevice() + fmt.Println("==================================================================================") +} + +//打印插件 +func (cm *Cmd) printPlugList() { + l := len(cm.plugHandle.InternalPlugList) + l += len(cm.plugHandle.ExternalPlugList) + fmt.Println("# 插件数量:"+strconv.Itoa(l)) +} + +//打印设备 +func (cm *Cmd) printDevice() { + ifaces, err:= net.Interfaces() + if err != nil { + panic(err) + } + for _, iface := range ifaces { + addrs, _ := iface.Addrs() + for _,a:=range addrs { + if ipnet, ok := a.(*net.IPNet); ok { + if ip4 := ipnet.IP.To4(); ip4 != nil { + fmt.Println("[设备名] : "+iface.Name+" : "+iface.HardwareAddr.String()+" "+ip4.String()) + } + } + } + } +} + +//解析插件需要的参数 +func (cm *Cmd) parsePlugCmd() { + + if len(os.Args) < 3 { + fmt.Println("缺少[插件名]") + fmt.Println("go-sniffer [设备名] [插件名] [插件参数(可选)]") + os.Exit(1) + } + + cm.Device = os.Args[1] + plugName := os.Args[2] + plugParams:= os.Args[3:] + cm.plugHandle.SetOption(plugName, plugParams) +} + + + + diff --git a/core/core.go b/core/core.go new file mode 100644 index 0000000..9613de7 --- /dev/null +++ b/core/core.go @@ -0,0 +1,28 @@ +package core + +type Core struct{ + //版本信息 + Version string +} + +var cxt Core + +func New() Core { + + cxt.Version = "0.1" + + return cxt +} + +func (c *Core) Run() { + + //插件 + plug := NewPlug() + + //解析参数 + cmd := NewCmd(plug) + cmd.Run() + + //开启抓包 + NewDispatch(plug, cmd).Capture() +} \ No newline at end of file diff --git a/core/dispatch.go b/core/dispatch.go new file mode 100644 index 0000000..16ea044 --- /dev/null +++ b/core/dispatch.go @@ -0,0 +1,100 @@ +package core + +import ( + "fmt" + "github.com/google/gopacket/pcap" + "log" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" + "github.com/google/gopacket/tcpassembly/tcpreader" + "time" +) + +type Dispatch struct { + device string + payload []byte + Plug *Plug +} + +func NewDispatch(plug *Plug, cmd *Cmd) *Dispatch { + return &Dispatch { + Plug: plug, + device:cmd.Device, + } +} + +func (d *Dispatch) Capture() { + + // Init device + handle, err := pcap.OpenLive(d.device, 65535, false, pcap.BlockForever) + if err != nil { + return + } + + // Set filter + fmt.Println(d.Plug.BPF) + err = handle.SetBPFFilter(d.Plug.BPF) + if err != nil { + log.Fatal(err) + } + + // Capture + src := gopacket.NewPacketSource(handle, handle.LinkType()) + packets := src.Packets() + + // Set up assembly + streamFactory := &ProtocolStreamFactory{ + dispatch:d, + } + streamPool := NewStreamPool(streamFactory) + assembler := NewAssembler(streamPool) + ticker := time.Tick(time.Minute) + + // Loop until ctrl+z + for { + select { + case packet := <-packets: + if packet.NetworkLayer() == nil || + packet.TransportLayer() == nil || + packet.TransportLayer().LayerType() != layers.LayerTypeTCP { + fmt.Println("包不能解析") + continue + } + tcp := packet.TransportLayer().(*layers.TCP) + assembler.AssembleWithTimestamp( + packet.NetworkLayer().NetworkFlow(), + tcp, packet.Metadata().Timestamp, + ) + case <-ticker: + assembler.FlushOlderThan(time.Now().Add(time.Minute * -2)) + } + } +} + +type ProtocolStreamFactory struct { + dispatch *Dispatch +} + +type ProtocolStream struct { + net, transport gopacket.Flow + r tcpreader.ReaderStream +} + +func (m *ProtocolStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { + + //init stream struct + stm := &ProtocolStream { + net: net, + transport: transport, + r: tcpreader.NewReaderStream(), + } + + //new stream + fmt.Println("# 新连接:", net, transport) + + //decode packet + go m.dispatch.Plug.ResolveStream(net, transport, &(stm.r)) + + return &(stm.r) +} \ No newline at end of file diff --git a/core/plug.go b/core/plug.go new file mode 100644 index 0000000..af0efd6 --- /dev/null +++ b/core/plug.go @@ -0,0 +1,198 @@ +package core + +import ( + "io/ioutil" + "plugin" + "github.com/google/gopacket" + "io" + mysql "github.com/40t/go-sniffer/plugSrc/mysql/build" + redis "github.com/40t/go-sniffer/plugSrc/redis/build" + hp "github.com/40t/go-sniffer/plugSrc/http/build" + "path/filepath" + "fmt" + "path" +) + +type Plug struct { + + //当前插件路径 + dir string + //解析包 + ResolveStream func(net gopacket.Flow, transport gopacket.Flow, r io.Reader) + //BPF + BPF string + + //内部插件列表 + InternalPlugList map[string]PlugInterface + //外部插件列表 + ExternalPlugList map[string]ExternalPlug +} + +// 内部插件必须实现此接口 +// ResolvePacket - 包入口 +// BPFFilter - 设置BPF规则,例如mysql: (tcp and port 3306) +// SetFlag - 设置参数 +// Version - 返回插件版本,例如0.1.0 +type PlugInterface interface { + //解析流 + ResolveStream(net gopacket.Flow, transport gopacket.Flow, r io.Reader) + //BPF + BPFFilter() string + //设置插件需要的参数 + SetFlag([]string) + //获取版本 + Version() string +} + +//外部插件 +type ExternalPlug struct { + Name string + Version string + ResolvePacket func(net gopacket.Flow, transport gopacket.Flow, r io.Reader) + BPFFilter func() string + SetFlag func([]string) +} + +//实例化 +func NewPlug() *Plug { + + var p Plug + + //设置默认插件目录 + p.dir, _ = filepath.Abs( "./plug/") + + //加载内部插件 + p.LoadInternalPlugList() + + //加载外部插件 + p.LoadExternalPlugList() + + return &p +} + +//加载内部插件 +func (p *Plug) LoadInternalPlugList() { + + list := make(map[string]PlugInterface) + + //Mysql + list["mysql"] = mysql.NewInstance() + + //TODO Mongodb + + //TODO ARP + + //Redis + list["redis"] = redis.NewInstance() + //Http + list["http"] = hp.NewInstance() + + p.InternalPlugList = list +} + +//加载外部so后缀插件 +func (p *Plug) LoadExternalPlugList() { + + dir, err := ioutil.ReadDir(p.dir) + if err != nil { + panic(p.dir + "不存在,或者无权访问") + } + + p.ExternalPlugList = make(map[string]ExternalPlug) + for _, fi := range dir { + if fi.IsDir() || path.Ext(fi.Name()) != ".so" { + continue + } + + plug, err := plugin.Open(p.dir+"/"+fi.Name()) + if err != nil { + panic(err) + } + + versionFunc, err := plug.Lookup("Version") + if err != nil { + panic(err) + } + + setFlagFunc, err := plug.Lookup("SetFlag") + if err != nil { + panic(err) + } + + BPFFilterFunc, err := plug.Lookup("BPFFilter") + if err != nil { + panic(err) + } + + ResolvePacketFunc, err := plug.Lookup("ResolvePacket") + if err != nil { + panic(err) + } + + version := versionFunc.(func() string)() + p.ExternalPlugList[fi.Name()] = ExternalPlug { + ResolvePacket:ResolvePacketFunc.(func(net gopacket.Flow, transport gopacket.Flow, r io.Reader)), + SetFlag:setFlagFunc.(func([]string)), + BPFFilter:BPFFilterFunc.(func() string), + Version:version, + Name:fi.Name(), + } + } +} + +//改变插件地址 +func (p *Plug) ChangePath(dir string) { + p.dir = dir +} + +//打印插件列表 +func (p *Plug) PrintList() { + + //Print Internal Plug + for inPlugName, _ := range p.InternalPlugList { + fmt.Println("内部插件:"+inPlugName) + } + + //split + fmt.Println("-- --- --") + + //print External Plug + for exPlugName, _ := range p.ExternalPlugList { + fmt.Println("外部插件:"+exPlugName) + } +} + +//选择当前使用的插件 && 加载插件 +func (p *Plug) SetOption(plugName string, plugParams []string) { + + //Load Internal Plug + if internalPlug, ok := p.InternalPlugList[plugName]; ok { + + p.ResolveStream = internalPlug.ResolveStream + internalPlug.SetFlag(plugParams) + p.BPF = internalPlug.BPFFilter() + + return + } + + //Load External Plug + plug, err := plugin.Open("./plug/"+ plugName) + if err != nil { + panic(err) + } + resolvePacket, err := plug.Lookup("ResolvePacket") + if err != nil { + panic(err) + } + setFlag, err := plug.Lookup("SetFlag") + if err != nil { + panic(err) + } + BPFFilter, err := plug.Lookup("BPFFilter") + if err != nil { + panic(err) + } + p.ResolveStream = resolvePacket.(func(net gopacket.Flow, transport gopacket.Flow, r io.Reader)) + setFlag.(func([]string))(plugParams) + p.BPF = BPFFilter.(func()string)() +} \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..4dfee3a --- /dev/null +++ b/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/40t/go-sniffer/core" +) + +func main() { + core := core.New() + core.Run() +} \ No newline at end of file diff --git a/plug/.gitkeep b/plug/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/plugSrc/http/build/entry.go b/plugSrc/http/build/entry.go new file mode 100644 index 0000000..38caef3 --- /dev/null +++ b/plugSrc/http/build/entry.go @@ -0,0 +1,106 @@ +package build + +import ( + "github.com/google/gopacket" + "io" + "log" + "strconv" + "fmt" + "os" + "bufio" + "net/http" +) + +const ( + Port = 80 + Version = "0.1" +) + +const ( + CmdPort = "-p" +) + +type H struct { + port int//端口 + version string//插件版本 +} + +var hp *H + +func NewInstance() *H { + if hp == nil { + hp = &H{ + port :Port, + version:Version, + } + } + return hp +} + +func (m *H) ResolveStream(net, transport gopacket.Flow, buf io.Reader) { + + bio := bufio.NewReader(buf) + for { + req, err := http.ReadRequest(bio) + + if err == io.EOF { + return + } else if err != nil { + continue + } else { + + var msg = "[" + msg += req.Method + msg += "] [" + msg += req.Host + req.URL.String() + msg += "] [" + req.ParseForm() + msg += req.Form.Encode() + msg += "]" + + log.Println(msg) + + req.Body.Close() + } + } +} + +func (m *H) BPFFilter() string { + return "tcp and port "+strconv.Itoa(m.port); +} + +func (m *H) Version() string { + return Version +} + +func (m *H) SetFlag(flg []string) { + + c := len(flg) + + if c == 0 { + return + } + if c >> 1 == 0 { + fmt.Println("http参数数量不正确!") + os.Exit(1) + } + for i:=0;i 65535 { + panic("参数不正确: 端口范围(0-65535)") + } + break + default: + panic("参数不正确") + } + } +} \ No newline at end of file diff --git a/plugSrc/mysql/build/const.go b/plugSrc/mysql/build/const.go new file mode 100644 index 0000000..37fd3c6 --- /dev/null +++ b/plugSrc/mysql/build/const.go @@ -0,0 +1,79 @@ +package build + +const ( + ComQueryRequestPacket string = "【查询】" + OkPacket string = "【正确】" + ErrorPacket string = "【错误】" + PreparePacket string = "【预处理】" + SendClientHandshakePacket string = "【用户认证】" + SendServerHandshakePacket string = "【登录认证】" +) + +const ( + COM_SLEEP byte = 0 + COM_QUIT = 1 + COM_INIT_DB = 2 + COM_QUERY = 3 + COM_FIELD_LIST = 4 + COM_CREATE_DB = 5 + COM_DROP_DB = 6 + COM_REFRESH = 7 + COM_SHUTDOWN = 8 + COM_STATISTICS = 9 + COM_PROCESS_INFO = 10 + COM_CONNECT = 11 + COM_PROCESS_KILL = 12 + COM_DEBUG = 13 + COM_PING = 14 + COM_TIME = 15 + COM_DELAYED_INSERT = 16 + COM_CHANGE_USER = 17 + COM_BINLOG_DUMP = 18 + COM_TABLE_DUMP = 19 + COM_CONNECT_OUT = 20 + COM_REGISTER_SLAVE = 21 + COM_STMT_PREPARE = 22 + COM_STMT_EXECUTE = 23 + COM_STMT_SEND_LONG_DATA = 24 + COM_STMT_CLOSE = 25 + COM_STMT_RESET = 26 + COM_SET_OPTION = 27 + COM_STMT_FETCH = 28 + COM_DAEMON = 29 + COM_BINLOG_DUMP_GTID = 30 + COM_RESET_CONNECTION = 31 +) + +const ( + MYSQL_TYPE_DECIMAL byte = 0 + MYSQL_TYPE_TINY = 1 + MYSQL_TYPE_SHORT = 2 + MYSQL_TYPE_LONG = 3 + MYSQL_TYPE_FLOAT = 4 + MYSQL_TYPE_DOUBLE = 5 + MYSQL_TYPE_NULL = 6 + MYSQL_TYPE_TIMESTAMP = 7 + MYSQL_TYPE_LONGLONG = 8 + MYSQL_TYPE_INT24 = 9 + MYSQL_TYPE_DATE = 10 + MYSQL_TYPE_TIME = 11 + MYSQL_TYPE_DATETIME = 12 + MYSQL_TYPE_YEAR = 13 + MYSQL_TYPE_NEWDATE = 14 + MYSQL_TYPE_VARCHAR = 15 + MYSQL_TYPE_BIT = 16 +) + +const ( + MYSQL_TYPE_JSON byte = iota + 0xf5 + MYSQL_TYPE_NEWDECIMAL + MYSQL_TYPE_ENUM + MYSQL_TYPE_SET + MYSQL_TYPE_TINY_BLOB + MYSQL_TYPE_MEDIUM_BLOB + MYSQL_TYPE_LONG_BLOB + MYSQL_TYPE_BLOB + MYSQL_TYPE_VAR_STRING + MYSQL_TYPE_STRING + MYSQL_TYPE_GEOMETRY +) diff --git a/plugSrc/mysql/build/entry.go b/plugSrc/mysql/build/entry.go new file mode 100644 index 0000000..52781a6 --- /dev/null +++ b/plugSrc/mysql/build/entry.go @@ -0,0 +1,350 @@ +package build + +import ( + "github.com/google/gopacket" + "io" + "bytes" + "errors" + "log" + "strconv" + "sync" + "time" + "fmt" + "encoding/binary" + "strings" + "os" +) + +const ( + Port = 3306 + Version = "0.1" + CmdPort = "-p" +) + +type Mysql struct { + port int//端口 + version string//插件版本 + source map[string]*stream//流 +} + +type stream struct { + packets chan *packet + stmtMap map[uint32]*Stmt +} + +type packet struct { + isClientFlow bool + seq int + length int + payload []byte +} + +var mysql *Mysql +var once sync.Once +func NewInstance() *Mysql { + + once.Do(func() { + mysql = &Mysql{ + port :Port, + version:Version, + source: make(map[string]*stream), + } + }) + + return mysql +} + +func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) { + + //uuid + uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash()) + + //generate resolve's stream + if _, ok := m.source[uuid]; !ok { + + var newStream = stream{ + packets:make(chan *packet, 100), + stmtMap:make(map[uint32]*Stmt), + } + + m.source[uuid] = &newStream + go newStream.resolve() + } + + //read bi-directional packet + //server -> client || client -> server + for { + + newPacket := m.newPacket(net, transport, buf) + + if newPacket == nil { + return + } + + m.source[uuid].packets <- newPacket + } +} + +func (m *Mysql) BPFFilter() string { + return "tcp and port "+strconv.Itoa(m.port); +} + +func (m *Mysql) Version() string { + return Version +} + +func (m *Mysql) SetFlag(flg []string) { + + c := len(flg) + + if c == 0 { + return + } + if c >> 1 == 0 { + fmt.Println("Mysql参数数量不正确!") + os.Exit(1) + } + for i:=0;i 65535 { + panic("参数不正确: 端口范围(0-65535)") + } + break + default: + panic("参数不正确") + } + } +} + +func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet { + + //read packet + var payload bytes.Buffer + var seq uint8 + var err error + if seq, err = m.resolvePacketTo(r, &payload); err != nil { + return nil + } + + //close stream + if err == io.EOF { + fmt.Println(net, transport, " 关闭") + return nil + } else if err != nil { + fmt.Println("错误流:", net, transport, ":", err) + } + + //generate new packet + var pk = packet{ + seq: int(seq), + length:payload.Len(), + payload:payload.Bytes(), + } + if transport.Src().String() == strconv.Itoa(Port) { + pk.isClientFlow = false + }else{ + pk.isClientFlow = true + } + + return &pk +} + +func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) { + + header := make([]byte, 4) + if n, err := io.ReadFull(r, header); err != nil { + if n == 0 && err == io.EOF { + return 0, io.EOF + } + return 0, errors.New("错误流") + } + + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + + var seq uint8 + seq = header[3] + + if n, err := io.CopyN(w, r, int64(length)); err != nil { + return 0, errors.New("错误流") + } else if n != int64(length) { + return 0, errors.New("错误流") + } else { + return seq, nil + } + + return seq, nil +} + +func (stm *stream) resolve() { + for { + select { + case packet := <- stm.packets: + if packet.isClientFlow { + stm.resolveClientPacket(packet.payload, packet.seq) + } else { + stm.resolveServerPacket(packet.payload, packet.seq) + } + } + } +} + +func (stm *stream) findStmtPacket (srv chan *packet, seq int) *packet { + for { + select { + case packet, ok := <- stm.packets: + if !ok { + return nil + } + if packet.seq == seq { + return packet + } + case <-time.After(5 * time.Second): + return nil + } + } +} + +func (stm *stream) resolveServerPacket(payload []byte, seq int) { + + var msg = "" + switch payload[0] { + + case 0xff: + errorCode := int(binary.LittleEndian.Uint16(payload[1:3])) + errorMsg,_ := ReadStringFromByte(payload[4:]) + + msg = GetNowStr(false)+"%s 错误代码:%s,错误信息:%s" + msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg)) + + case 0x00: + var pos = 1 + l,_ := LengthBinary(payload[pos:]) + affectedRows := int(l) + + msg += GetNowStr(false)+"%s 影响行数:%s" + msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows)) + + default: + return + } + + fmt.Println(msg) +} + +func (stm *stream) resolveClientPacket(payload []byte, seq int) { + + var msg string + switch payload[0] { + + case COM_INIT_DB: + + msg = fmt.Sprintf("USE %s;\n", payload[1:]) + case COM_DROP_DB: + + msg = fmt.Sprintf("删除数据库 %s;\n", payload[1:]) + case COM_CREATE_DB, COM_QUERY: + + statement := string(payload[1:]) + msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement) + case COM_STMT_PREPARE: + + serverPacket := stm.findStmtPacket(stm.packets, seq+1) + if serverPacket == nil { + log.Println("找不到预处理响应包") + } + + //获取响应包中预处理id + stmtID := binary.LittleEndian.Uint32(serverPacket.payload[1:5]) + stmt := &Stmt{ + ID: stmtID, + Query: string(payload[1:]), + } + + //记录预处理语句 + stm.stmtMap[stmtID] = stmt + stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7]) + stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9]) + stmt.Args = make([]interface{}, stmt.ParamCount) + + msg = PreparePacket+stmt.Query + case COM_STMT_SEND_LONG_DATA: + + stmtID := binary.LittleEndian.Uint32(payload[1:5]) + paramId := binary.LittleEndian.Uint16(payload[5:7]) + stmt, _ := stm.stmtMap[stmtID] + + if stmt.Args[paramId] == nil { + stmt.Args[paramId] = payload[7:] + } else { + if b, ok := stmt.Args[paramId].([]byte); ok { + b = append(b, payload[7:]...) + stmt.Args[paramId] = b + } + } + return + case COM_STMT_RESET: + + stmtID := binary.LittleEndian.Uint32(payload[1:5]) + stmt, _:= stm.stmtMap[stmtID] + stmt.Args = make([]interface{}, stmt.ParamCount) + return + case COM_STMT_EXECUTE: + + var pos = 1 + stmtID := binary.LittleEndian.Uint32(payload[pos : pos+4]) + pos += 4 + var stmt *Stmt + var ok bool + if stmt, ok = stm.stmtMap[stmtID]; ok == false { + log.Println("未发现预处理id: ", stmtID) + } + + //参数 + pos += 5 + if stmt.ParamCount > 0 { + + //空位图(Null-Bitmap,长度 = (参数数量 + 7) / 8 字节) + step := int((stmt.ParamCount + 7) / 8) + nullBitmap := payload[pos : pos+step] + pos += step + + //参数分隔标志 + flag := payload[pos] + + pos++ + + var pTypes []byte + var pValues []byte + + //如果参数分隔标志值为1 + //n 每个参数的类型值(长度 = 参数数量 * 2 字节) + //n 每个参数的值 + if flag == 1 { + pTypes = payload[pos : pos+int(stmt.ParamCount)*2] + pos += int(stmt.ParamCount) * 2 + pValues = payload[pos:] + } + + //绑定参数 + err := stmt.BindArgs(nullBitmap, pTypes, pValues) + if err != nil { + log.Println("预处理绑定参数失败: ", err) + } + } + msg = string(stmt.WriteToText()) + default: + return + } + + fmt.Println(GetNowStr(true) + msg) +} + diff --git a/plugSrc/mysql/build/stmt.go b/plugSrc/mysql/build/stmt.go new file mode 100644 index 0000000..22faeec --- /dev/null +++ b/plugSrc/mysql/build/stmt.go @@ -0,0 +1,175 @@ +package build + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "strings" + + "errors" +) + +type Stmt struct { + ID uint32 + Query string + ParamCount uint16 + FieldCount uint16 + + Args []interface{} +} + +func (stmt *Stmt) WriteToText() []byte { + + var buf bytes.Buffer + + str := fmt.Sprintf("预处理编号[%d]: '%s';\n", stmt.ID, stmt.Query) + buf.WriteString(str) + + for i := 0; i < int(stmt.ParamCount); i++ { + var str string + switch stmt.Args[i].(type) { + case nil: + str = fmt.Sprintf("set @p%v = NULL;\n", i) + case []byte: + param := string(stmt.Args[i].([]byte)) + str = fmt.Sprintf("set @p%v = '%s';\n", i, strings.TrimSpace(param)) + default: + str = fmt.Sprintf("set @p%v = %v;\n", i, stmt.Args[i]) + } + buf.WriteString(str) + } + + str = fmt.Sprintf("执行预处理[%d]: ", stmt.ID) + buf.WriteString(str) + for i := 0; i < int(stmt.ParamCount); i++ { + if i == 0 { + buf.WriteString(" using ") + } + if i > 0 { + buf.WriteString(", ") + } + str := fmt.Sprintf("@p%v", i) + buf.WriteString(str) + } + buf.WriteString(";\n") + + str = fmt.Sprintf("丢弃预处理[%d];\n", stmt.ID) + buf.WriteString(str) + + return buf.Bytes() +} + +func (stmt *Stmt) BindArgs(nullBitmap, paramTypes, paramValues []byte) error { + + args := stmt.Args + pos := 0 + + var v []byte + var n = 0 + var isNull bool + var err error + + for i := 0; i < int(stmt.ParamCount); i++ { + + //判断参数是否为null + if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { + args[i] = nil + continue + } + + //参数类型 + typ := paramTypes[i<<1] + unsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 + + switch typ { + case MYSQL_TYPE_NULL: + args[i] = nil + continue + + case MYSQL_TYPE_TINY: + + value := paramValues[pos] + if unsigned { + args[i] = uint8(value) + } else { + args[i] = int8(value) + } + + pos++ + continue + + case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: + + value := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) + if unsigned { + args[i] = uint16(value) + } else { + args[i] = int16(value) + } + pos += 2 + continue + + case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: + + value := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + if unsigned { + args[i] = uint32(value) + } else { + args[i] = int32(value) + } + pos += 4 + continue + + case MYSQL_TYPE_LONGLONG: + + value := binary.LittleEndian.Uint64(paramValues[pos : pos+8]) + if unsigned { + args[i] = value + } else { + args[i] = int64(value) + } + pos += 8 + continue + + case MYSQL_TYPE_FLOAT: + + value := math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) + args[i] = float32(value) + pos += 4 + continue + + case MYSQL_TYPE_DOUBLE: + + value := math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) + args[i] = value + pos += 8 + continue + + case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, + MYSQL_TYPE_VARCHAR, MYSQL_TYPE_BIT, + MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, + MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, + MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, + MYSQL_TYPE_GEOMETRY, + MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME: + + v, isNull, n, err = LengthEncodedString(paramValues[pos:]) + pos += n + if err != nil { + return err + } + + if !isNull { + args[i] = v + continue + } else { + args[i] = nil + continue + } + default: + return errors.New(fmt.Sprintf("预处理未知类型 %d", typ)) + } + } + return nil +} diff --git a/plugSrc/mysql/build/util.go b/plugSrc/mysql/build/util.go new file mode 100644 index 0000000..fd26727 --- /dev/null +++ b/plugSrc/mysql/build/util.go @@ -0,0 +1,94 @@ +package build + +import ( + "bytes" + "encoding/binary" + "io" + "time" +) + +func GetNowStr(isClient bool) string { + var msg string + msg += time.Now().Format("2006-01-02 15:04:05") + if isClient { + msg += "| cli -> ser |" + }else{ + msg += "| ser -> cli |" + } + return msg +} + +func ReadStringFromByte(b []byte) (string,int) { + + var l int + l = bytes.IndexByte(b, 0x00) + if l == -1 { + l = len(b) + } + return string(b[0:l]), l +} + +func LengthBinary(b []byte) (uint32, int) { + + var first = int(b[0]) + if first > 0 && first <= 250 { + return uint32(first), 1 + } + if first == 251 { + return 0,1 + } + if first == 252 { + return binary.LittleEndian.Uint32(b[1:2]),1 + } + if first == 253 { + return binary.LittleEndian.Uint32(b[1:4]),3 + } + if first == 254 { + return binary.LittleEndian.Uint32(b[1:9]),8 + } + return 0,0 +} + +func LengthEncodedInt(input []byte) (num uint64, isNull bool, n int) { + + switch input[0] { + + case 0xfb: + n = 1 + isNull = true + return + case 0xfc: + num = uint64(input[1]) | uint64(input[2])<<8 + n = 3 + return + case 0xfd: + num = uint64(input[1]) | uint64(input[2])<<8 | uint64(input[3])<<16 + n = 4 + return + case 0xfe: + num = uint64(input[1]) | uint64(input[2])<<8 | uint64(input[3])<<16 | + uint64(input[4])<<24 | uint64(input[5])<<32 | uint64(input[6])<<40 | + uint64(input[7])<<48 | uint64(input[8])<<56 + n = 9 + return + } + + num = uint64(input[0]) + n = 1 + return +} + +func LengthEncodedString(b []byte) ([]byte, bool, int, error) { + + num, isNull, n := LengthEncodedInt(b) + if num < 1 { + return nil, isNull, n, nil + } + + n += int(num) + + if len(b) >= n { + return b[n-int(num) : n], false, n, nil + } + return nil, false, n, io.EOF +} diff --git a/plugSrc/redis/build/entry.go b/plugSrc/redis/build/entry.go new file mode 100644 index 0000000..e14288d --- /dev/null +++ b/plugSrc/redis/build/entry.go @@ -0,0 +1,126 @@ +package build + +import ( + "github.com/google/gopacket" + "io" + "strings" + "fmt" + "strconv" + "bufio" +) + +type Redis struct { + port int + version string + cmd chan string + done chan bool +} + +const ( + Port int = 6379 + Version string = "0.1" + CmdPort string = "-p" +) + +var redis = &Redis { + port:Port, + version:Version, +} + +func NewInstance() *Redis{ + return redis +} + +/** + 解析流 + */ +func (red Redis) ResolveStream(net, transport gopacket.Flow, r io.Reader) { + + //只解析clint发出去的包 + buf := bufio.NewReader(r) + var cmd string + var cmdCount = 0 + for { + + line, _, _ := buf.ReadLine() + //判断链接是否断开 + if len(line) == 0 { + buff := make([]byte, 1) + _, err := r.Read(buff) + if err == io.EOF { + red.done <- true + return + } + } + + //过滤无用数据 + if !strings.HasPrefix(string(line), "*") { + continue + } + + //过滤服务器返回数据 + if strings.EqualFold(transport.Src().String(), strconv.Itoa(red.port)) == true { + continue + } + + //解析 + l := string(line[1]) + cmdCount, _ = strconv.Atoi(l) + cmd = "" + for j := 0; j < cmdCount * 2; j++ { + c, _, _ := buf.ReadLine() + if j & 1 == 0 { + continue + } + cmd += " " + string(c) + } + fmt.Println(cmd) + } +} + +/** + SetOption + */ +func (red *Redis) SetFlag(flg []string) { + c := len(flg) + if c == 0 { + return + } + if c >> 1 != 1 { + panic("Mysql参数数量不正确!") + } + for i:=0;i 65535 { + panic("参数不正确: 端口范围(0-65535)") + } + break + default: + panic("参数不正确") + } + } +} + +/** + BPFFilter + */ +func (red *Redis) BPFFilter() string { + return "tcp and port "+strconv.Itoa(redis.port); +} + +/** + Version + */ +func (red *Redis) Version() string { + return red.version; +} +