360 lines
7.1 KiB
Go
360 lines
7.1 KiB
Go
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
dbTypeMysql = "mysql"
|
||
)
|
||
|
||
// Host 主机
|
||
type Host struct {
|
||
IP string `json:"ip"`
|
||
Domain string `json:"domain"`
|
||
Port int `json:"port"`
|
||
}
|
||
|
||
// UnanimityHost id标示的主机
|
||
type UnanimityHost struct {
|
||
Host string `json:"host"`
|
||
Port int `json:"port"`
|
||
}
|
||
|
||
func (uh *UnanimityHost) String() string {
|
||
return fmt.Sprintf("%s:%d", uh.Host, uh.Port)
|
||
}
|
||
|
||
// UnanimityHostWithDomains 带域名的id标示的主机
|
||
type UnanimityHostWithDomains struct {
|
||
UnanimityHost
|
||
IP string `json:"ip"`
|
||
Domains []string `json:"domains"`
|
||
}
|
||
|
||
// MysqlDB Mysql主机实例
|
||
type MysqlDB struct {
|
||
Host
|
||
UserName string
|
||
Passwd string
|
||
DatabaseType string
|
||
DBName string
|
||
ConnectTimeout int
|
||
}
|
||
|
||
// NewMysqlDB 创建MySQL数据库
|
||
func NewMysqlDB() (md *MysqlDB) {
|
||
md = new(MysqlDB)
|
||
md.DatabaseType = dbTypeMysql
|
||
return
|
||
}
|
||
|
||
// NewMysqlDBWithAllParam 带参数创建MySQL数据库
|
||
func NewMysqlDBWithAllParam(
|
||
ip string, port int, userName, passwd, dbName string) (
|
||
pmd *MysqlDB) {
|
||
pmd = NewMysqlDB()
|
||
pmd.IP = ip
|
||
pmd.Port = port
|
||
pmd.UserName = userName
|
||
pmd.Passwd = passwd
|
||
pmd.DBName = dbName
|
||
|
||
return
|
||
}
|
||
|
||
// GetConnection 获取数据库连接
|
||
func (md *MysqlDB) getConnection() (*sql.DB, error) {
|
||
connStr := md.fillConnStr()
|
||
|
||
stmtDB, err := sql.Open(md.DatabaseType, connStr)
|
||
if err != nil {
|
||
if stmtDB != nil {
|
||
stmtDB.Close()
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
stmtDB.SetMaxOpenConns(0)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||
defer cancel()
|
||
if err := stmtDB.PingContext(ctx); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return stmtDB, nil
|
||
}
|
||
|
||
// GetConnection 获取数据库连接
|
||
func (md *MysqlDB) getRealConnection(ctx context.Context) (*sql.Conn, error) {
|
||
connStr := md.fillConnStr()
|
||
|
||
stmtDB, err := sql.Open(md.DatabaseType, connStr)
|
||
if err != nil {
|
||
if stmtDB != nil {
|
||
stmtDB.Close()
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
conn, err := stmtDB.Conn(ctx)
|
||
if err != nil {
|
||
if conn != nil {
|
||
conn.Close()
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return conn, nil
|
||
}
|
||
|
||
type Field struct {
|
||
Name string
|
||
Type string
|
||
}
|
||
|
||
// FieldType Common type include "STRING", "FLOAT", "INT", "BOOL"
|
||
func (f *Field) FieldType() string {
|
||
return f.Type
|
||
}
|
||
|
||
type QueryRow struct {
|
||
Fields []Field
|
||
Record map[string]interface{}
|
||
}
|
||
|
||
type QueryRows struct {
|
||
Fields []Field
|
||
Records []map[string]interface{}
|
||
}
|
||
|
||
func newQueryRow() *QueryRow {
|
||
queryRow := new(QueryRow)
|
||
queryRow.Fields = make([]Field, 0)
|
||
queryRow.Record = make(map[string]interface{})
|
||
return queryRow
|
||
}
|
||
|
||
func newQueryRows() *QueryRows {
|
||
queryRows := new(QueryRows)
|
||
queryRows.Fields = make([]Field, 0)
|
||
queryRows.Records = make([]map[string]interface{}, 0)
|
||
return queryRows
|
||
}
|
||
|
||
// QueryRows 执行MySQL Query语句,返回多条数据
|
||
func (md *MysqlDB) QueryRows(stmt string) (queryRows *QueryRows, err error) {
|
||
defer func() {
|
||
if err != nil {
|
||
err = fmt.Errorf("query rows on %s:%d failed <-- %s", md.IP, md.Port, err.Error())
|
||
}
|
||
}()
|
||
|
||
connStr := md.fillConnStr()
|
||
|
||
db, err := sql.Open(md.DatabaseType, connStr)
|
||
if db != nil {
|
||
defer db.Close()
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
rawRows, err := db.Query(stmt)
|
||
if rawRows != nil {
|
||
defer rawRows.Close()
|
||
}
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
colTypes, err := rawRows.ColumnTypes()
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
fields := make([]Field, 0, len(colTypes))
|
||
for _, colType := range colTypes {
|
||
fields = append(fields, Field{Name: colType.Name(), Type: getDataType(colType.DatabaseTypeName())})
|
||
}
|
||
|
||
queryRows = newQueryRows()
|
||
queryRows.Fields = fields
|
||
for rawRows.Next() {
|
||
receiver := createReceiver(fields)
|
||
err = rawRows.Scan(receiver...)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
queryRows.Records = append(queryRows.Records, getRecordFromReceiver(receiver, fields))
|
||
}
|
||
return
|
||
}
|
||
|
||
func createReceiver(fields []Field) (receiver []interface{}) {
|
||
receiver = make([]interface{}, 0, len(fields))
|
||
for _, field := range fields {
|
||
switch field.Type {
|
||
case "string":
|
||
{
|
||
var val sql.NullString
|
||
receiver = append(receiver, &val)
|
||
}
|
||
case "int64":
|
||
{
|
||
var val sql.NullInt64
|
||
receiver = append(receiver, &val)
|
||
}
|
||
case "float64":
|
||
{
|
||
var val sql.NullFloat64
|
||
receiver = append(receiver, &val)
|
||
}
|
||
case "bool":
|
||
{
|
||
var val sql.NullBool
|
||
receiver = append(receiver, &val)
|
||
}
|
||
default:
|
||
var val sql.NullString
|
||
receiver = append(receiver, &val)
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func getRecordFromReceiver(receiver []interface{}, fields []Field) (record map[string]interface{}) {
|
||
record = make(map[string]interface{})
|
||
for idx := 0; idx < len(fields); idx++ {
|
||
field := fields[idx]
|
||
value := receiver[idx]
|
||
switch field.Type {
|
||
case "string":
|
||
{
|
||
nullVal := value.(*sql.NullString)
|
||
record[field.Name] = nil
|
||
if nullVal.Valid {
|
||
record[field.Name] = nullVal.String
|
||
}
|
||
}
|
||
case "int64":
|
||
{
|
||
nullVal := value.(*sql.NullInt64)
|
||
record[field.Name] = nil
|
||
if nullVal.Valid {
|
||
record[field.Name] = nullVal.Int64
|
||
}
|
||
}
|
||
case "float64":
|
||
{
|
||
nullVal := value.(*sql.NullFloat64)
|
||
record[field.Name] = nil
|
||
if nullVal.Valid {
|
||
record[field.Name] = nullVal.Float64
|
||
}
|
||
}
|
||
case "bool":
|
||
{
|
||
nullVal := value.(*sql.NullBool)
|
||
record[field.Name] = nil
|
||
if nullVal.Valid {
|
||
record[field.Name] = nullVal.Bool
|
||
}
|
||
}
|
||
default:
|
||
nullVal := value.(*sql.NullString)
|
||
record[field.Name] = nil
|
||
if nullVal.Valid {
|
||
record[field.Name] = nullVal.String
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func getDataType(dbColType string) (colType string) {
|
||
var columnTypeDict = map[string]string{
|
||
"VARCHAR": "string",
|
||
"TEXT": "string",
|
||
"NVARCHAR": "string",
|
||
"DATETIME": "float64",
|
||
"DECIMAL": "float64",
|
||
"BOOL": "bool",
|
||
"INT": "int64",
|
||
"BIGINT": "int64",
|
||
}
|
||
|
||
colType, ok := columnTypeDict[dbColType]
|
||
if ok {
|
||
return
|
||
}
|
||
|
||
colType = "string"
|
||
return
|
||
}
|
||
|
||
// QueryRow 执行MySQL Query语句,返回1条或0条数据
|
||
func (md *MysqlDB) QueryRow(stmt string) (row *QueryRow, err error) {
|
||
defer func() {
|
||
if err != nil {
|
||
err = fmt.Errorf("query row failed <-- %s", err.Error())
|
||
}
|
||
}()
|
||
|
||
queryRows, err := md.QueryRows(stmt)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if len(queryRows.Records) < 1 {
|
||
return
|
||
}
|
||
|
||
row = newQueryRow()
|
||
row.Fields = queryRows.Fields
|
||
row.Record = queryRows.Records[0]
|
||
|
||
return
|
||
}
|
||
|
||
// ExecChange 执行MySQL DML Query语句
|
||
func (md *MysqlDB) ExecChange(stmt string, args ...interface{}) (
|
||
result sql.Result, err error) {
|
||
defer func() {
|
||
if err != nil {
|
||
err = fmt.Errorf("execute dml failed <-- %s", err.Error())
|
||
}
|
||
}()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||
defer cancel()
|
||
|
||
conn, err := md.getRealConnection(ctx)
|
||
if conn != nil {
|
||
defer conn.Close()
|
||
}
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
result, err = conn.ExecContext(ctx, stmt, args...)
|
||
return
|
||
}
|
||
|
||
func (md *MysqlDB) fillConnStr() string {
|
||
dbServerInfoStr := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
||
md.UserName, md.Passwd, md.IP, md.Port, md.DBName)
|
||
if md.ConnectTimeout > 0 {
|
||
dbServerInfoStr = fmt.Sprintf("%s?timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
|
||
dbServerInfoStr, md.ConnectTimeout, md.ConnectTimeout, md.ConnectTimeout)
|
||
}
|
||
|
||
return dbServerInfoStr
|
||
}
|