mirror of
https://github.com/40t/go-sniffer.git
synced 2025-08-08 20:32:50 +08:00
Init
This commit is contained in:
175
plugSrc/mysql/build/stmt.go
Normal file
175
plugSrc/mysql/build/stmt.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
ID uint32
|
||||
Query string
|
||||
ParamCount uint16
|
||||
FieldCount uint16
|
||||
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
func (stmt *Stmt) WriteToText() []byte {
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
str := fmt.Sprintf("预处理编号[%d]: '%s';\n", stmt.ID, stmt.Query)
|
||||
buf.WriteString(str)
|
||||
|
||||
for i := 0; i < int(stmt.ParamCount); i++ {
|
||||
var str string
|
||||
switch stmt.Args[i].(type) {
|
||||
case nil:
|
||||
str = fmt.Sprintf("set @p%v = NULL;\n", i)
|
||||
case []byte:
|
||||
param := string(stmt.Args[i].([]byte))
|
||||
str = fmt.Sprintf("set @p%v = '%s';\n", i, strings.TrimSpace(param))
|
||||
default:
|
||||
str = fmt.Sprintf("set @p%v = %v;\n", i, stmt.Args[i])
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
str = fmt.Sprintf("执行预处理[%d]: ", stmt.ID)
|
||||
buf.WriteString(str)
|
||||
for i := 0; i < int(stmt.ParamCount); i++ {
|
||||
if i == 0 {
|
||||
buf.WriteString(" using ")
|
||||
}
|
||||
if i > 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
str := fmt.Sprintf("@p%v", i)
|
||||
buf.WriteString(str)
|
||||
}
|
||||
buf.WriteString(";\n")
|
||||
|
||||
str = fmt.Sprintf("丢弃预处理[%d];\n", stmt.ID)
|
||||
buf.WriteString(str)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (stmt *Stmt) BindArgs(nullBitmap, paramTypes, paramValues []byte) error {
|
||||
|
||||
args := stmt.Args
|
||||
pos := 0
|
||||
|
||||
var v []byte
|
||||
var n = 0
|
||||
var isNull bool
|
||||
var err error
|
||||
|
||||
for i := 0; i < int(stmt.ParamCount); i++ {
|
||||
|
||||
//判断参数是否为null
|
||||
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
|
||||
args[i] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
//参数类型
|
||||
typ := paramTypes[i<<1]
|
||||
unsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
|
||||
|
||||
switch typ {
|
||||
case MYSQL_TYPE_NULL:
|
||||
args[i] = nil
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_TINY:
|
||||
|
||||
value := paramValues[pos]
|
||||
if unsigned {
|
||||
args[i] = uint8(value)
|
||||
} else {
|
||||
args[i] = int8(value)
|
||||
}
|
||||
|
||||
pos++
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
|
||||
|
||||
value := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
|
||||
if unsigned {
|
||||
args[i] = uint16(value)
|
||||
} else {
|
||||
args[i] = int16(value)
|
||||
}
|
||||
pos += 2
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
|
||||
|
||||
value := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
|
||||
if unsigned {
|
||||
args[i] = uint32(value)
|
||||
} else {
|
||||
args[i] = int32(value)
|
||||
}
|
||||
pos += 4
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_LONGLONG:
|
||||
|
||||
value := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
|
||||
if unsigned {
|
||||
args[i] = value
|
||||
} else {
|
||||
args[i] = int64(value)
|
||||
}
|
||||
pos += 8
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_FLOAT:
|
||||
|
||||
value := math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
|
||||
args[i] = float32(value)
|
||||
pos += 4
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_DOUBLE:
|
||||
|
||||
value := math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
|
||||
args[i] = value
|
||||
pos += 8
|
||||
continue
|
||||
|
||||
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL,
|
||||
MYSQL_TYPE_VARCHAR, MYSQL_TYPE_BIT,
|
||||
MYSQL_TYPE_ENUM, MYSQL_TYPE_SET,
|
||||
MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
|
||||
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING,
|
||||
MYSQL_TYPE_GEOMETRY,
|
||||
MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:
|
||||
|
||||
v, isNull, n, err = LengthEncodedString(paramValues[pos:])
|
||||
pos += n
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isNull {
|
||||
args[i] = v
|
||||
continue
|
||||
} else {
|
||||
args[i] = nil
|
||||
continue
|
||||
}
|
||||
default:
|
||||
return errors.New(fmt.Sprintf("预处理未知类型 %d", typ))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user