sniffer-agent/vendor/github.com/zr-hebo/util-db/query_db.go

360 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"context"
"database/sql"
"fmt"
"time"
)
const (
dbTypeMysql = "mysql"
)
// Host 主机
type Host struct {
IP string `json:"ip"`
Domain string `json:"domain"`
Port int `json:"port"`
}
// UnanimityHost id标示的主机
type UnanimityHost struct {
Host string `json:"host"`
Port int `json:"port"`
}
func (uh *UnanimityHost) String() string {
return fmt.Sprintf("%s:%d", uh.Host, uh.Port)
}
// UnanimityHostWithDomains 带域名的id标示的主机
type UnanimityHostWithDomains struct {
UnanimityHost
IP string `json:"ip"`
Domains []string `json:"domains"`
}
// MysqlDB Mysql主机实例
type MysqlDB struct {
Host
UserName string
Passwd string
DatabaseType string
DBName string
ConnectTimeout int
}
// NewMysqlDB 创建MySQL数据库
func NewMysqlDB() (md *MysqlDB) {
md = new(MysqlDB)
md.DatabaseType = dbTypeMysql
return
}
// NewMysqlDBWithAllParam 带参数创建MySQL数据库
func NewMysqlDBWithAllParam(
ip string, port int, userName, passwd, dbName string) (
pmd *MysqlDB) {
pmd = NewMysqlDB()
pmd.IP = ip
pmd.Port = port
pmd.UserName = userName
pmd.Passwd = passwd
pmd.DBName = dbName
return
}
// GetConnection 获取数据库连接
func (md *MysqlDB) getConnection() (*sql.DB, error) {
connStr := md.fillConnStr()
stmtDB, err := sql.Open(md.DatabaseType, connStr)
if err != nil {
if stmtDB != nil {
stmtDB.Close()
}
return nil, err
}
stmtDB.SetMaxOpenConns(0)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
if err := stmtDB.PingContext(ctx); err != nil {
return nil, err
}
return stmtDB, nil
}
// GetConnection 获取数据库连接
func (md *MysqlDB) getRealConnection(ctx context.Context) (*sql.Conn, error) {
connStr := md.fillConnStr()
stmtDB, err := sql.Open(md.DatabaseType, connStr)
if err != nil {
if stmtDB != nil {
stmtDB.Close()
}
return nil, err
}
conn, err := stmtDB.Conn(ctx)
if err != nil {
if conn != nil {
conn.Close()
}
return nil, err
}
return conn, nil
}
type Field struct {
Name string
Type string
}
// FieldType Common type include "STRING", "FLOAT", "INT", "BOOL"
func (f *Field) FieldType() string {
return f.Type
}
type QueryRow struct {
Fields []Field
Record map[string]interface{}
}
type QueryRows struct {
Fields []Field
Records []map[string]interface{}
}
func newQueryRow() *QueryRow {
queryRow := new(QueryRow)
queryRow.Fields = make([]Field, 0)
queryRow.Record = make(map[string]interface{})
return queryRow
}
func newQueryRows() *QueryRows {
queryRows := new(QueryRows)
queryRows.Fields = make([]Field, 0)
queryRows.Records = make([]map[string]interface{}, 0)
return queryRows
}
// QueryRows 执行MySQL Query语句返回多条数据
func (md *MysqlDB) QueryRows(stmt string) (queryRows *QueryRows, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("query rows on %s:%d failed <-- %s", md.IP, md.Port, err.Error())
}
}()
connStr := md.fillConnStr()
db, err := sql.Open(md.DatabaseType, connStr)
if db != nil {
defer db.Close()
}
if err != nil {
return nil, err
}
rawRows, err := db.Query(stmt)
if rawRows != nil {
defer rawRows.Close()
}
if err != nil {
return
}
colTypes, err := rawRows.ColumnTypes()
if err != nil {
return
}
fields := make([]Field, 0, len(colTypes))
for _, colType := range colTypes {
fields = append(fields, Field{Name: colType.Name(), Type: getDataType(colType.DatabaseTypeName())})
}
queryRows = newQueryRows()
queryRows.Fields = fields
for rawRows.Next() {
receiver := createReceiver(fields)
err = rawRows.Scan(receiver...)
if err != nil {
return
}
queryRows.Records = append(queryRows.Records, getRecordFromReceiver(receiver, fields))
}
return
}
func createReceiver(fields []Field) (receiver []interface{}) {
receiver = make([]interface{}, 0, len(fields))
for _, field := range fields {
switch field.Type {
case "string":
{
var val sql.NullString
receiver = append(receiver, &val)
}
case "int64":
{
var val sql.NullInt64
receiver = append(receiver, &val)
}
case "float64":
{
var val sql.NullFloat64
receiver = append(receiver, &val)
}
case "bool":
{
var val sql.NullBool
receiver = append(receiver, &val)
}
default:
var val sql.NullString
receiver = append(receiver, &val)
}
}
return
}
func getRecordFromReceiver(receiver []interface{}, fields []Field) (record map[string]interface{}) {
record = make(map[string]interface{})
for idx := 0; idx < len(fields); idx++ {
field := fields[idx]
value := receiver[idx]
switch field.Type {
case "string":
{
nullVal := value.(*sql.NullString)
record[field.Name] = nil
if nullVal.Valid {
record[field.Name] = nullVal.String
}
}
case "int64":
{
nullVal := value.(*sql.NullInt64)
record[field.Name] = nil
if nullVal.Valid {
record[field.Name] = nullVal.Int64
}
}
case "float64":
{
nullVal := value.(*sql.NullFloat64)
record[field.Name] = nil
if nullVal.Valid {
record[field.Name] = nullVal.Float64
}
}
case "bool":
{
nullVal := value.(*sql.NullBool)
record[field.Name] = nil
if nullVal.Valid {
record[field.Name] = nullVal.Bool
}
}
default:
nullVal := value.(*sql.NullString)
record[field.Name] = nil
if nullVal.Valid {
record[field.Name] = nullVal.String
}
}
}
return
}
func getDataType(dbColType string) (colType string) {
var columnTypeDict = map[string]string{
"VARCHAR": "string",
"TEXT": "string",
"NVARCHAR": "string",
"DATETIME": "float64",
"DECIMAL": "float64",
"BOOL": "bool",
"INT": "int64",
"BIGINT": "int64",
}
colType, ok := columnTypeDict[dbColType]
if ok {
return
}
colType = "string"
return
}
// QueryRow 执行MySQL Query语句返回条或条数据
func (md *MysqlDB) QueryRow(stmt string) (row *QueryRow, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("query row failed <-- %s", err.Error())
}
}()
queryRows, err := md.QueryRows(stmt)
if err != nil {
return
}
if len(queryRows.Records) < 1 {
return
}
row = newQueryRow()
row.Fields = queryRows.Fields
row.Record = queryRows.Records[0]
return
}
// ExecChange 执行MySQL DML Query语句
func (md *MysqlDB) ExecChange(stmt string, args ...interface{}) (
result sql.Result, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("execute dml failed <-- %s", err.Error())
}
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := md.getRealConnection(ctx)
if conn != nil {
defer conn.Close()
}
if err != nil {
return
}
result, err = conn.ExecContext(ctx, stmt, args...)
return
}
func (md *MysqlDB) fillConnStr() string {
dbServerInfoStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
md.UserName, md.Passwd, md.IP, md.Port, md.DBName)
if md.ConnectTimeout > 0 {
dbServerInfoStr = fmt.Sprintf("%s?timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
dbServerInfoStr, md.ConnectTimeout, md.ConnectTimeout, md.ConnectTimeout)
}
return dbServerInfoStr
}