package build

import (
	"bytes"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"log"
	"os"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/google/gopacket"
)

const (
	Port    = 3306
	Version = "0.1"
	CmdPort = "-p"
)

type Mysql struct {
	port    int
	version string
	source  map[string]*stream
}

type stream struct {
	packets chan *packet
	stmtMap map[uint32]*Stmt
}

type packet struct {
	isClientFlow bool
	seq          int
	length       int
	payload      []byte
}

var mysql *Mysql
var once sync.Once

func NewInstance() *Mysql {

	once.Do(func() {
		mysql = &Mysql{
			port:    Port,
			version: Version,
			source:  make(map[string]*stream),
		}
	})

	return mysql
}

func (m *Mysql) ResolveStream(net, transport gopacket.Flow, buf io.Reader) {

	//uuid
	uuid := fmt.Sprintf("%v:%v", net.FastHash(), transport.FastHash())

	//generate resolve's stream
	if _, ok := m.source[uuid]; !ok {

		var newStream = stream{
			packets: make(chan *packet, 100),
			stmtMap: make(map[uint32]*Stmt),
		}

		m.source[uuid] = &newStream
		go newStream.resolve()
	}

	//read bi-directional packet
	//server -> client || client -> server
	for {

		newPacket := m.newPacket(net, transport, buf)

		if newPacket == nil {
			return
		}

		m.source[uuid].packets <- newPacket
	}
}

func (m *Mysql) BPFFilter() string {
	return "tcp and port " + strconv.Itoa(m.port)
}

func (m *Mysql) Version() string {
	return Version
}

func (m *Mysql) SetFlag(flg []string) {

	c := len(flg)

	if c == 0 {
		return
	}
	if c>>1 == 0 {
		fmt.Println("ERR : Mysql Number of parameters")
		os.Exit(1)
	}
	for i := 0; i < c; i = i + 2 {
		key := flg[i]
		val := flg[i+1]

		switch key {
		case CmdPort:
			port, err := strconv.Atoi(val)
			m.port = port
			if err != nil {
				panic("ERR : port")
			}
			if port < 0 || port > 65535 {
				panic("ERR : port(0-65535)")
			}
			break
		default:
			panic("ERR : mysql's params")
		}
	}
}

func (m *Mysql) newPacket(net, transport gopacket.Flow, r io.Reader) *packet {

	//read packet
	var payload bytes.Buffer
	var seq uint8
	var err error
	if seq, err = m.resolvePacketTo(r, &payload); err != nil {
		return nil
	}

	//close stream
	if err == io.EOF {
		fmt.Println(net, transport, " close")
		return nil
	} else if err != nil {
		fmt.Println("ERR : Unknown stream", net, transport, ":", err)
	}

	//generate new packet
	var pk = packet{
		seq:     int(seq),
		length:  payload.Len(),
		payload: payload.Bytes(),
	}
	if transport.Src().String() == strconv.Itoa(Port) {
		pk.isClientFlow = false
	} else {
		pk.isClientFlow = true
	}

	return &pk
}

func (m *Mysql) resolvePacketTo(r io.Reader, w io.Writer) (uint8, error) {

	header := make([]byte, 4)
	if n, err := io.ReadFull(r, header); err != nil {
		if n == 0 && err == io.EOF {
			return 0, io.EOF
		}
		return 0, errors.New("ERR : Unknown stream")
	}

	length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

	var seq uint8
	seq = header[3]

	if n, err := io.CopyN(w, r, int64(length)); err != nil {
		return 0, errors.New("ERR : Unknown stream")
	} else if n != int64(length) {
		return 0, errors.New("ERR : Unknown stream")
	} else {
		return seq, nil
	}

	return seq, nil
}

func (stm *stream) resolve() {
	for {
		select {
		case packet := <-stm.packets:
			if packet.isClientFlow {
				stm.resolveClientPacket(packet.payload, packet.seq)
			} else {
				stm.resolveServerPacket(packet.payload, packet.seq)
			}
		}
	}
}

func (stm *stream) findStmtPacket(srv chan *packet, seq int) *packet {
	for {
		select {
		case packet, ok := <-stm.packets:
			if !ok {
				return nil
			}
			if packet.seq == seq {
				return packet
			}
		case <-time.After(5 * time.Second):
			return nil
		}
	}
}

func (stm *stream) resolveServerPacket(payload []byte, seq int) {

	var msg = ""
	switch payload[0] {

	case 0xff:
		errorCode := int(binary.LittleEndian.Uint16(payload[1:3]))
		errorMsg, _ := ReadStringFromByte(payload[4:])

		msg = GetNowStr(false) + "%s Err code:%s,Err msg:%s"
		msg = fmt.Sprintf(msg, ErrorPacket, strconv.Itoa(errorCode), strings.TrimSpace(errorMsg))

	case 0x00:
		var pos = 1
		l, _ := LengthBinary(payload[pos:])
		affectedRows := int(l)

		msg += GetNowStr(false) + "%s Effect Row:%s"
		msg = fmt.Sprintf(msg, OkPacket, strconv.Itoa(affectedRows))

	default:
		return
	}

	fmt.Println(msg)
}

func (stm *stream) resolveClientPacket(payload []byte, seq int) {

	var msg string
	switch payload[0] {

	case COM_INIT_DB:

		msg = fmt.Sprintf("USE %s;\n", payload[1:])
	case COM_DROP_DB:

		msg = fmt.Sprintf("Drop DB %s;\n", payload[1:])
	case COM_CREATE_DB, COM_QUERY:

		statement := string(payload[1:])
		msg = fmt.Sprintf("%s %s", ComQueryRequestPacket, statement)
	case COM_STMT_PREPARE:

		serverPacket := stm.findStmtPacket(stm.packets, seq+1)
		if serverPacket == nil {
			log.Println("ERR : Not found stm packet")
			return
		}

		//fetch stm id
		stmtID := binary.LittleEndian.Uint32(serverPacket.payload[1:5])
		stmt := &Stmt{
			ID:    stmtID,
			Query: string(payload[1:]),
		}

		//record stm sql
		stm.stmtMap[stmtID] = stmt
		stmt.FieldCount = binary.LittleEndian.Uint16(serverPacket.payload[5:7])
		stmt.ParamCount = binary.LittleEndian.Uint16(serverPacket.payload[7:9])
		stmt.Args = make([]interface{}, stmt.ParamCount)

		msg = PreparePacket + stmt.Query
	case COM_STMT_SEND_LONG_DATA:

		stmtID := binary.LittleEndian.Uint32(payload[1:5])
		paramId := binary.LittleEndian.Uint16(payload[5:7])
		stmt, _ := stm.stmtMap[stmtID]

		if stmt.Args[paramId] == nil {
			stmt.Args[paramId] = payload[7:]
		} else {
			if b, ok := stmt.Args[paramId].([]byte); ok {
				b = append(b, payload[7:]...)
				stmt.Args[paramId] = b
			}
		}
		return
	case COM_STMT_RESET:

		stmtID := binary.LittleEndian.Uint32(payload[1:5])
		stmt, _ := stm.stmtMap[stmtID]
		stmt.Args = make([]interface{}, stmt.ParamCount)
		return
	case COM_STMT_EXECUTE:

		var pos = 1
		stmtID := binary.LittleEndian.Uint32(payload[pos : pos+4])
		pos += 4
		var stmt *Stmt
		var ok bool
		if stmt, ok = stm.stmtMap[stmtID]; ok == false {
			log.Println("ERR : Not found stm id", stmtID)
			return
		}

		//params
		pos += 5
		if stmt.ParamCount > 0 {

			//(Null-Bitmap,len = (paramsCount + 7) / 8 byte)
			step := int((stmt.ParamCount + 7) / 8)
			nullBitmap := payload[pos : pos+step]
			pos += step

			//Parameter separator
			flag := payload[pos]

			pos++

			var pTypes []byte
			var pValues []byte

			//if flag == 1
			//n (len = paramsCount * 2 byte)
			if flag == 1 {
				pTypes = payload[pos : pos+int(stmt.ParamCount)*2]
				pos += int(stmt.ParamCount) * 2
				pValues = payload[pos:]
			}

			//bind params
			err := stmt.BindArgs(nullBitmap, pTypes, pValues)
			if err != nil {
				log.Println("ERR : Could not bind params", err)
			}
		}
		msg = string(stmt.WriteToText())
	default:
		return
	}

	fmt.Println(GetNowStr(true) + msg)
}