sniffer-agent/vendor/github.com/zr-hebo/util-http/receiver.go

210 lines
3.6 KiB
Go

package easyhttp
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"github.com/gorilla/mux"
)
// NewUnpacker 创建request参数解析器
func NewUnpacker(
req *http.Request, receiver interface{}, logger Logger) (
unpacker *Unpacker) {
unpacker = new(Unpacker)
unpacker.req = req
unpacker.receiver = receiver
unpacker.logger = logger
return
}
// Unpack 将request中的请求参数解析到结构体中
func (u *Unpacker) Unpack() (err error) {
if u.receiver == nil {
return
}
if u.req.Method == "GET" {
err = u.unpackGetParams()
} else if u.req.Method == "POST" {
err = u.unpackJSONParams()
}
return
}
// unpackGetParams 解析GET参数到接收器中
func (u *Unpacker) unpackGetParams() (err error) {
if err = u.req.ParseForm(); err != nil {
return err
}
rt := reflect.TypeOf(u.receiver)
rv := reflect.ValueOf(u.receiver)
if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Struct {
return u.unpackFieldFromParams(rv.Elem(), "")
}
return fmt.Errorf("解析参数类型需要为 *struct ,传入的是 %s", rt.String())
}
func (u *Unpacker) getFormVal(key string) (val string) {
vars := mux.Vars(u.req)
val = u.req.FormValue(key)
if val == "" && vars != nil {
val = vars[key]
}
return
}
func (u *Unpacker) unpackFieldFromParams(
field reflect.Value, varName string) (err error) {
rv := field
rt := field.Type()
switch rt.Kind() {
case reflect.Ptr:
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
u.unpackFieldFromParams(rv.Elem(), varName)
case reflect.Struct:
for i := 0; i < rt.NumField(); i++ {
tf := rt.Field(i)
key := tf.Tag.Get("json")
if key == "" {
key = strings.ToLower(tf.Name)
}
val := u.getFormVal(key)
if u.logger != nil {
u.logger.Debugf("key:%v value:%v", key, val)
}
rfv := rv.Field(i)
switch rfv.Kind() {
case reflect.Ptr:
if rfv.IsNil() {
rfv.Set(reflect.New(rfv.Type().Elem()))
}
err = u.unpackFieldFromParams(rfv.Elem(), key)
case reflect.Struct:
err = u.unpackFieldFromParams(rfv, key)
case reflect.Array, reflect.Map:
continue
default:
if len(val) < 1 {
continue
}
err = populate(rfv, val)
}
if err != nil {
break
}
}
case reflect.Array, reflect.Map:
break
default:
val := u.getFormVal(varName)
if len(val) > 0 {
err = populate(field, val)
}
}
return
}
func populate(rv reflect.Value, value string) (err error) {
switch rv.Kind() {
case reflect.String:
rv.SetString(value)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
rv.SetInt(i)
case reflect.Bool:
b, err := strconv.ParseBool(value)
if err != nil {
return err
}
rv.SetBool(b)
case reflect.Float32:
f, err := strconv.ParseFloat(value, 32)
if err != nil {
return err
}
rv.SetFloat(f)
case reflect.Float64:
f, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
rv.SetFloat(f)
default:
return fmt.Errorf("unsupported kind %s", rv.Type().String())
}
return
}
func (u *Unpacker) unpackJSONParams() (err error) {
if u.req == nil || u.req.Body == nil {
return fmt.Errorf("request body 为空")
}
if u.req.Body != nil {
defer u.req.Body.Close()
}
body, err := ioutil.ReadAll(u.req.Body)
if err != nil {
return
}
if u.logger != nil {
u.logger.Info(string(body))
}
if len(body) > 0 {
return json.Unmarshal(body, u.receiver)
}
return
}
func stringSliceContent(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}