package build

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"math"
	"strings"
	"errors"
)

type Stmt struct {
	ID         uint32
	Query      string
	ParamCount uint16
	FieldCount uint16

	Args []interface{}
}

func (stmt *Stmt) WriteToText() []byte {

	var buf bytes.Buffer

	str := fmt.Sprintf("Stm id[%d]: '%s';\n", stmt.ID, stmt.Query)
	buf.WriteString(str)

	for i := 0; i < int(stmt.ParamCount); i++ {
		var str string
		switch stmt.Args[i].(type) {
		case nil:
			str = fmt.Sprintf("set @p%v = NULL;\n", i)
		case []byte:
			param := string(stmt.Args[i].([]byte))
			str = fmt.Sprintf("set @p%v = '%s';\n", i, strings.TrimSpace(param))
		default:
			str = fmt.Sprintf("set @p%v = %v;\n", i, stmt.Args[i])
		}
		buf.WriteString(str)
	}

	str = fmt.Sprintf("Execute stm id[%d]: ", stmt.ID)
	buf.WriteString(str)
	for i := 0; i < int(stmt.ParamCount); i++ {
		if i == 0 {
			buf.WriteString(" using ")
		}
		if i > 0 {
			buf.WriteString(", ")
		}
		str := fmt.Sprintf("@p%v", i)
		buf.WriteString(str)
	}
	buf.WriteString(";\n")

	str = fmt.Sprintf("Drop stm id[%d];\n", stmt.ID)
	buf.WriteString(str)

	return buf.Bytes()
}

func (stmt *Stmt) BindArgs(nullBitmap, paramTypes, paramValues []byte) error {

	args := stmt.Args
	pos := 0

	var v []byte
	var n = 0
	var isNull bool
	var err error

	for i := 0; i < int(stmt.ParamCount); i++ {

		if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
			args[i] = nil
			continue
		}

		typ := paramTypes[i<<1]
		unsigned := (paramTypes[(i<<1)+1] & 0x80) > 0

		switch typ {
		case MYSQL_TYPE_NULL:
			args[i] = nil
			continue

		case MYSQL_TYPE_TINY:

			value := paramValues[pos]
			if unsigned {
				args[i] = uint8(value)
			} else {
				args[i] = int8(value)
			}

			pos++
			continue

		case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:

			value := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
			if unsigned {
				args[i] = uint16(value)
			} else {
				args[i] = int16(value)
			}
			pos += 2
			continue

		case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:

			value := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
			if unsigned {
				args[i] = uint32(value)
			} else {
				args[i] = int32(value)
			}
			pos += 4
			continue

		case MYSQL_TYPE_LONGLONG:

			value := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
			if unsigned {
				args[i] = value
			} else {
				args[i] = int64(value)
			}
			pos += 8
			continue

		case MYSQL_TYPE_FLOAT:

			value := math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
			args[i] = float32(value)
			pos += 4
			continue

		case MYSQL_TYPE_DOUBLE:

			value := math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
			args[i] = value
			pos += 8
			continue

		case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL,
			MYSQL_TYPE_VARCHAR, MYSQL_TYPE_BIT,
			MYSQL_TYPE_ENUM, MYSQL_TYPE_SET,
			MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
			MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING,
			MYSQL_TYPE_GEOMETRY,
			MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:

			v, isNull, n, err = LengthEncodedString(paramValues[pos:])
			pos += n
			if err != nil {
				return err
			}

			if !isNull {
				args[i] = v
				continue
			} else {
				args[i] = nil
				continue
			}
		default:
			return errors.New(fmt.Sprintf("ERR : Unknown stm type %d", typ))
		}
	}
	return nil
}