aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Paris <[email protected]>2015-05-10 18:00:23 -0500
committerEric Paris <[email protected]>2015-05-10 18:00:23 -0500
commit0ed81a961505a7dfaab5490049a7a324743e6f03 (patch)
tree91d96435841b18fdedbc0b4d6929f09f31398519
parentf1e68ce945b0710375b5cccee37318a3d13fdf8c (diff)
parentce8e092726fe09c73532462e5f8810a3552270ab (diff)
Merge pull request #22 from andronat/multiple_narmalization_requests
Call normalizeFlagName function only once
-rw-r--r--flag.go7
-rw-r--r--flag_test.go32
2 files changed, 29 insertions, 10 deletions
diff --git a/flag.go b/flag.go
index 362a20a..06fb7a8 100644
--- a/flag.go
+++ b/flag.go
@@ -421,7 +421,10 @@ func (f *FlagSet) VarP(value Value, name, shorthand, usage string) {
}
func (f *FlagSet) AddFlag(flag *Flag) {
- _, alreadythere := f.formal[f.normalizeFlagName(flag.Name)]
+ // Call normalizeFlagName function only once
+ var normalizedFlagName NormalizedName = f.normalizeFlagName(flag.Name)
+
+ _, alreadythere := f.formal[normalizedFlagName]
if alreadythere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg)
@@ -430,7 +433,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
if f.formal == nil {
f.formal = make(map[NormalizedName]*Flag)
}
- f.formal[f.normalizeFlagName(flag.Name)] = flag
+ f.formal[normalizedFlagName] = flag
if len(flag.Shorthand) == 0 {
return
diff --git a/flag_test.go b/flag_test.go
index b5956fa..efd6666 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -17,14 +17,15 @@ import (
)
var (
- test_bool = Bool("test_bool", false, "bool value")
- test_int = Int("test_int", 0, "int value")
- test_int64 = Int64("test_int64", 0, "int64 value")
- test_uint = Uint("test_uint", 0, "uint value")
- test_uint64 = Uint64("test_uint64", 0, "uint64 value")
- test_string = String("test_string", "0", "string value")
- test_float64 = Float64("test_float64", 0, "float64 value")
- test_duration = Duration("test_duration", 0, "time.Duration value")
+ test_bool = Bool("test_bool", false, "bool value")
+ test_int = Int("test_int", 0, "int value")
+ test_int64 = Int64("test_int64", 0, "int64 value")
+ test_uint = Uint("test_uint", 0, "uint value")
+ test_uint64 = Uint64("test_uint64", 0, "uint64 value")
+ test_string = String("test_string", "0", "string value")
+ test_float64 = Float64("test_float64", 0, "float64 value")
+ test_duration = Duration("test_duration", 0, "time.Duration value")
+ normalizeFlagNameInvocations = 0
)
func boolString(s string) string {
@@ -254,6 +255,8 @@ func replaceSeparators(name string, from []string, to string) string {
func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"}
name = replaceSeparators(name, seps, ".")
+ normalizeFlagNameInvocations++
+
return NormalizedName(name)
}
@@ -574,3 +577,16 @@ func TestDeprecatedFlagUsageNormalized(t *testing.T) {
t.Errorf("usageMsg not printed when using a deprecated flag!")
}
}
+
+// Name normalization function should be called only once on flag addition
+func TestMultipleNormalizeFlagNameInvocations(t *testing.T) {
+ normalizeFlagNameInvocations = 0
+
+ f := NewFlagSet("normalized", ContinueOnError)
+ f.SetNormalizeFunc(wordSepNormalizeFunc)
+ f.Bool("with_under_flag", false, "bool value")
+
+ if normalizeFlagNameInvocations != 1 {
+ t.Fatal("Expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations)
+ }
+}