aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flag.go52
-rw-r--r--flag_test.go35
2 files changed, 76 insertions, 11 deletions
diff --git a/flag.go b/flag.go
index 746af63..7ebf33e 100644
--- a/flag.go
+++ b/flag.go
@@ -134,10 +134,16 @@ type FlagSet struct {
// a custom error handler.
Usage func()
+ // SortFlags is used to indicate, if user wants to have sorted flags in
+ // help/usage messages.
+ SortFlags bool
+
name string
parsed bool
actual map[NormalizedName]*Flag
+ orderedActual []*Flag
formal map[NormalizedName]*Flag
+ orderedFormal []*Flag
shorthands map[byte]*Flag
args []string // arguments after flags
argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
@@ -156,7 +162,7 @@ type Flag struct {
Value Value // value as set
DefValue string // default value (as text); for usage message
Changed bool // If the user set the value (or if left to default)
- NoOptDefVal string //default value (as text); if the flag is on the command line without any options
+ NoOptDefVal string // default value (as text); if the flag is on the command line without any options
Deprecated string // If this flag is deprecated, this string is the new or now thing to use
Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text
ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use
@@ -194,10 +200,12 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
// "--getUrl" which may also be translated to "geturl" and everything will work.
func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
f.normalizeNameFunc = n
+ f.orderedFormal = f.orderedFormal[:0]
for k, v := range f.formal {
delete(f.formal, k)
nname := f.normalizeFlagName(string(k))
f.formal[nname] = v
+ f.orderedFormal = append(f.orderedFormal, v)
v.Name = string(nname)
}
}
@@ -229,10 +237,18 @@ func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}
-// VisitAll visits the flags in lexicographical order, calling fn for each.
+// VisitAll visits the flags in lexicographical order or
+// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) {
- for _, flag := range sortFlags(f.formal) {
+ var flags []*Flag
+ if f.SortFlags {
+ flags = sortFlags(f.formal)
+ } else {
+ flags = f.orderedFormal
+ }
+
+ for _, flag := range flags {
fn(flag)
}
}
@@ -253,22 +269,32 @@ func (f *FlagSet) HasAvailableFlags() bool {
return false
}
-// VisitAll visits the command-line flags in lexicographical order, calling
-// fn for each. It visits all flags, even those not set.
+// VisitAll visits the command-line flags in lexicographical order or
+// in primordial order if f.SortFlags is false, calling fn for each.
+// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn)
}
-// Visit visits the flags in lexicographical order, calling fn for each.
+// Visit visits the flags in lexicographical order or
+// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) {
- for _, flag := range sortFlags(f.actual) {
+ var flags []*Flag
+ if f.SortFlags {
+ flags = sortFlags(f.actual)
+ } else {
+ flags = f.orderedActual
+ }
+
+ for _, flag := range flags {
fn(flag)
}
}
-// Visit visits the command-line flags in lexicographical order, calling fn
-// for each. It visits only those flags that have been set.
+// Visit visits the command-line flags in lexicographical order or
+// in primordial order if f.SortFlags is false, calling fn for each.
+// It visits only those flags that have been set.
func Visit(fn func(*Flag)) {
CommandLine.Visit(fn)
}
@@ -373,6 +399,7 @@ func (f *FlagSet) Set(name, value string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[normalName] = flag
+ f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
@@ -729,6 +756,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
flag.Name = string(normalizedFlagName)
f.formal[normalizedFlagName] = flag
+ f.orderedFormal = append(f.orderedFormal, flag)
if len(flag.Shorthand) == 0 {
return
@@ -807,6 +835,7 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[f.normalizeFlagName(flag.Name)] = flag
+ f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
@@ -1036,14 +1065,15 @@ func Parsed() bool {
// CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)
-// NewFlagSet returns a new, empty flag set with the specified name and
-// error handling property.
+// NewFlagSet returns a new, empty flag set with the specified name,
+// error handling property and SortFlags set to true.
func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet {
f := &FlagSet{
name: name,
errorHandling: errorHandling,
argsLenAtDash: -1,
interspersed: true,
+ SortFlags: true,
}
return f
}
diff --git a/flag_test.go b/flag_test.go
index b83a0ed..55dd6c6 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -1004,3 +1004,38 @@ func TestPrintDefaults(t *testing.T) {
t.Errorf("got %q want %q\n", got, defaultOutput)
}
}
+
+func TestVisitAllFlagOrder(t *testing.T) {
+ fs := NewFlagSet("TestVisitAllFlagOrder", ContinueOnError)
+ fs.SortFlags = false
+ names := []string{"C", "B", "A", "D"}
+ for _, name := range names {
+ fs.Bool(name, false, "")
+ }
+
+ i := 0
+ fs.VisitAll(func(f *Flag) {
+ if names[i] != f.Name {
+ t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
+ }
+ i++
+ })
+}
+
+func TestVisitFlagOrder(t *testing.T) {
+ fs := NewFlagSet("TestVisitFlagOrder", ContinueOnError)
+ fs.SortFlags = false
+ names := []string{"C", "B", "A", "D"}
+ for _, name := range names {
+ fs.Bool(name, false, "")
+ fs.Set(name, "true")
+ }
+
+ i := 0
+ fs.Visit(func(f *Flag) {
+ if names[i] != f.Name {
+ t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
+ }
+ i++
+ })
+}