mirror of
https://github.com/zr-hebo/sniffer-agent.git
synced 2025-08-07 14:39:02 +08:00
add vendor packages避免用户网络太慢无法下载编译
This commit is contained in:
359
vendor/github.com/zr-hebo/util-db/query_db.go
generated
vendored
Normal file
359
vendor/github.com/zr-hebo/util-db/query_db.go
generated
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
dbTypeMysql = "mysql"
|
||||
)
|
||||
|
||||
// Host 主机
|
||||
type Host struct {
|
||||
IP string `json:"ip"`
|
||||
Domain string `json:"domain"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// UnanimityHost id标示的主机
|
||||
type UnanimityHost struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
func (uh *UnanimityHost) String() string {
|
||||
return fmt.Sprintf("%s:%d", uh.Host, uh.Port)
|
||||
}
|
||||
|
||||
// UnanimityHostWithDomains 带域名的id标示的主机
|
||||
type UnanimityHostWithDomains struct {
|
||||
UnanimityHost
|
||||
IP string `json:"ip"`
|
||||
Domains []string `json:"domains"`
|
||||
}
|
||||
|
||||
// MysqlDB Mysql主机实例
|
||||
type MysqlDB struct {
|
||||
Host
|
||||
UserName string
|
||||
Passwd string
|
||||
DatabaseType string
|
||||
DBName string
|
||||
ConnectTimeout int
|
||||
}
|
||||
|
||||
// NewMysqlDB 创建MySQL数据库
|
||||
func NewMysqlDB() (md *MysqlDB) {
|
||||
md = new(MysqlDB)
|
||||
md.DatabaseType = dbTypeMysql
|
||||
return
|
||||
}
|
||||
|
||||
// NewMysqlDBWithAllParam 带参数创建MySQL数据库
|
||||
func NewMysqlDBWithAllParam(
|
||||
ip string, port int, userName, passwd, dbName string) (
|
||||
pmd *MysqlDB) {
|
||||
pmd = NewMysqlDB()
|
||||
pmd.IP = ip
|
||||
pmd.Port = port
|
||||
pmd.UserName = userName
|
||||
pmd.Passwd = passwd
|
||||
pmd.DBName = dbName
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetConnection 获取数据库连接
|
||||
func (md *MysqlDB) getConnection() (*sql.DB, error) {
|
||||
connStr := md.fillConnStr()
|
||||
|
||||
stmtDB, err := sql.Open(md.DatabaseType, connStr)
|
||||
if err != nil {
|
||||
if stmtDB != nil {
|
||||
stmtDB.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmtDB.SetMaxOpenConns(0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer cancel()
|
||||
if err := stmtDB.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmtDB, nil
|
||||
}
|
||||
|
||||
// GetConnection 获取数据库连接
|
||||
func (md *MysqlDB) getRealConnection(ctx context.Context) (*sql.Conn, error) {
|
||||
connStr := md.fillConnStr()
|
||||
|
||||
stmtDB, err := sql.Open(md.DatabaseType, connStr)
|
||||
if err != nil {
|
||||
if stmtDB != nil {
|
||||
stmtDB.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := stmtDB.Conn(ctx)
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type Field struct {
|
||||
Name string
|
||||
Type string
|
||||
}
|
||||
|
||||
// FieldType Common type include "STRING", "FLOAT", "INT", "BOOL"
|
||||
func (f *Field) FieldType() string {
|
||||
return f.Type
|
||||
}
|
||||
|
||||
type QueryRow struct {
|
||||
Fields []Field
|
||||
Record map[string]interface{}
|
||||
}
|
||||
|
||||
type QueryRows struct {
|
||||
Fields []Field
|
||||
Records []map[string]interface{}
|
||||
}
|
||||
|
||||
func newQueryRow() *QueryRow {
|
||||
queryRow := new(QueryRow)
|
||||
queryRow.Fields = make([]Field, 0)
|
||||
queryRow.Record = make(map[string]interface{})
|
||||
return queryRow
|
||||
}
|
||||
|
||||
func newQueryRows() *QueryRows {
|
||||
queryRows := new(QueryRows)
|
||||
queryRows.Fields = make([]Field, 0)
|
||||
queryRows.Records = make([]map[string]interface{}, 0)
|
||||
return queryRows
|
||||
}
|
||||
|
||||
// QueryRows 执行MySQL Query语句,返回多条数据
|
||||
func (md *MysqlDB) QueryRows(stmt string) (queryRows *QueryRows, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("query rows on %s:%d failed <-- %s", md.IP, md.Port, err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
connStr := md.fillConnStr()
|
||||
|
||||
db, err := sql.Open(md.DatabaseType, connStr)
|
||||
if db != nil {
|
||||
defer db.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawRows, err := db.Query(stmt)
|
||||
if rawRows != nil {
|
||||
defer rawRows.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
colTypes, err := rawRows.ColumnTypes()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fields := make([]Field, 0, len(colTypes))
|
||||
for _, colType := range colTypes {
|
||||
fields = append(fields, Field{Name: colType.Name(), Type: getDataType(colType.DatabaseTypeName())})
|
||||
}
|
||||
|
||||
queryRows = newQueryRows()
|
||||
queryRows.Fields = fields
|
||||
for rawRows.Next() {
|
||||
receiver := createReceiver(fields)
|
||||
err = rawRows.Scan(receiver...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
queryRows.Records = append(queryRows.Records, getRecordFromReceiver(receiver, fields))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createReceiver(fields []Field) (receiver []interface{}) {
|
||||
receiver = make([]interface{}, 0, len(fields))
|
||||
for _, field := range fields {
|
||||
switch field.Type {
|
||||
case "string":
|
||||
{
|
||||
var val sql.NullString
|
||||
receiver = append(receiver, &val)
|
||||
}
|
||||
case "int64":
|
||||
{
|
||||
var val sql.NullInt64
|
||||
receiver = append(receiver, &val)
|
||||
}
|
||||
case "float64":
|
||||
{
|
||||
var val sql.NullFloat64
|
||||
receiver = append(receiver, &val)
|
||||
}
|
||||
case "bool":
|
||||
{
|
||||
var val sql.NullBool
|
||||
receiver = append(receiver, &val)
|
||||
}
|
||||
default:
|
||||
var val sql.NullString
|
||||
receiver = append(receiver, &val)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getRecordFromReceiver(receiver []interface{}, fields []Field) (record map[string]interface{}) {
|
||||
record = make(map[string]interface{})
|
||||
for idx := 0; idx < len(fields); idx++ {
|
||||
field := fields[idx]
|
||||
value := receiver[idx]
|
||||
switch field.Type {
|
||||
case "string":
|
||||
{
|
||||
nullVal := value.(*sql.NullString)
|
||||
record[field.Name] = nil
|
||||
if nullVal.Valid {
|
||||
record[field.Name] = nullVal.String
|
||||
}
|
||||
}
|
||||
case "int64":
|
||||
{
|
||||
nullVal := value.(*sql.NullInt64)
|
||||
record[field.Name] = nil
|
||||
if nullVal.Valid {
|
||||
record[field.Name] = nullVal.Int64
|
||||
}
|
||||
}
|
||||
case "float64":
|
||||
{
|
||||
nullVal := value.(*sql.NullFloat64)
|
||||
record[field.Name] = nil
|
||||
if nullVal.Valid {
|
||||
record[field.Name] = nullVal.Float64
|
||||
}
|
||||
}
|
||||
case "bool":
|
||||
{
|
||||
nullVal := value.(*sql.NullBool)
|
||||
record[field.Name] = nil
|
||||
if nullVal.Valid {
|
||||
record[field.Name] = nullVal.Bool
|
||||
}
|
||||
}
|
||||
default:
|
||||
nullVal := value.(*sql.NullString)
|
||||
record[field.Name] = nil
|
||||
if nullVal.Valid {
|
||||
record[field.Name] = nullVal.String
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getDataType(dbColType string) (colType string) {
|
||||
var columnTypeDict = map[string]string{
|
||||
"VARCHAR": "string",
|
||||
"TEXT": "string",
|
||||
"NVARCHAR": "string",
|
||||
"DATETIME": "float64",
|
||||
"DECIMAL": "float64",
|
||||
"BOOL": "bool",
|
||||
"INT": "int64",
|
||||
"BIGINT": "int64",
|
||||
}
|
||||
|
||||
colType, ok := columnTypeDict[dbColType]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
colType = "string"
|
||||
return
|
||||
}
|
||||
|
||||
// QueryRow 执行MySQL Query语句,返回1条或0条数据
|
||||
func (md *MysqlDB) QueryRow(stmt string) (row *QueryRow, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("query row failed <-- %s", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
queryRows, err := md.QueryRows(stmt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(queryRows.Records) < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
row = newQueryRow()
|
||||
row.Fields = queryRows.Fields
|
||||
row.Record = queryRows.Records[0]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ExecChange 执行MySQL DML Query语句
|
||||
func (md *MysqlDB) ExecChange(stmt string, args ...interface{}) (
|
||||
result sql.Result, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("execute dml failed <-- %s", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := md.getRealConnection(ctx)
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err = conn.ExecContext(ctx, stmt, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func (md *MysqlDB) fillConnStr() string {
|
||||
dbServerInfoStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
||||
md.UserName, md.Passwd, md.IP, md.Port, md.DBName)
|
||||
if md.ConnectTimeout > 0 {
|
||||
dbServerInfoStr = fmt.Sprintf("%s?timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
|
||||
dbServerInfoStr, md.ConnectTimeout, md.ConnectTimeout, md.ConnectTimeout)
|
||||
}
|
||||
|
||||
return dbServerInfoStr
|
||||
}
|
Reference in New Issue
Block a user