From d90f37a48761fe767528f31db1955e4f795d652f Mon Sep 17 00:00:00 2001 From: Albert Nigmatzianov Date: Sun, 26 Mar 2017 00:48:22 +0500 Subject: Add SortFlags option (#113) Fixes https://github.com/spf13/cobra/issues/316 --- flag.go | 52 +++++++++++++++++++++++++++++++++++++++++----------- flag_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/flag.go b/flag.go index 746af63..7ebf33e 100644 --- a/flag.go +++ b/flag.go @@ -134,10 +134,16 @@ type FlagSet struct { // a custom error handler. Usage func() + // SortFlags is used to indicate, if user wants to have sorted flags in + // help/usage messages. + SortFlags bool + name string parsed bool actual map[NormalizedName]*Flag + orderedActual []*Flag formal map[NormalizedName]*Flag + orderedFormal []*Flag shorthands map[byte]*Flag args []string // arguments after flags argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- @@ -156,7 +162,7 @@ type Flag struct { Value Value // value as set DefValue string // default value (as text); for usage message Changed bool // If the user set the value (or if left to default) - NoOptDefVal string //default value (as text); if the flag is on the command line without any options + NoOptDefVal string // default value (as text); if the flag is on the command line without any options Deprecated string // If this flag is deprecated, this string is the new or now thing to use Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use @@ -194,10 +200,12 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag { // "--getUrl" which may also be translated to "geturl" and everything will work. func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { f.normalizeNameFunc = n + f.orderedFormal = f.orderedFormal[:0] for k, v := range f.formal { delete(f.formal, k) nname := f.normalizeFlagName(string(k)) f.formal[nname] = v + f.orderedFormal = append(f.orderedFormal, v) v.Name = string(nname) } } @@ -229,10 +237,18 @@ func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } -// VisitAll visits the flags in lexicographical order, calling fn for each. +// VisitAll visits the flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. // It visits all flags, even those not set. func (f *FlagSet) VisitAll(fn func(*Flag)) { - for _, flag := range sortFlags(f.formal) { + var flags []*Flag + if f.SortFlags { + flags = sortFlags(f.formal) + } else { + flags = f.orderedFormal + } + + for _, flag := range flags { fn(flag) } } @@ -253,22 +269,32 @@ func (f *FlagSet) HasAvailableFlags() bool { return false } -// VisitAll visits the command-line flags in lexicographical order, calling -// fn for each. It visits all flags, even those not set. +// VisitAll visits the command-line flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits all flags, even those not set. func VisitAll(fn func(*Flag)) { CommandLine.VisitAll(fn) } -// Visit visits the flags in lexicographical order, calling fn for each. +// Visit visits the flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. // It visits only those flags that have been set. func (f *FlagSet) Visit(fn func(*Flag)) { - for _, flag := range sortFlags(f.actual) { + var flags []*Flag + if f.SortFlags { + flags = sortFlags(f.actual) + } else { + flags = f.orderedActual + } + + for _, flag := range flags { fn(flag) } } -// Visit visits the command-line flags in lexicographical order, calling fn -// for each. It visits only those flags that have been set. +// Visit visits the command-line flags in lexicographical order or +// in primordial order if f.SortFlags is false, calling fn for each. +// It visits only those flags that have been set. func Visit(fn func(*Flag)) { CommandLine.Visit(fn) } @@ -373,6 +399,7 @@ func (f *FlagSet) Set(name, value string) error { f.actual = make(map[NormalizedName]*Flag) } f.actual[normalName] = flag + f.orderedActual = append(f.orderedActual, flag) flag.Changed = true if len(flag.Deprecated) > 0 { fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) @@ -729,6 +756,7 @@ func (f *FlagSet) AddFlag(flag *Flag) { flag.Name = string(normalizedFlagName) f.formal[normalizedFlagName] = flag + f.orderedFormal = append(f.orderedFormal, flag) if len(flag.Shorthand) == 0 { return @@ -807,6 +835,7 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error { f.actual = make(map[NormalizedName]*Flag) } f.actual[f.normalizeFlagName(flag.Name)] = flag + f.orderedActual = append(f.orderedActual, flag) flag.Changed = true if len(flag.Deprecated) > 0 { fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) @@ -1036,14 +1065,15 @@ func Parsed() bool { // CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) -// NewFlagSet returns a new, empty flag set with the specified name and -// error handling property. +// NewFlagSet returns a new, empty flag set with the specified name, +// error handling property and SortFlags set to true. func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { f := &FlagSet{ name: name, errorHandling: errorHandling, argsLenAtDash: -1, interspersed: true, + SortFlags: true, } return f } diff --git a/flag_test.go b/flag_test.go index b83a0ed..55dd6c6 100644 --- a/flag_test.go +++ b/flag_test.go @@ -1004,3 +1004,38 @@ func TestPrintDefaults(t *testing.T) { t.Errorf("got %q want %q\n", got, defaultOutput) } } + +func TestVisitAllFlagOrder(t *testing.T) { + fs := NewFlagSet("TestVisitAllFlagOrder", ContinueOnError) + fs.SortFlags = false + names := []string{"C", "B", "A", "D"} + for _, name := range names { + fs.Bool(name, false, "") + } + + i := 0 + fs.VisitAll(func(f *Flag) { + if names[i] != f.Name { + t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name) + } + i++ + }) +} + +func TestVisitFlagOrder(t *testing.T) { + fs := NewFlagSet("TestVisitFlagOrder", ContinueOnError) + fs.SortFlags = false + names := []string{"C", "B", "A", "D"} + for _, name := range names { + fs.Bool(name, false, "") + fs.Set(name, "true") + } + + i := 0 + fs.Visit(func(f *Flag) { + if names[i] != f.Name { + t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name) + } + i++ + }) +} -- cgit v1.2.3