flags.go 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /*
  2. Copyright 2016 The Rook Authors. All rights reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package flags
  14. import (
  15. "fmt"
  16. "os"
  17. "regexp"
  18. "strings"
  19. "github.com/coreos/pkg/capnslog"
  20. "github.com/spf13/cobra"
  21. "github.com/spf13/pflag"
  22. )
  23. var (
  24. logger = capnslog.NewPackageLogger("github.com/rook/rook", "op-flags")
  25. )
  26. func VerifyRequiredFlags(cmd *cobra.Command, requiredFlags []string) error {
  27. var missingFlags []string
  28. for _, reqFlag := range requiredFlags {
  29. val, err := cmd.Flags().GetString(reqFlag)
  30. if err != nil || val == "" {
  31. missingFlags = append(missingFlags, reqFlag)
  32. }
  33. }
  34. return createRequiredFlagError(cmd.Name(), missingFlags)
  35. }
  36. func createRequiredFlagError(name string, flags []string) error {
  37. if len(flags) == 0 {
  38. return nil
  39. }
  40. if len(flags) == 1 {
  41. return fmt.Errorf("%s is required for %s", flags[0], name)
  42. }
  43. return fmt.Errorf("%s are required for %s", strings.Join(flags, ","), name)
  44. }
  45. func SetFlagsFromEnv(flags *pflag.FlagSet, prefix string) {
  46. var errorFlag bool
  47. var err error
  48. flags.VisitAll(func(f *pflag.Flag) {
  49. envVar := prefix + "_" + strings.Replace(strings.ToUpper(f.Name), "-", "_", -1)
  50. value := os.Getenv(envVar)
  51. if value != "" {
  52. // Set the environment variable. Will override default values, but be overridden by command line parameters.
  53. if err = flags.Set(f.Name, value); err != nil {
  54. errorFlag = true
  55. }
  56. }
  57. })
  58. if errorFlag {
  59. logger.Error("failed to set flag ", err)
  60. }
  61. }
  62. // GetFlagsAndValues returns all flags and their values as a slice with elements in the format of
  63. // "--<flag>=<value>"
  64. func GetFlagsAndValues(flags *pflag.FlagSet, excludeFilter string) []string {
  65. var flagValues []string
  66. flags.VisitAll(func(f *pflag.Flag) {
  67. val := f.Value.String()
  68. if excludeFilter != "" {
  69. if matched, _ := regexp.Match(excludeFilter, []byte(f.Name)); matched {
  70. val = "*****"
  71. }
  72. }
  73. flagValues = append(flagValues, fmt.Sprintf("--%s=%s", f.Name, val))
  74. })
  75. return flagValues
  76. }