mirror of
https://github.com/bjdgyc/anylink.git
synced 2025-08-08 02:11:08 +08:00
修改为sql数据库
This commit is contained in:
@@ -1,70 +1,111 @@
|
||||
package dbdata
|
||||
|
||||
import (
|
||||
"time"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/asdine/storm/v3/codec/json"
|
||||
"github.com/bjdgyc/anylink/base"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
var (
|
||||
sdb *storm.DB
|
||||
xdb *xorm.Engine
|
||||
)
|
||||
|
||||
func GetXdb() *xorm.Engine {
|
||||
return xdb
|
||||
}
|
||||
|
||||
func initDb() {
|
||||
var err error
|
||||
sdb, err = storm.Open(base.Cfg.DbFile, storm.Codec(json.Codec),
|
||||
storm.BoltOptions(0600, &bolt.Options{Timeout: 10 * time.Second}))
|
||||
xdb, err = xorm.NewEngine(base.Cfg.DbType, base.Cfg.DbDsn)
|
||||
// xdb.ShowSQL(true)
|
||||
if err != nil {
|
||||
base.Fatal(err)
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
err = sdb.Init(&User{})
|
||||
err = xdb.Sync2(&User{}, &Setting{}, &Group{}, &IpMap{})
|
||||
if err != nil {
|
||||
base.Fatal(err)
|
||||
}
|
||||
|
||||
// fmt.Println("s1")
|
||||
// fmt.Println("s1=============", err)
|
||||
}
|
||||
|
||||
func initData() {
|
||||
var (
|
||||
err error
|
||||
install bool
|
||||
err error
|
||||
)
|
||||
|
||||
// 判断是否初次使用
|
||||
err = Get(SettingBucket, Installed, &install)
|
||||
if err == nil && install {
|
||||
s := &Setting{}
|
||||
err = One("name", InstallName, s)
|
||||
if err == nil && s.Data == InstallData {
|
||||
// 已经安装过
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = Set(SettingBucket, Installed, true)
|
||||
}()
|
||||
|
||||
smtp := &SettingSmtp{
|
||||
Host: "127.0.0.1",
|
||||
Port: 25,
|
||||
From: "vpn@xx.com",
|
||||
err = addInitData()
|
||||
if err != nil {
|
||||
base.Fatal(err)
|
||||
}
|
||||
_ = SettingSet(smtp)
|
||||
}
|
||||
|
||||
func addInitData() error {
|
||||
var (
|
||||
err error
|
||||
)
|
||||
|
||||
sess := xdb.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
err = sess.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// SettingSmtp
|
||||
smtp := &SettingSmtp{
|
||||
Host: "127.0.0.1",
|
||||
Port: 25,
|
||||
From: "vpn@xx.com",
|
||||
Encryption: "None",
|
||||
}
|
||||
v, _ := json.Marshal(smtp)
|
||||
s := &Setting{Name: StructName(smtp), Data: string(v)}
|
||||
_, err = sess.InsertOne(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// SettingOther
|
||||
other := &SettingOther{
|
||||
LinkAddr: "vpn.xx.com",
|
||||
Banner: "您已接入公司网络,请按照公司规定使用。\n请勿进行非工作下载及视频行为!",
|
||||
AccountMail: accountMail,
|
||||
}
|
||||
_ = SettingSet(other)
|
||||
v, _ = json.Marshal(other)
|
||||
s = &Setting{Name: StructName(other), Data: string(v)}
|
||||
_, err = sess.InsertOne(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Install
|
||||
install := &Setting{Name: InstallName, Data: InstallData}
|
||||
_, err = sess.InsertOne(install)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sess.Commit()
|
||||
}
|
||||
|
||||
func CheckErrNotFound(err error) bool {
|
||||
return err == storm.ErrNotFound
|
||||
return err == ErrNotFound
|
||||
}
|
||||
|
||||
const accountMail = `<p>您好:</p>
|
||||
|
@@ -1,66 +1,84 @@
|
||||
package dbdata
|
||||
|
||||
import "github.com/asdine/storm/v3/index"
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
const PageSize = 10
|
||||
|
||||
func Save(data interface{}) error {
|
||||
return sdb.Save(data)
|
||||
var ErrNotFound = errors.New("ErrNotFound")
|
||||
|
||||
func Add(data interface{}) error {
|
||||
_, err := xdb.InsertOne(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func Update(data interface{}) error {
|
||||
return sdb.Update(data)
|
||||
}
|
||||
|
||||
func UpdateField(data interface{}, fieldName string, value interface{}) error {
|
||||
return sdb.UpdateField(data, fieldName, value)
|
||||
func Update(fieldName string, value interface{}, data interface{}) error {
|
||||
_, err := xdb.Where(fieldName+"=?", value).Update(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func Del(data interface{}) error {
|
||||
return sdb.DeleteStruct(data)
|
||||
_, err := xdb.Delete(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func Set(bucket, key string, data interface{}) error {
|
||||
return sdb.Set(bucket, key, data)
|
||||
func extract(data interface{}, fieldName string) interface{} {
|
||||
ref := reflect.ValueOf(data)
|
||||
r := &ref
|
||||
if r.Kind() == reflect.Ptr {
|
||||
e := r.Elem()
|
||||
r = &e
|
||||
}
|
||||
field := r.FieldByName(fieldName).Interface()
|
||||
return field
|
||||
}
|
||||
|
||||
func Get(bucket, key string, data interface{}) error {
|
||||
return sdb.Get(bucket, key, data)
|
||||
// 更新全部字段
|
||||
func Set(data interface{}) error {
|
||||
id := extract(data, "Id")
|
||||
_, err := xdb.ID(id).AllCols().Update(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func One(fieldName string, value interface{}, data interface{}) error {
|
||||
has, err := xdb.Where(fieldName+"=?", value).Get(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CountAll(data interface{}) int {
|
||||
n, _ := sdb.Count(data)
|
||||
return n
|
||||
n, _ := xdb.Count(data)
|
||||
return int(n)
|
||||
}
|
||||
|
||||
func One(fieldName string, value interface{}, to interface{}) error {
|
||||
return sdb.One(fieldName, value, to)
|
||||
}
|
||||
|
||||
func Find(fieldName string, value interface{}, to interface{}, options ...func(q *index.Options)) error {
|
||||
return sdb.Find(fieldName, value, to, options...)
|
||||
}
|
||||
|
||||
func All(to interface{}, limit, page int) error {
|
||||
opt := getOpt(limit, page)
|
||||
return sdb.All(to, opt)
|
||||
}
|
||||
|
||||
func Prefix(fieldName string, prefix string, to interface{}, limit, page int) error {
|
||||
opt := getOpt(limit, page)
|
||||
return sdb.Prefix(fieldName, prefix, to, opt)
|
||||
}
|
||||
|
||||
func getOpt(limit, page int) func(*index.Options) {
|
||||
skip := (page - 1) * limit
|
||||
opt := func(opt *index.Options) {
|
||||
opt.Reverse = true
|
||||
if limit > 0 {
|
||||
opt.Limit = limit
|
||||
}
|
||||
if skip > 0 {
|
||||
opt.Skip = skip
|
||||
}
|
||||
func Find(data interface{}, limit, page int) error {
|
||||
if limit == 0 {
|
||||
return xdb.Find(data)
|
||||
}
|
||||
return opt
|
||||
|
||||
start := (page - 1) * limit
|
||||
return xdb.Limit(limit, start).Find(data)
|
||||
}
|
||||
|
||||
func CountPrefix(fieldName string, prefix string, data interface{}) int {
|
||||
n, _ := xdb.Where(fieldName + " like '" + prefix + "%' ").Count(data)
|
||||
return int(n)
|
||||
}
|
||||
|
||||
func Prefix(fieldName string, prefix string, data interface{}, limit, page int) error {
|
||||
where := xdb.Where(fieldName + " like '" + prefix + "%' ")
|
||||
if limit == 0 {
|
||||
return where.Find(data)
|
||||
}
|
||||
|
||||
start := (page - 1) * limit
|
||||
return where.Limit(limit, start).Find(data)
|
||||
}
|
||||
|
@@ -11,12 +11,13 @@ import (
|
||||
|
||||
func preIpData() {
|
||||
tmpDb := path.Join(os.TempDir(), "anylink_test.db")
|
||||
base.Cfg.DbFile = tmpDb
|
||||
base.Cfg.DbType = "sqlite3"
|
||||
base.Cfg.DbDsn = tmpDb
|
||||
initDb()
|
||||
}
|
||||
|
||||
func closeIpdata() {
|
||||
sdb.Close()
|
||||
xdb.Close()
|
||||
tmpDb := path.Join(os.TempDir(), "anylink_test.db")
|
||||
os.Remove(tmpDb)
|
||||
}
|
||||
@@ -27,7 +28,7 @@ func TestDb(t *testing.T) {
|
||||
defer closeIpdata()
|
||||
|
||||
u := User{Username: "a"}
|
||||
err := Save(&u)
|
||||
err := Add(&u)
|
||||
ast.Nil(err)
|
||||
|
||||
ast.Equal(u.Id, 1)
|
||||
|
@@ -30,8 +30,8 @@ type ValData struct {
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Id int `json:"id" storm:"id,increment"`
|
||||
Name string `json:"name" storm:"unique"`
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
Name string `json:"name" xorm:"not null unique"`
|
||||
Note string `json:"note"`
|
||||
AllowLan bool `json:"allow_lan"`
|
||||
ClientDns []ValData `json:"client_dns"`
|
||||
@@ -46,7 +46,7 @@ type Group struct {
|
||||
|
||||
func GetGroupNames() []string {
|
||||
var datas []Group
|
||||
err := All(&datas, 0, 0)
|
||||
err := Find(&datas, 0, 0)
|
||||
if err != nil {
|
||||
base.Error(err)
|
||||
return nil
|
||||
@@ -116,7 +116,11 @@ func SetGroup(g *Group) error {
|
||||
g.LinkAcl = linkAcl
|
||||
|
||||
g.UpdatedAt = time.Now()
|
||||
err = Save(g)
|
||||
if g.Id > 0 {
|
||||
err = Set(g)
|
||||
} else {
|
||||
err = Add(g)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
@@ -1,14 +1,14 @@
|
||||
package dbdata
|
||||
|
||||
import (
|
||||
"net"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IpMap struct {
|
||||
Id int `json:"id" storm:"id,increment"`
|
||||
IpAddr net.IP `json:"ip_addr" storm:"unique"`
|
||||
MacAddr string `json:"mac_addr" storm:"unique"`
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
IpAddr string `json:"ip_addr" xorm:"not null unique"`
|
||||
MacAddr string `json:"mac_addr" xorm:"not null unique"`
|
||||
Username string `json:"username"`
|
||||
Keep bool `json:"keep"` // 保留 ip-mac 绑定
|
||||
KeepTime time.Time `json:"keep_time"`
|
||||
@@ -16,3 +16,19 @@ type IpMap struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func SetIpMap(v *IpMap) error {
|
||||
var err error
|
||||
|
||||
if len(v.IpAddr) < 4 || len(v.MacAddr) < 6 {
|
||||
return errors.New("IP或MAC错误")
|
||||
}
|
||||
|
||||
v.UpdatedAt = time.Now()
|
||||
if v.Id > 0 {
|
||||
err = Set(v)
|
||||
} else {
|
||||
err = Add(v)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@@ -1,35 +1,19 @@
|
||||
package dbdata
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
const (
|
||||
SettingBucket = "SettingBucket"
|
||||
Installed = "Installed"
|
||||
InstallName = "Install"
|
||||
InstallData = "OK"
|
||||
)
|
||||
|
||||
func StructName(data interface{}) string {
|
||||
ref := reflect.ValueOf(data)
|
||||
s := &ref
|
||||
if s.Kind() == reflect.Ptr {
|
||||
e := s.Elem()
|
||||
s = &e
|
||||
}
|
||||
name := s.Type().Name()
|
||||
return name
|
||||
}
|
||||
|
||||
func SettingSet(data interface{}) error {
|
||||
key := StructName(data)
|
||||
err := Set(SettingBucket, key, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func SettingGet(data interface{}) error {
|
||||
key := StructName(data)
|
||||
err := Get(SettingBucket, key, data)
|
||||
return err
|
||||
type Setting struct {
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
Name string `json:"name" xorm:"not null unique"`
|
||||
Data string `json:"data" xorm:"Text"`
|
||||
}
|
||||
|
||||
type SettingSmtp struct {
|
||||
@@ -46,3 +30,41 @@ type SettingOther struct {
|
||||
Banner string `json:"banner"`
|
||||
AccountMail string `json:"account_mail"`
|
||||
}
|
||||
|
||||
func StructName(data interface{}) string {
|
||||
ref := reflect.ValueOf(data)
|
||||
s := &ref
|
||||
if s.Kind() == reflect.Ptr {
|
||||
e := s.Elem()
|
||||
s = &e
|
||||
}
|
||||
name := s.Type().Name()
|
||||
return name
|
||||
}
|
||||
|
||||
func SettingAdd(data interface{}) error {
|
||||
name := StructName(data)
|
||||
v, _ := json.Marshal(data)
|
||||
s := Setting{Name: name, Data: string(v)}
|
||||
err := Add(&s)
|
||||
return err
|
||||
}
|
||||
|
||||
func SettingSet(data interface{}) error {
|
||||
name := StructName(data)
|
||||
v, _ := json.Marshal(data)
|
||||
s := Setting{Data: string(v)}
|
||||
err := Update("name", name, &s)
|
||||
return err
|
||||
}
|
||||
|
||||
func SettingGet(data interface{}) error {
|
||||
name := StructName(data)
|
||||
s := Setting{Name: name}
|
||||
err := One("name", name, &s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal([]byte(s.Data), data)
|
||||
return err
|
||||
}
|
||||
|
@@ -6,5 +6,5 @@ func Start() {
|
||||
}
|
||||
|
||||
func Stop() error {
|
||||
return sdb.Close()
|
||||
return xdb.Close()
|
||||
}
|
||||
|
@@ -11,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id int `json:"id" storm:"id,increment"`
|
||||
Username string `json:"username" storm:"unique"`
|
||||
Id int `json:"id" xorm:"pk autoincr not null"`
|
||||
Username string `json:"username" storm:"not null unique"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email string `json:"email"`
|
||||
// Password string `json:"password"`
|
||||
@@ -57,7 +57,11 @@ func SetUser(v *User) error {
|
||||
v.Groups = ng
|
||||
|
||||
v.UpdatedAt = time.Now()
|
||||
err = Save(v)
|
||||
if v.Id > 0 {
|
||||
err = Set(v)
|
||||
} else {
|
||||
err = Add(v)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
Reference in New Issue
Block a user