aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiego Becciolini <[email protected]>2017-10-01 23:02:52 +0100
committerAlbert Nigmatzianov <[email protected]>2017-10-02 00:02:52 +0200
commitbe7121dd7a937a85e1e4b1ddda6a3edce3466110 (patch)
tree477625a9e13468dd61981baeade64e055d93214f
parent5c2d607c75df0540c877524f9e82d3edb7748668 (diff)
Fix SetNormalizeFunc (#137)
Related to https://github.com/spf13/cobra/issues/521
-rw-r--r--flag.go32
-rw-r--r--flag_test.go64
2 files changed, 84 insertions, 12 deletions
diff --git a/flag.go b/flag.go
index 85c34a7..b0eb4ff 100644
--- a/flag.go
+++ b/flag.go
@@ -202,12 +202,18 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
f.normalizeNameFunc = n
f.sortedFormal = f.sortedFormal[:0]
- for k, v := range f.orderedFormal {
- delete(f.formal, NormalizedName(v.Name))
- nname := f.normalizeFlagName(v.Name)
- v.Name = string(nname)
- f.formal[nname] = v
- f.orderedFormal[k] = v
+ for fname, flag := range f.formal {
+ nname := f.normalizeFlagName(flag.Name)
+ if fname == nname {
+ continue
+ }
+ flag.Name = string(nname)
+ delete(f.formal, fname)
+ f.formal[nname] = flag
+ if _, set := f.actual[fname]; set {
+ delete(f.actual, fname)
+ f.actual[nname] = flag
+ }
}
}
@@ -440,13 +446,15 @@ func (f *FlagSet) Set(name, value string) error {
return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err)
}
- if f.actual == nil {
- f.actual = make(map[NormalizedName]*Flag)
- }
- f.actual[normalName] = flag
- f.orderedActual = append(f.orderedActual, flag)
+ if !flag.Changed {
+ if f.actual == nil {
+ f.actual = make(map[NormalizedName]*Flag)
+ }
+ f.actual[normalName] = flag
+ f.orderedActual = append(f.orderedActual, flag)
- flag.Changed = true
+ flag.Changed = true
+ }
if flag.Deprecated != "" {
fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
diff --git a/flag_test.go b/flag_test.go
index 29fec86..fe9a4a3 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -658,6 +658,70 @@ func TestNormalizationFuncShouldChangeFlagName(t *testing.T) {
}
}
+// Related to https://github.com/spf13/cobra/issues/521.
+func TestNormalizationSharedFlags(t *testing.T) {
+ f := NewFlagSet("set f", ContinueOnError)
+ g := NewFlagSet("set g", ContinueOnError)
+ nfunc := wordSepNormalizeFunc
+ testName := "valid_flag"
+ normName := nfunc(nil, testName)
+ if testName == string(normName) {
+ t.Error("TestNormalizationSharedFlags meaningless: the original and normalized flag names are identical:", testName)
+ }
+
+ f.Bool(testName, false, "bool value")
+ g.AddFlagSet(f)
+
+ f.SetNormalizeFunc(nfunc)
+ g.SetNormalizeFunc(nfunc)
+
+ if len(f.formal) != 1 {
+ t.Error("Normalizing flags should not result in duplications in the flag set:", f.formal)
+ }
+ if f.orderedFormal[0].Name != string(normName) {
+ t.Error("Flag name not normalized")
+ }
+ for k := range f.formal {
+ if k != "valid.flag" {
+ t.Errorf("The key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k)
+ }
+ }
+
+ if !reflect.DeepEqual(f.formal, g.formal) || !reflect.DeepEqual(f.orderedFormal, g.orderedFormal) {
+ t.Error("Two flag sets sharing the same flags should stay consistent after being normalized. Original set:", f.formal, "Duplicate set:", g.formal)
+ }
+}
+
+func TestNormalizationSetFlags(t *testing.T) {
+ f := NewFlagSet("normalized", ContinueOnError)
+ nfunc := wordSepNormalizeFunc
+ testName := "valid_flag"
+ normName := nfunc(nil, testName)
+ if testName == string(normName) {
+ t.Error("TestNormalizationSetFlags meaningless: the original and normalized flag names are identical:", testName)
+ }
+
+ f.Bool(testName, false, "bool value")
+ f.Set(testName, "true")
+ f.SetNormalizeFunc(nfunc)
+
+ if len(f.formal) != 1 {
+ t.Error("Normalizing flags should not result in duplications in the flag set:", f.formal)
+ }
+ if f.orderedFormal[0].Name != string(normName) {
+ t.Error("Flag name not normalized")
+ }
+ for k := range f.formal {
+ if k != "valid.flag" {
+ t.Errorf("The key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k)
+ }
+ }
+
+ if !reflect.DeepEqual(f.formal, f.actual) {
+ t.Error("The map of set flags should get normalized. Formal:", f.formal, "Actual:", f.actual)
+ }
+}
+
// Declare a user-defined flag type.
type flagVar []string