diff --git a/server/base/cfg.go b/server/base/cfg.go index a853f1d..c176fe0 100644 --- a/server/base/cfg.go +++ b/server/base/cfg.go @@ -5,8 +5,6 @@ import ( "os" "path/filepath" "reflect" - - "github.com/spf13/viper" ) const ( @@ -31,7 +29,7 @@ var ( type ServerConfig struct { // LinkAddr string `json:"link_addr"` - ConfFile string `json:"conf_file"` + Conf string `json:"conf"` ServerAddr string `json:"server_addr"` ServerDTLSAddr string `json:"server_dtls_addr"` ServerDTLS bool `json:"server_dtls"` @@ -83,6 +81,10 @@ func initServerCfg() { // Cfg.FilesPath = getAbsPath(base, Cfg.FilesPath) // Cfg.LogPath = getAbsPath(base, Cfg.LogPath) + if Cfg.AdminPass == defaultPwd { + fmt.Fprintln(os.Stderr, "=== 使用默认的admin_pass有安全风险,请设置新的admin_pass ===") + } + if Cfg.JwtSecret == defaultJwt { fmt.Fprintln(os.Stderr, "=== 使用默认的jwt_secret有安全风险,请设置新的jwt_secret ===") } @@ -116,19 +118,18 @@ func initCfg() { for _, v := range configs { if v.Name == tag { if v.Typ == cfgStr { - value.SetString(viper.GetString(v.Name)) + value.SetString(linkViper.GetString(v.Name)) } if v.Typ == cfgInt { - value.SetInt(int64(viper.GetInt(v.Name))) + value.SetInt(int64(linkViper.GetInt(v.Name))) } if v.Typ == cfgBool { - value.SetBool(viper.GetBool(v.Name)) + value.SetBool(linkViper.GetBool(v.Name)) } } } } - Cfg.ConfFile = cfgFile initServerCfg() } @@ -152,9 +153,6 @@ func ServerCfg2Slice() []SCfg { value := s.Field(i) tag := field.Tag.Get("json") usage, env := getUsageEnv(tag) - // if usage == "" { - // continue - // } datas = append(datas, SCfg{Name: tag, Env: env, Info: usage, Data: value.Interface()}) } diff --git a/server/base/cmd.go b/server/base/cmd.go index 2e34649..0abdcb2 100644 --- a/server/base/cmd.go +++ b/server/base/cmd.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "reflect" "runtime" "strings" @@ -15,8 +16,6 @@ import ( var ( // 提交id CommitId string - // 配置文件 - cfgFile string // pass明文 passwd string // 生成密钥 @@ -29,11 +28,14 @@ var ( // Used for flags. runSrv bool - rootCmd *cobra.Command + linkViper *viper.Viper + rootCmd *cobra.Command ) // Execute executes the root command. func execute() { + initCmd() + err := rootCmd.Execute() if err != nil { fmt.Println(err) @@ -41,13 +43,25 @@ func execute() { } // viper.Debug() + ref := reflect.ValueOf(linkViper) + s := ref.Elem() + ee := s.FieldByName("env") + if ee.Kind() != reflect.Map { + panic("Viper env is err") + } + rr := ee.MapRange() + for rr.Next() { + // fmt.Println(rr.Key(), rr.Value()) + envs[rr.Key().String()] = rr.Value().String() + } if !runSrv { os.Exit(0) } } -func init() { +func initCmd() { + linkViper = viper.New() rootCmd = &cobra.Command{ Use: "anylink", Short: "AnyLink VPN Server", @@ -58,11 +72,9 @@ func init() { }, } - viper.SetEnvPrefix("link") + linkViper.SetEnvPrefix("link") // 基础配置 - rootCmd.Flags().StringVarP(&cfgFile, "conf", "c", "./conf/server.toml", "config file") - _ = viper.BindEnv("conf") for _, v := range configs { if v.Typ == cfgStr { @@ -75,24 +87,25 @@ func init() { rootCmd.Flags().BoolP(v.Name, v.Short, v.ValBool, v.Usage) } - _ = viper.BindPFlag(v.Name, rootCmd.Flags().Lookup(v.Name)) - _ = viper.BindEnv(v.Name) + _ = linkViper.BindPFlag(v.Name, rootCmd.Flags().Lookup(v.Name)) + _ = linkViper.BindEnv(v.Name) // viper.SetDefault(v.Name, v.Value) } rootCmd.AddCommand(initToolCmd()) cobra.OnInitialize(func() { - viper.SetConfigFile(cfgFile) - viper.AutomaticEnv() + linkViper.AutomaticEnv() + conf := linkViper.GetString("conf") - _, err := os.Stat(cfgFile) + _, err := os.Stat(conf) if errors.Is(err, os.ErrNotExist) { // 没有配置文件,不做处理 return } - err = viper.ReadInConfig() + linkViper.SetConfigFile(conf) + err = linkViper.ReadInConfig() if err != nil { fmt.Println("Using config file:", err) } @@ -124,7 +137,7 @@ func initToolCmd() *cobra.Command { pass, _ := utils.PasswordHash(passwd) fmt.Printf("Passwd:%s\n", pass) case debug: - viper.Debug() + linkViper.Debug() default: fmt.Println("Using [anylink tool -h] for help") } diff --git a/server/base/config.go b/server/base/config.go index 24e72f5..958e9cf 100644 --- a/server/base/config.go +++ b/server/base/config.go @@ -6,6 +6,7 @@ const ( cfgBool defaultJwt = "abcdef.0123456789.abcdef" + defaultPwd = "$2a$10$UQ7C.EoPifDeJh6d8.31TeSPQU7hM/NOM2nixmBucJpAuXDQNqNke" ) type config struct { @@ -19,6 +20,7 @@ type config struct { } var configs = []config{ + {Typ: cfgStr, Name: "conf", Usage: "config file", ValStr: "./conf/server.toml", Short: "c"}, {Typ: cfgStr, Name: "server_addr", Usage: "服务监听地址", ValStr: ":443"}, {Typ: cfgBool, Name: "server_dtls", Usage: "开启DTLS", ValBool: false}, {Typ: cfgStr, Name: "server_dtls_addr", Usage: "DTLS监听地址", ValStr: ":4433"}, @@ -34,7 +36,7 @@ var configs = []config{ {Typ: cfgBool, Name: "pprof", Usage: "开启pprof", ValBool: false}, {Typ: cfgStr, Name: "issuer", Usage: "系统名称", ValStr: "XX公司VPN"}, {Typ: cfgStr, Name: "admin_user", Usage: "管理用户名", ValStr: "admin"}, - {Typ: cfgStr, Name: "admin_pass", Usage: "管理用户密码", ValStr: "$2a$10$UQ7C.EoPifDeJh6d8.31TeSPQU7hM/NOM2nixmBucJpAuXDQNqNke"}, + {Typ: cfgStr, Name: "admin_pass", Usage: "管理用户密码", ValStr: defaultPwd}, {Typ: cfgStr, Name: "jwt_secret", Usage: "JWT密钥", ValStr: defaultJwt}, {Typ: cfgStr, Name: "link_mode", Usage: "虚拟网络类型", ValStr: "tun"}, {Typ: cfgStr, Name: "ipv4_cidr", Usage: "ip地址网段", ValStr: "192.168.10.0/24"},