go-sniffer/plugSrc/mysql/build/stmt.go

176 lines
3.4 KiB
Go

package build
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strings"
"errors"
)
type Stmt struct {
ID uint32
Query string
ParamCount uint16
FieldCount uint16
Args []interface{}
}
func (stmt *Stmt) WriteToText() []byte {
var buf bytes.Buffer
str := fmt.Sprintf("预处理编号[%d]: '%s';\n", stmt.ID, stmt.Query)
buf.WriteString(str)
for i := 0; i < int(stmt.ParamCount); i++ {
var str string
switch stmt.Args[i].(type) {
case nil:
str = fmt.Sprintf("set @p%v = NULL;\n", i)
case []byte:
param := string(stmt.Args[i].([]byte))
str = fmt.Sprintf("set @p%v = '%s';\n", i, strings.TrimSpace(param))
default:
str = fmt.Sprintf("set @p%v = %v;\n", i, stmt.Args[i])
}
buf.WriteString(str)
}
str = fmt.Sprintf("执行预处理[%d]: ", stmt.ID)
buf.WriteString(str)
for i := 0; i < int(stmt.ParamCount); i++ {
if i == 0 {
buf.WriteString(" using ")
}
if i > 0 {
buf.WriteString(", ")
}
str := fmt.Sprintf("@p%v", i)
buf.WriteString(str)
}
buf.WriteString(";\n")
str = fmt.Sprintf("丢弃预处理[%d];\n", stmt.ID)
buf.WriteString(str)
return buf.Bytes()
}
func (stmt *Stmt) BindArgs(nullBitmap, paramTypes, paramValues []byte) error {
args := stmt.Args
pos := 0
var v []byte
var n = 0
var isNull bool
var err error
for i := 0; i < int(stmt.ParamCount); i++ {
//判断参数是否为null
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
args[i] = nil
continue
}
//参数类型
typ := paramTypes[i<<1]
unsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
switch typ {
case MYSQL_TYPE_NULL:
args[i] = nil
continue
case MYSQL_TYPE_TINY:
value := paramValues[pos]
if unsigned {
args[i] = uint8(value)
} else {
args[i] = int8(value)
}
pos++
continue
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
value := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
if unsigned {
args[i] = uint16(value)
} else {
args[i] = int16(value)
}
pos += 2
continue
case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
value := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
if unsigned {
args[i] = uint32(value)
} else {
args[i] = int32(value)
}
pos += 4
continue
case MYSQL_TYPE_LONGLONG:
value := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
if unsigned {
args[i] = value
} else {
args[i] = int64(value)
}
pos += 8
continue
case MYSQL_TYPE_FLOAT:
value := math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
args[i] = float32(value)
pos += 4
continue
case MYSQL_TYPE_DOUBLE:
value := math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
args[i] = value
pos += 8
continue
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL,
MYSQL_TYPE_VARCHAR, MYSQL_TYPE_BIT,
MYSQL_TYPE_ENUM, MYSQL_TYPE_SET,
MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING,
MYSQL_TYPE_GEOMETRY,
MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:
v, isNull, n, err = LengthEncodedString(paramValues[pos:])
pos += n
if err != nil {
return err
}
if !isNull {
args[i] = v
continue
} else {
args[i] = nil
continue
}
default:
return errors.New(fmt.Sprintf("预处理未知类型 %d", typ))
}
}
return nil
}