diff options
| -rw-r--r-- | .editorconfig | 12 | ||||
| -rw-r--r-- | .github/.editorconfig | 2 | ||||
| -rw-r--r-- | .github/dependabot.yaml | 12 | ||||
| -rw-r--r-- | .github/workflows/ci.yaml | 48 | ||||
| -rw-r--r-- | .golangci.yaml | 4 | ||||
| -rw-r--r-- | README.md | 27 | ||||
| -rw-r--r-- | bool_func.go | 40 | ||||
| -rw-r--r-- | bool_func_test.go | 177 | ||||
| -rw-r--r-- | count.go | 2 | ||||
| -rw-r--r-- | errors.go | 149 | ||||
| -rw-r--r-- | errors_test.go | 67 | ||||
| -rw-r--r-- | flag.go | 89 | ||||
| -rw-r--r-- | flag_test.go | 85 | ||||
| -rw-r--r-- | func.go | 37 | ||||
| -rw-r--r-- | func_test.go | 183 | ||||
| -rw-r--r-- | golangflag.go | 22 | ||||
| -rw-r--r-- | golangflag_test.go | 16 | ||||
| -rw-r--r-- | ip.go | 3 | ||||
| -rw-r--r-- | ip_test.go | 2 | ||||
| -rw-r--r-- | ipnet_slice.go | 147 | ||||
| -rw-r--r-- | ipnet_slice_test.go | 239 | ||||
| -rw-r--r-- | string_array.go | 4 | ||||
| -rw-r--r-- | text.go | 81 | ||||
| -rw-r--r-- | text_test.go | 56 | ||||
| -rw-r--r-- | time.go | 118 | ||||
| -rw-r--r-- | time_test.go | 62 |
26 files changed, 1635 insertions, 49 deletions
diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..4492e9f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true + +[*.go] +indent_style = tab diff --git a/.github/.editorconfig b/.github/.editorconfig new file mode 100644 index 0000000..0902c6a --- /dev/null +++ b/.github/.editorconfig @@ -0,0 +1,2 @@ +[{*.yml,*.yaml}] +indent_size = 2 diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..73aa36f --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,12 @@ +version: 2 + +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: daily + + - package-ecosystem: github-actions + directory: / + schedule: + interval: daily diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..42f7614 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,48 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + +jobs: + test: + name: Test + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + go: ["1.21", "1.22", "1.23"] + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Go + uses: actions/setup-go@3041bf56c941b39c61721a86cd11f3bb1338122a # v5.2.0 + with: + go-version: ${{ matrix.go }} + + - name: Test + # Cannot enable shuffle for now because some tests rely on global state and order + # run: go test -race -v -shuffle=on ./... + run: go test -race -v ./... + + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Go + uses: actions/setup-go@3041bf56c941b39c61721a86cd11f3bb1338122a # v5.2.0 + with: + go-version: "1.23" + + - name: Lint + uses: golangci/golangci-lint-action@971e284b6050e8a5849b72094c50ab08da042db8 # v6.1.1 + with: + version: v1.63.4 diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..b274f24 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,4 @@ +linters: + disable-all: true + enable: + - nolintlint @@ -284,6 +284,33 @@ func main() { } ``` +### Using pflag with go test +`pflag` does not parse the shorthand versions of go test's built-in flags (i.e., those starting with `-test.`). +For more context, see issues [#63](https://github.com/spf13/pflag/issues/63) and [#238](https://github.com/spf13/pflag/issues/238) for more details. + +For example, if you use pflag in your `TestMain` function and call `pflag.Parse()` after defining your custom flags, running a test like this: +```bash +go test /your/tests -run ^YourTest -v --your-test-pflags +``` +will result in the `-v` flag being ignored. This happens because of the way pflag handles flag parsing, skipping over go test's built-in shorthand flags. +To work around this, you can use the `ParseSkippedFlags` function, which ensures that go test's flags are parsed separately using the standard flag package. + +**Example**: You want to parse go test flags that are otherwise ignore by `pflag.Parse()` +```go +import ( + goflag "flag" + flag "github.com/spf13/pflag" +) + +var ip *int = flag.Int("flagname", 1234, "help message for flagname") + +func main() { + flag.CommandLine.AddGoFlagSet(goflag.CommandLine) + flag.ParseSkippedFlags(os.Args[1:], goflag.CommandLine) + flag.Parse() +} +``` + ## More info You can see the full reference documentation of the pflag package diff --git a/bool_func.go b/bool_func.go new file mode 100644 index 0000000..83d77af --- /dev/null +++ b/bool_func.go @@ -0,0 +1,40 @@ +package pflag + +// -- func Value +type boolfuncValue func(string) error + +func (f boolfuncValue) Set(s string) error { return f(s) } + +func (f boolfuncValue) Type() string { return "boolfunc" } + +func (f boolfuncValue) String() string { return "" } // same behavior as stdlib 'flag' package + +func (f boolfuncValue) IsBoolFlag() bool { return true } + +// BoolFunc defines a func flag with specified name, callback function and usage string. +// +// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed +// on the command line. +func (f *FlagSet) BoolFunc(name string, usage string, fn func(string) error) { + f.BoolFuncP(name, "", usage, fn) +} + +// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolFuncP(name, shorthand string, usage string, fn func(string) error) { + var val Value = boolfuncValue(fn) + flag := f.VarPF(val, name, shorthand, usage) + flag.NoOptDefVal = "true" +} + +// BoolFunc defines a func flag with specified name, callback function and usage string. +// +// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed +// on the command line. +func BoolFunc(name string, usage string, fn func(string) error) { + CommandLine.BoolFuncP(name, "", usage, fn) +} + +// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash. +func BoolFuncP(name, shorthand string, usage string, fn func(string) error) { + CommandLine.BoolFuncP(name, shorthand, usage, fn) +} diff --git a/bool_func_test.go b/bool_func_test.go new file mode 100644 index 0000000..c16be83 --- /dev/null +++ b/bool_func_test.go @@ -0,0 +1,177 @@ +package pflag + +import ( + "errors" + "flag" + "io" + "strings" + "testing" +) + +func TestBoolFunc(t *testing.T) { + var count int + fn := func(_ string) error { + count++ + return nil + } + + fset := NewFlagSet("test", ContinueOnError) + fset.BoolFunc("func", "Callback function", fn) + + err := fset.Parse([]string{"--func", "--func=1", "--func=false"}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + if count != 3 { + t.Fatalf("expected 3 calls to the callback, got %d calls", count) + } +} + +func TestBoolFuncP(t *testing.T) { + var count int + fn := func(_ string) error { + count++ + return nil + } + + fset := NewFlagSet("test", ContinueOnError) + fset.BoolFuncP("bfunc", "b", "Callback function", fn) + + err := fset.Parse([]string{"--bfunc", "--bfunc=0", "--bfunc=false", "-b", "-b=0"}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + if count != 5 { + t.Fatalf("expected 5 calls to the callback, got %d calls", count) + } +} + +func TestBoolFuncCompat(t *testing.T) { + // compare behavior with the stdlib 'flag' package + type BoolFuncFlagSet interface { + BoolFunc(name string, usage string, fn func(string) error) + Parse([]string) error + } + + unitTestErr := errors.New("unit test error") + runCase := func(f BoolFuncFlagSet, name string, args []string) (values []string, err error) { + fn := func(s string) error { + values = append(values, s) + if s == "err" { + return unitTestErr + } + return nil + } + f.BoolFunc(name, "Callback function", fn) + + err = f.Parse(args) + return values, err + } + + t.Run("regular parsing", func(t *testing.T) { + flagName := "bflag" + args := []string{"--bflag", "--bflag=false", "--bflag=1", "--bflag=bar", "--bflag="} + + // It turns out that, even though the function is called "BoolFunc", + // the standard flag package does not try to parse the value assigned to + // that cli flag as a boolean. The string provided on the command line is + // passed as is to the callback. + // e.g: with "--bflag=not_a_bool" on the command line, the FlagSet does not + // generate an error stating "invalid boolean value", and `fn` will be called + // with "not_a_bool" as an argument. + + stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError) + stdValues, err := runCase(stdFSet, flagName, args) + if err != nil { + t.Fatalf("std flag: expected no error, got %v", err) + } + expected := []string{"true", "false", "1", "bar", ""} + if !cmpLists(expected, stdValues) { + t.Fatalf("std flag: expected %v, got %v", expected, stdValues) + } + + fset := NewFlagSet("pflag test", ContinueOnError) + pflagValues, err := runCase(fset, flagName, args) + if err != nil { + t.Fatalf("pflag: expected no error, got %v", err) + } + if !cmpLists(stdValues, pflagValues) { + t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues) + } + }) + + t.Run("error triggered by callback", func(t *testing.T) { + flagName := "bflag" + args := []string{"--bflag", "--bflag=err", "--bflag=after"} + + // test behavior of standard flag.Fset with an error triggered by the callback: + // (note: as can be seen in 'runCase()', if the callback sees "err" as a value + // for the bool flag, it will return an error) + stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError) + stdFSet.SetOutput(io.Discard) // suppress output + + // run test case with standard flag.Fset + stdValues, err := runCase(stdFSet, flagName, args) + + // double check the standard behavior: + // - .Parse() should return an error, which contains the error message + if err == nil { + t.Fatalf("std flag: expected an error triggered by callback, got no error instead") + } + if !strings.HasSuffix(err.Error(), unitTestErr.Error()) { + t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err) + } + // - the function should have been called twice, with the first two values, + // the final "=after" should not be recorded + expected := []string{"true", "err"} + if !cmpLists(expected, stdValues) { + t.Fatalf("std flag: expected %v, got %v", expected, stdValues) + } + + // now run the test case on a pflag FlagSet: + fset := NewFlagSet("pflag test", ContinueOnError) + pflagValues, err := runCase(fset, flagName, args) + + // check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib + // currently keeps the original message but creates a flat errors.Error) + if !errors.Is(err, unitTestErr) { + t.Fatalf("pflag: got unexpected error value: %T %v", err, err) + } + // the callback should be called the same number of times, with the same values: + if !cmpLists(stdValues, pflagValues) { + t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues) + } + }) +} + +func TestBoolFuncUsage(t *testing.T) { + t.Run("regular func flag", func(t *testing.T) { + // regular boolfunc flag: + // expect to see '--flag1' followed by the usageMessage, and no mention of a default value + fset := NewFlagSet("unittest", ContinueOnError) + fset.BoolFunc("flag1", "usage message", func(s string) error { return nil }) + usage := fset.FlagUsagesWrapped(80) + + usage = strings.TrimSpace(usage) + expected := "--flag1 usage message" + if usage != expected { + t.Fatalf("unexpected generated usage message\n expected: %s\n got: %s", expected, usage) + } + }) + + t.Run("func flag with placeholder name", func(t *testing.T) { + // func flag, with a placeholder name: + // if usageMesage contains a placeholder, expect '--flag2 {placeholder}'; still expect no mention of a default value + fset := NewFlagSet("unittest", ContinueOnError) + fset.BoolFunc("flag2", "usage message with `name` placeholder", func(s string) error { return nil }) + usage := fset.FlagUsagesWrapped(80) + + usage = strings.TrimSpace(usage) + expected := "--flag2 name usage message with name placeholder" + if usage != expected { + t.Fatalf("unexpected generated usage message\n expected: %s\n got: %s", expected, usage) + } + }) +} @@ -85,7 +85,7 @@ func (f *FlagSet) CountP(name, shorthand string, usage string) *int { // Count defines a count flag with specified name, default value, and usage string. // The return value is the address of an int variable that stores the value of the flag. -// A count flag will add 1 to its value evey time it is found on the command line +// A count flag will add 1 to its value every time it is found on the command line func Count(name string, usage string) *int { return CommandLine.CountP(name, "", usage) } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..ff11b66 --- /dev/null +++ b/errors.go @@ -0,0 +1,149 @@ +package pflag + +import "fmt" + +// notExistErrorMessageType specifies which flavor of "flag does not exist" +// is printed by NotExistError. This allows the related errors to be grouped +// under a single NotExistError struct without making a breaking change to +// the error message text. +type notExistErrorMessageType int + +const ( + flagNotExistMessage notExistErrorMessageType = iota + flagNotDefinedMessage + flagNoSuchFlagMessage + flagUnknownFlagMessage + flagUnknownShorthandFlagMessage +) + +// NotExistError is the error returned when trying to access a flag that +// does not exist in the FlagSet. +type NotExistError struct { + name string + specifiedShorthands string + messageType notExistErrorMessageType +} + +// Error implements error. +func (e *NotExistError) Error() string { + switch e.messageType { + case flagNotExistMessage: + return fmt.Sprintf("flag %q does not exist", e.name) + + case flagNotDefinedMessage: + return fmt.Sprintf("flag accessed but not defined: %s", e.name) + + case flagNoSuchFlagMessage: + return fmt.Sprintf("no such flag -%v", e.name) + + case flagUnknownFlagMessage: + return fmt.Sprintf("unknown flag: --%s", e.name) + + case flagUnknownShorthandFlagMessage: + c := rune(e.name[0]) + return fmt.Sprintf("unknown shorthand flag: %q in -%s", c, e.specifiedShorthands) + } + + panic(fmt.Errorf("unknown flagNotExistErrorMessageType: %v", e.messageType)) +} + +// GetSpecifiedName returns the name of the flag (without dashes) as it +// appeared in the parsed arguments. +func (e *NotExistError) GetSpecifiedName() string { + return e.name +} + +// GetSpecifiedShortnames returns the group of shorthand arguments +// (without dashes) that the flag appeared within. If the flag was not in a +// shorthand group, this will return an empty string. +func (e *NotExistError) GetSpecifiedShortnames() string { + return e.specifiedShorthands +} + +// ValueRequiredError is the error returned when a flag needs an argument but +// no argument was provided. +type ValueRequiredError struct { + flag *Flag + specifiedName string + specifiedShorthands string +} + +// Error implements error. +func (e *ValueRequiredError) Error() string { + if len(e.specifiedShorthands) > 0 { + c := rune(e.specifiedName[0]) + return fmt.Sprintf("flag needs an argument: %q in -%s", c, e.specifiedShorthands) + } + + return fmt.Sprintf("flag needs an argument: --%s", e.specifiedName) +} + +// GetFlag returns the flag for which the error occurred. +func (e *ValueRequiredError) GetFlag() *Flag { + return e.flag +} + +// GetSpecifiedName returns the name of the flag (without dashes) as it +// appeared in the parsed arguments. +func (e *ValueRequiredError) GetSpecifiedName() string { + return e.specifiedName +} + +// GetSpecifiedShortnames returns the group of shorthand arguments +// (without dashes) that the flag appeared within. If the flag was not in a +// shorthand group, this will return an empty string. +func (e *ValueRequiredError) GetSpecifiedShortnames() string { + return e.specifiedShorthands +} + +// InvalidValueError is the error returned when an invalid value is used +// for a flag. +type InvalidValueError struct { + flag *Flag + value string + cause error +} + +// Error implements error. +func (e *InvalidValueError) Error() string { + flag := e.flag + var flagName string + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name) + } else { + flagName = fmt.Sprintf("--%s", flag.Name) + } + return fmt.Sprintf("invalid argument %q for %q flag: %v", e.value, flagName, e.cause) +} + +// Unwrap implements errors.Unwrap. +func (e *InvalidValueError) Unwrap() error { + return e.cause +} + +// GetFlag returns the flag for which the error occurred. +func (e *InvalidValueError) GetFlag() *Flag { + return e.flag +} + +// GetValue returns the invalid value that was provided. +func (e *InvalidValueError) GetValue() string { + return e.value +} + +// InvalidSyntaxError is the error returned when a bad flag name is passed on +// the command line. +type InvalidSyntaxError struct { + specifiedFlag string +} + +// Error implements error. +func (e *InvalidSyntaxError) Error() string { + return fmt.Sprintf("bad flag syntax: %s", e.specifiedFlag) +} + +// GetSpecifiedName returns the exact flag (with dashes) as it +// appeared in the parsed arguments. +func (e *InvalidSyntaxError) GetSpecifiedFlag() string { + return e.specifiedFlag +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..7b4c7a4 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,67 @@ +package pflag + +import ( + "errors" + "testing" +) + +func TestNotExistError(t *testing.T) { + err := &NotExistError{ + name: "foo", + specifiedShorthands: "bar", + } + + if err.GetSpecifiedName() != "foo" { + t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName()) + } + if err.GetSpecifiedShortnames() != "bar" { + t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames()) + } +} + +func TestValueRequiredError(t *testing.T) { + err := &ValueRequiredError{ + flag: &Flag{}, + specifiedName: "foo", + specifiedShorthands: "bar", + } + + if err.GetFlag() == nil { + t.Error("Expected GetSpecifiedName to return its flag field, but got nil") + } + if err.GetSpecifiedName() != "foo" { + t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName()) + } + if err.GetSpecifiedShortnames() != "bar" { + t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames()) + } +} + +func TestInvalidValueError(t *testing.T) { + expectedCause := errors.New("error") + err := &InvalidValueError{ + flag: &Flag{}, + value: "foo", + cause: expectedCause, + } + + if err.GetFlag() == nil { + t.Error("Expected GetSpecifiedName to return its flag field, but got nil") + } + if err.GetValue() != "foo" { + t.Errorf("Expected GetValue to return %q, got %q", "foo", err.GetValue()) + } + if err.Unwrap() != expectedCause { + t.Errorf("Expected Unwrwap to return %q, got %q", expectedCause, err.Unwrap()) + } +} + +func TestInvalidSyntaxError(t *testing.T) { + err := &InvalidSyntaxError{ + specifiedFlag: "--=", + } + + if err.GetSpecifiedFlag() != "--=" { + t.Errorf("Expected GetSpecifiedFlag to return %q, got %q", "--=", err.GetSpecifiedFlag()) + } +} @@ -27,23 +27,32 @@ unaffected. Define flags using flag.String(), Bool(), Int(), etc. This declares an integer flag, -flagname, stored in the pointer ip, with type *int. + var ip = flag.Int("flagname", 1234, "help message for flagname") + If you like, you can bind the flag to a variable using the Var() functions. + var flagvar int func init() { flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") } + Or you can create custom flags that satisfy the Value interface (with pointer receivers) and couple them to flag parsing by + flag.Var(&flagVal, "name", "help message for flagname") + For such flags, the default value is just the initial value of the variable. After all flags are defined, call + flag.Parse() + to parse the command line into the defined flags. Flags may then be used directly. If you're using the flags themselves, they are all pointers; if you bind to variables, they're values. + fmt.Println("ip has value ", *ip) fmt.Println("flagvar has value ", flagvar) @@ -54,22 +63,26 @@ The arguments are indexed from 0 through flag.NArg()-1. The pflag package also defines some new functions that are not in flag, that give one-letter shorthands for flags. You can use these by appending 'P' to the name of any function that defines a flag. + var ip = flag.IntP("flagname", "f", 1234, "help message") var flagvar bool func init() { flag.BoolVarP(&flagvar, "boolname", "b", true, "help message") } flag.VarP(&flagval, "varname", "v", "help message") + Shorthand letters can be used with single dashes on the command line. Boolean shorthand flags can be combined with other shorthand flags. Command line flag syntax: + --flag // boolean flags only --flag=x Unlike the flag package, a single dash before an option means something different than a double dash. Single dashes signify a series of shorthand letters for flags. All but the last shorthand letter must be boolean flags. + // boolean flags -f -abc @@ -381,7 +394,7 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag { func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { flag := f.Lookup(name) if flag == nil { - err := fmt.Errorf("flag accessed but not defined: %s", name) + err := &NotExistError{name: name, messageType: flagNotDefinedMessage} return nil, err } @@ -411,7 +424,7 @@ func (f *FlagSet) ArgsLenAtDash() int { func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { flag := f.Lookup(name) if flag == nil { - return fmt.Errorf("flag %q does not exist", name) + return &NotExistError{name: name, messageType: flagNotExistMessage} } if usageMessage == "" { return fmt.Errorf("deprecated message for flag %q must be set", name) @@ -427,7 +440,7 @@ func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error { flag := f.Lookup(name) if flag == nil { - return fmt.Errorf("flag %q does not exist", name) + return &NotExistError{name: name, messageType: flagNotExistMessage} } if usageMessage == "" { return fmt.Errorf("deprecated message for flag %q must be set", name) @@ -441,7 +454,7 @@ func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) erro func (f *FlagSet) MarkHidden(name string) error { flag := f.Lookup(name) if flag == nil { - return fmt.Errorf("flag %q does not exist", name) + return &NotExistError{name: name, messageType: flagNotExistMessage} } flag.Hidden = true return nil @@ -464,18 +477,16 @@ func (f *FlagSet) Set(name, value string) error { normalName := f.normalizeFlagName(name) flag, ok := f.formal[normalName] if !ok { - return fmt.Errorf("no such flag -%v", name) + return &NotExistError{name: name, messageType: flagNoSuchFlagMessage} } err := flag.Value.Set(value) if err != nil { - var flagName string - if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { - flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name) - } else { - flagName = fmt.Sprintf("--%s", flag.Name) + return &InvalidValueError{ + flag: flag, + value: value, + cause: err, } - return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err) } if !flag.Changed { @@ -501,7 +512,7 @@ func (f *FlagSet) SetAnnotation(name, key string, values []string) error { normalName := f.normalizeFlagName(name) flag, ok := f.formal[normalName] if !ok { - return fmt.Errorf("no such flag -%v", name) + return &NotExistError{name: name, messageType: flagNoSuchFlagMessage} } if flag.Annotations == nil { flag.Annotations = map[string][]string{} @@ -538,7 +549,7 @@ func (f *FlagSet) PrintDefaults() { func (f *Flag) defaultIsZeroValue() bool { switch f.Value.(type) { case boolFlag: - return f.DefValue == "false" + return f.DefValue == "false" || f.DefValue == "" case *durationValue: // Beginning in Go 1.7, duration zero values are "0s" return f.DefValue == "0" || f.DefValue == "0s" @@ -551,7 +562,7 @@ func (f *Flag) defaultIsZeroValue() bool { case *intSliceValue, *stringSliceValue, *stringArrayValue: return f.DefValue == "[]" default: - switch f.Value.String() { + switch f.DefValue { case "false": return true case "<nil>": @@ -588,8 +599,10 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { name = flag.Value.Type() switch name { - case "bool": + case "bool", "boolfunc": name = "" + case "func": + name = "value" case "float64": name = "float" case "int64": @@ -707,7 +720,7 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string { switch flag.Value.Type() { case "string": line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) - case "bool": + case "bool", "boolfunc": if flag.NoOptDefVal != "true" { line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) } @@ -911,12 +924,10 @@ func VarP(value Value, name, shorthand, usage string) { CommandLine.VarP(value, name, shorthand, usage) } -// failf prints to standard error a formatted error and usage message and +// fail prints an error message and usage message to standard error and // returns the error. -func (f *FlagSet) failf(format string, a ...interface{}) error { - err := fmt.Errorf(format, a...) +func (f *FlagSet) fail(err error) error { if f.errorHandling != ContinueOnError { - fmt.Fprintln(f.Output(), err) f.usage() } return err @@ -934,9 +945,9 @@ func (f *FlagSet) usage() { } } -//--unknown (args will be empty) -//--unknown --next-flag ... (args will be --next-flag ...) -//--unknown arg ... (args will be arg ...) +// --unknown (args will be empty) +// --unknown --next-flag ... (args will be --next-flag ...) +// --unknown arg ... (args will be arg ...) func stripUnknownFlagValue(args []string) []string { if len(args) == 0 { //--unknown @@ -960,7 +971,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin a = args name := s[2:] if len(name) == 0 || name[0] == '-' || name[0] == '=' { - err = f.failf("bad flag syntax: %s", s) + err = f.fail(&InvalidSyntaxError{specifiedFlag: s}) return } @@ -982,7 +993,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin return stripUnknownFlagValue(a), nil default: - err = f.failf("unknown flag: --%s", name) + err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage}) return } } @@ -1000,13 +1011,16 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin a = a[1:] } else { // '--flag' (arg was required) - err = f.failf("flag needs an argument: %s", s) + err = f.fail(&ValueRequiredError{ + flag: flag, + specifiedName: name, + }) return } err = fn(flag, value) if err != nil { - f.failf(err.Error()) + f.fail(err) } return } @@ -1014,7 +1028,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) { outArgs = args - if strings.HasPrefix(shorthands, "test.") { + if isGotestShorthandFlag(shorthands) { return } @@ -1039,7 +1053,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse outArgs = stripUnknownFlagValue(outArgs) return default: - err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) + err = f.fail(&NotExistError{ + name: string(c), + specifiedShorthands: shorthands, + messageType: flagUnknownShorthandFlagMessage, + }) return } } @@ -1062,7 +1080,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse outArgs = args[1:] } else { // '-f' (arg was required) - err = f.failf("flag needs an argument: %q in -%s", c, shorthands) + err = f.fail(&ValueRequiredError{ + flag: flag, + specifiedName: string(c), + specifiedShorthands: shorthands, + }) return } @@ -1072,7 +1094,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse err = fn(flag, value) if err != nil { - f.failf(err.Error()) + f.fail(err) } return } @@ -1135,7 +1157,7 @@ func (f *FlagSet) Parse(arguments []string) error { } f.parsed = true - if len(arguments) < 0 { + if len(arguments) == 0 { return nil } @@ -1151,7 +1173,7 @@ func (f *FlagSet) Parse(arguments []string) error { case ContinueOnError: return err case ExitOnError: - fmt.Println(err) + fmt.Fprintln(f.Output(), err) os.Exit(2) case PanicOnError: panic(err) @@ -1177,6 +1199,7 @@ func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) case ContinueOnError: return err case ExitOnError: + fmt.Fprintln(f.Output(), err) os.Exit(2) case PanicOnError: panic(err) diff --git a/flag_test.go b/flag_test.go index 76535f3..2df3ea2 100644 --- a/flag_test.go +++ b/flag_test.go @@ -100,12 +100,23 @@ func TestEverything(t *testing.T) { } } +func TestNoArgument(t *testing.T) { + if GetCommandLine().Parse([]string{}) != nil { + t.Error("parse failed for empty argument list") + } +} + func TestUsage(t *testing.T) { called := false ResetForTesting(func() { called = true }) - if GetCommandLine().Parse([]string{"--x"}) == nil { + err := GetCommandLine().Parse([]string{"--x"}) + expectedErr := "unknown flag: --x" + if err == nil { t.Error("parse did not fail for unknown flag") } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } if called { t.Error("did call Usage while using ContinueOnError") } @@ -131,9 +142,14 @@ func TestAddFlagSet(t *testing.T) { func TestAnnotation(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) - if err := f.SetAnnotation("missing-flag", "key", nil); err == nil { + err := f.SetAnnotation("missing-flag", "key", nil) + expectedErr := "no such flag -missing-flag" + if err == nil { t.Errorf("Expected error setting annotation on non-existent flag") } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } f.StringP("stringa", "a", "", "string value") if err := f.SetAnnotation("stringa", "key", nil); err != nil { @@ -349,6 +365,33 @@ func testParse(f *FlagSet, t *testing.T) { } else if f.Args()[0] != extra { t.Errorf("expected argument %q got %q", extra, f.Args()[0]) } + // Test unknown + err := f.Parse([]string{"--unknown"}) + expectedErr := "unknown flag: --unknown" + if err == nil { + t.Error("parse did not fail for unknown flag") + } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } + // Test invalid + err = f.Parse([]string{"--bool=abcdefg"}) + expectedErr = `invalid argument "abcdefg" for "--bool" flag: strconv.ParseBool: parsing "abcdefg": invalid syntax` + if err == nil { + t.Error("parse did not fail for invalid argument") + } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } + // Test required + err = f.Parse([]string{"--int"}) + expectedErr = `flag needs an argument: --int` + if err == nil { + t.Error("parse did not fail for missing argument") + } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } } func testParseAll(f *FlagSet, t *testing.T) { @@ -433,7 +476,7 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "-u=unknown3Value", "-p", "unknown4Value", - "-q", //another unknown with bool value + "-q", // another unknown with bool value "-y", "ee", "--unknown7=unknown7value", @@ -538,6 +581,24 @@ func TestShorthand(t *testing.T) { if f.ArgsLenAtDash() != 1 { t.Errorf("expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1) } + // Test unknown + err := f.Parse([]string{"-ukn"}) + expectedErr := "unknown shorthand flag: 'u' in -ukn" + if err == nil { + t.Error("parse did not fail for unknown shorthand flag") + } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } + // Test required + err = f.Parse([]string{"-as"}) + expectedErr = `flag needs an argument: 's' in -s` + if err == nil { + t.Error("parse did not fail for missing argument") + } + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } } func TestShorthandLookup(t *testing.T) { @@ -899,7 +960,7 @@ func TestChangingArgs(t *testing.T) { // Test that -help invokes the usage message and returns ErrHelp. func TestHelp(t *testing.T) { - var helpCalled = false + helpCalled := false fs := NewFlagSet("help test", ContinueOnError) fs.Usage = func() { helpCalled = true } var flag bool @@ -998,6 +1059,7 @@ func getDeprecatedFlagSet() *FlagSet { f.MarkDeprecated("badflag", "use --good-flag instead") return f } + func TestDeprecatedFlagInDocs(t *testing.T) { f := getDeprecatedFlagSet() @@ -1134,7 +1196,6 @@ func TestMultipleNormalizeFlagNameInvocations(t *testing.T) { } } -// func TestHiddenFlagInUsage(t *testing.T) { f := NewFlagSet("bob", ContinueOnError) f.Bool("secretFlag", true, "shhh") @@ -1149,7 +1210,6 @@ func TestHiddenFlagInUsage(t *testing.T) { } } -// func TestHiddenFlagUsage(t *testing.T) { f := NewFlagSet("bob", ContinueOnError) f.Bool("secretFlag", true, "shhh") @@ -1184,6 +1244,7 @@ const defaultOutput = ` --A for bootstrapping, allo --StringSlice strings string slice with zero default --Z int an int that defaults to zero --custom custom custom Value implementation + --custom-with-val custom custom value which has been set from command line while help is shown --customP custom a VarP with default (default 10) --maxT timeout set timeout for dial -v, --verbose count verbosity @@ -1235,12 +1296,18 @@ func TestPrintDefaults(t *testing.T) { cv2 := customValue(10) fs.VarP(&cv2, "customP", "", "a VarP with default") + // Simulate case where a value has been provided and the help screen is shown + var cv3 customValue + fs.Var(&cv3, "custom-with-val", "custom value which has been set from command line while help is shown") + err := fs.Parse([]string{"--custom-with-val", "3"}) + if err != nil { + t.Error("Parsing flags failed:", err) + } + fs.PrintDefaults() got := buf.String() if got != defaultOutput { - fmt.Println("\n" + got) - fmt.Println("\n" + defaultOutput) - t.Errorf("got %q want %q\n", got, defaultOutput) + t.Errorf("\n--- Got:\n%s--- Wanted:\n%s\n", got, defaultOutput) } } @@ -0,0 +1,37 @@ +package pflag + +// -- func Value +type funcValue func(string) error + +func (f funcValue) Set(s string) error { return f(s) } + +func (f funcValue) Type() string { return "func" } + +func (f funcValue) String() string { return "" } // same behavior as stdlib 'flag' package + +// Func defines a func flag with specified name, callback function and usage string. +// +// The callback function will be called every time "--{name}={value}" (or equivalent) is +// parsed on the command line, with "{value}" as an argument. +func (f *FlagSet) Func(name string, usage string, fn func(string) error) { + f.FuncP(name, "", usage, fn) +} + +// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) FuncP(name string, shorthand string, usage string, fn func(string) error) { + var val Value = funcValue(fn) + f.VarP(val, name, shorthand, usage) +} + +// Func defines a func flag with specified name, callback function and usage string. +// +// The callback function will be called every time "--{name}={value}" (or equivalent) is +// parsed on the command line, with "{value}" as an argument. +func Func(name string, usage string, fn func(string) error) { + CommandLine.FuncP(name, "", usage, fn) +} + +// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash. +func FuncP(name, shorthand string, usage string, fn func(string) error) { + CommandLine.FuncP(name, shorthand, usage, fn) +} diff --git a/func_test.go b/func_test.go new file mode 100644 index 0000000..d492b48 --- /dev/null +++ b/func_test.go @@ -0,0 +1,183 @@ +package pflag + +import ( + "errors" + "flag" + "io" + "strings" + "testing" +) + +func cmpLists(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestFunc(t *testing.T) { + var values []string + fn := func(s string) error { + values = append(values, s) + return nil + } + + fset := NewFlagSet("test", ContinueOnError) + fset.Func("fnflag", "Callback function", fn) + + err := fset.Parse([]string{"--fnflag=aa", "--fnflag", "bb"}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + expected := []string{"aa", "bb"} + if !cmpLists(expected, values) { + t.Fatalf("expected %v, got %v", expected, values) + } +} + +func TestFuncP(t *testing.T) { + var values []string + fn := func(s string) error { + values = append(values, s) + return nil + } + + fset := NewFlagSet("test", ContinueOnError) + fset.FuncP("fnflag", "f", "Callback function", fn) + + err := fset.Parse([]string{"--fnflag=a", "--fnflag", "b", "-fc", "-f=d", "-f", "e"}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + expected := []string{"a", "b", "c", "d", "e"} + if !cmpLists(expected, values) { + t.Fatalf("expected %v, got %v", expected, values) + } +} + +func TestFuncCompat(t *testing.T) { + // compare behavior with the stdlib 'flag' package + type FuncFlagSet interface { + Func(name string, usage string, fn func(string) error) + Parse([]string) error + } + + unitTestErr := errors.New("unit test error") + runCase := func(f FuncFlagSet, name string, args []string) (values []string, err error) { + fn := func(s string) error { + values = append(values, s) + if s == "err" { + return unitTestErr + } + return nil + } + f.Func(name, "Callback function", fn) + + err = f.Parse(args) + return values, err + } + + t.Run("regular parsing", func(t *testing.T) { + flagName := "fnflag" + args := []string{"--fnflag=xx", "--fnflag", "yy", "--fnflag=zz"} + + stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError) + stdValues, err := runCase(stdFSet, flagName, args) + if err != nil { + t.Fatalf("std flag: expected no error, got %v", err) + } + expected := []string{"xx", "yy", "zz"} + if !cmpLists(expected, stdValues) { + t.Fatalf("std flag: expected %v, got %v", expected, stdValues) + } + + fset := NewFlagSet("pflag test", ContinueOnError) + pflagValues, err := runCase(fset, flagName, args) + if err != nil { + t.Fatalf("pflag: expected no error, got %v", err) + } + if !cmpLists(stdValues, pflagValues) { + t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues) + } + }) + + t.Run("error triggered by callback", func(t *testing.T) { + flagName := "fnflag" + args := []string{"--fnflag", "before", "--fnflag", "err", "--fnflag", "after"} + + // test behavior of standard flag.Fset with an error triggered by the callback: + // (note: as can be seen in 'runCase()', if the callback sees "err" as a value + // for the flag, it will return an error) + stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError) + stdFSet.SetOutput(io.Discard) // suppress output + + // run test case with standard flag.Fset + stdValues, err := runCase(stdFSet, flagName, args) + + // double check the standard behavior: + // - .Parse() should return an error, which contains the error message + if err == nil { + t.Fatalf("std flag: expected an error triggered by callback, got no error instead") + } + if !strings.HasSuffix(err.Error(), unitTestErr.Error()) { + t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err) + } + // - the function should have been called twice, with the first two values, + // the final "=after" should not be recorded + expected := []string{"before", "err"} + if !cmpLists(expected, stdValues) { + t.Fatalf("std flag: expected %v, got %v", expected, stdValues) + } + + // now run the test case on a pflag FlagSet: + fset := NewFlagSet("pflag test", ContinueOnError) + pflagValues, err := runCase(fset, flagName, args) + + // check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib + // currently keeps the original message but creates a flat errors.Error) + if !errors.Is(err, unitTestErr) { + t.Fatalf("pflag: got unexpected error value: %T %v", err, err) + } + // the callback should be called the same number of times, with the same values: + if !cmpLists(stdValues, pflagValues) { + t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues) + } + }) +} + +func TestFuncUsage(t *testing.T) { + t.Run("regular func flag", func(t *testing.T) { + // regular func flag: + // expect to see '--flag1 value' followed by the usageMessage, and no mention of a default value + fset := NewFlagSet("unittest", ContinueOnError) + fset.Func("flag1", "usage message", func(s string) error { return nil }) + usage := fset.FlagUsagesWrapped(80) + + usage = strings.TrimSpace(usage) + expected := "--flag1 value usage message" + if usage != expected { + t.Fatalf("unexpected generated usage message\n expected: %s\n got: %s", expected, usage) + } + }) + + t.Run("func flag with placeholder name", func(t *testing.T) { + // func flag, with a placeholder name: + // if usageMesage contains a placeholder, expect that name; still expect no mention of a default value + fset := NewFlagSet("unittest", ContinueOnError) + fset.Func("flag2", "usage message with `name` placeholder", func(s string) error { return nil }) + usage := fset.FlagUsagesWrapped(80) + + usage = strings.TrimSpace(usage) + expected := "--flag2 name usage message with name placeholder" + if usage != expected { + t.Fatalf("unexpected generated usage message\n expected: %s\n got: %s", expected, usage) + } + }) +} diff --git a/golangflag.go b/golangflag.go index d3dd72b..f563907 100644 --- a/golangflag.go +++ b/golangflag.go @@ -10,6 +10,15 @@ import ( "strings" ) +// go test flags prefixes +func isGotestFlag(flag string) bool { + return strings.HasPrefix(flag, "-test.") +} + +func isGotestShorthandFlag(flag string) bool { + return strings.HasPrefix(flag, "test.") +} + // flagValueWrapper implements pflag.Value around a flag.Value. The main // difference here is the addition of the Type method that returns a string // name of the type. As this is generally unknown, we approximate that with @@ -103,3 +112,16 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) { } f.addedGoFlagSets = append(f.addedGoFlagSets, newSet) } + +// ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(), +// since by default those are skipped by pflag.Parse(). +// Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)` +func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error { + var skippedFlags []string + for _, f := range osArgs { + if isGotestFlag(f) { + skippedFlags = append(skippedFlags, f) + } + } + return goFlagSet.Parse(skippedFlags) +} diff --git a/golangflag_test.go b/golangflag_test.go index 5bd831b..2ecefef 100644 --- a/golangflag_test.go +++ b/golangflag_test.go @@ -12,11 +12,14 @@ import ( func TestGoflags(t *testing.T) { goflag.String("stringFlag", "stringFlag", "stringFlag") goflag.Bool("boolFlag", false, "boolFlag") + var testxxxValue string + goflag.StringVar(&testxxxValue, "test.xxx", "test.xxx", "it is a test flag") f := NewFlagSet("test", ContinueOnError) f.AddGoFlagSet(goflag.CommandLine) - err := f.Parse([]string{"--stringFlag=bob", "--boolFlag"}) + args := []string{"--stringFlag=bob", "--boolFlag", "-test.xxx=testvalue"} + err := f.Parse(args) if err != nil { t.Fatal("expected no error; get", err) } @@ -40,6 +43,17 @@ func TestGoflags(t *testing.T) { t.Fatal("f.Parsed() return false after f.Parse() called") } + if testxxxValue != "test.xxx" { + t.Fatalf("expected testxxxValue to be test.xxx but got %v", testxxxValue) + } + err = ParseSkippedFlags(args, goflag.CommandLine) + if err != nil { + t.Fatal("expected no error; ParseSkippedFlags", err) + } + if testxxxValue != "testvalue" { + t.Fatalf("expected testxxxValue to be testvalue but got %v", testxxxValue) + } + // in fact it is useless. because `go test` called flag.Parse() if !goflag.CommandLine.Parsed() { t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called") @@ -16,6 +16,9 @@ func newIPValue(val net.IP, p *net.IP) *ipValue { func (i *ipValue) String() string { return net.IP(*i).String() } func (i *ipValue) Set(s string) error { + if s == "" { + return nil + } ip := net.ParseIP(strings.TrimSpace(s)) if ip == nil { return fmt.Errorf("failed to parse IP: %q", s) @@ -24,7 +24,7 @@ func TestIP(t *testing.T) { {"1.2.3.4", true, "1.2.3.4"}, {"127.0.0.1", true, "127.0.0.1"}, {"255.255.255.255", true, "255.255.255.255"}, - {"", false, ""}, + {"", true, "0.0.0.0"}, {"0", false, ""}, {"localhost", false, ""}, {"0.0.0", false, ""}, diff --git a/ipnet_slice.go b/ipnet_slice.go new file mode 100644 index 0000000..c6e89da --- /dev/null +++ b/ipnet_slice.go @@ -0,0 +1,147 @@ +package pflag + +import ( + "fmt" + "io" + "net" + "strings" +) + +// -- ipNetSlice Value +type ipNetSliceValue struct { + value *[]net.IPNet + changed bool +} + +func newIPNetSliceValue(val []net.IPNet, p *[]net.IPNet) *ipNetSliceValue { + ipnsv := new(ipNetSliceValue) + ipnsv.value = p + *ipnsv.value = val + return ipnsv +} + +// Set converts, and assigns, the comma-separated IPNet argument string representation as the []net.IPNet value of this flag. +// If Set is called on a flag that already has a []net.IPNet assigned, the newly converted values will be appended. +func (s *ipNetSliceValue) Set(val string) error { + + // remove all quote characters + rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "") + + // read flag arguments with CSV parser + ipNetStrSlice, err := readAsCSV(rmQuote.Replace(val)) + if err != nil && err != io.EOF { + return err + } + + // parse ip values into slice + out := make([]net.IPNet, 0, len(ipNetStrSlice)) + for _, ipNetStr := range ipNetStrSlice { + _, n, err := net.ParseCIDR(strings.TrimSpace(ipNetStr)) + if err != nil { + return fmt.Errorf("invalid string being converted to CIDR: %s", ipNetStr) + } + out = append(out, *n) + } + + if !s.changed { + *s.value = out + } else { + *s.value = append(*s.value, out...) + } + + s.changed = true + + return nil +} + +// Type returns a string that uniquely represents this flag's type. +func (s *ipNetSliceValue) Type() string { + return "ipNetSlice" +} + +// String defines a "native" format for this net.IPNet slice flag value. +func (s *ipNetSliceValue) String() string { + + ipNetStrSlice := make([]string, len(*s.value)) + for i, n := range *s.value { + ipNetStrSlice[i] = n.String() + } + + out, _ := writeAsCSV(ipNetStrSlice) + return "[" + out + "]" +} + +func ipNetSliceConv(val string) (interface{}, error) { + val = strings.Trim(val, "[]") + // Empty string would cause a slice with one (empty) entry + if len(val) == 0 { + return []net.IPNet{}, nil + } + ss := strings.Split(val, ",") + out := make([]net.IPNet, len(ss)) + for i, sval := range ss { + _, n, err := net.ParseCIDR(strings.TrimSpace(sval)) + if err != nil { + return nil, fmt.Errorf("invalid string being converted to CIDR: %s", sval) + } + out[i] = *n + } + return out, nil +} + +// GetIPNetSlice returns the []net.IPNet value of a flag with the given name +func (f *FlagSet) GetIPNetSlice(name string) ([]net.IPNet, error) { + val, err := f.getFlagType(name, "ipNetSlice", ipNetSliceConv) + if err != nil { + return []net.IPNet{}, err + } + return val.([]net.IPNet), nil +} + +// IPNetSliceVar defines a ipNetSlice flag with specified name, default value, and usage string. +// The argument p points to a []net.IPNet variable in which to store the value of the flag. +func (f *FlagSet) IPNetSliceVar(p *[]net.IPNet, name string, value []net.IPNet, usage string) { + f.VarP(newIPNetSliceValue(value, p), name, "", usage) +} + +// IPNetSliceVarP is like IPNetSliceVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IPNetSliceVarP(p *[]net.IPNet, name, shorthand string, value []net.IPNet, usage string) { + f.VarP(newIPNetSliceValue(value, p), name, shorthand, usage) +} + +// IPNetSliceVar defines a []net.IPNet flag with specified name, default value, and usage string. +// The argument p points to a []net.IPNet variable in which to store the value of the flag. +func IPNetSliceVar(p *[]net.IPNet, name string, value []net.IPNet, usage string) { + CommandLine.VarP(newIPNetSliceValue(value, p), name, "", usage) +} + +// IPNetSliceVarP is like IPNetSliceVar, but accepts a shorthand letter that can be used after a single dash. +func IPNetSliceVarP(p *[]net.IPNet, name, shorthand string, value []net.IPNet, usage string) { + CommandLine.VarP(newIPNetSliceValue(value, p), name, shorthand, usage) +} + +// IPNetSlice defines a []net.IPNet flag with specified name, default value, and usage string. +// The return value is the address of a []net.IPNet variable that stores the value of that flag. +func (f *FlagSet) IPNetSlice(name string, value []net.IPNet, usage string) *[]net.IPNet { + p := []net.IPNet{} + f.IPNetSliceVarP(&p, name, "", value, usage) + return &p +} + +// IPNetSliceP is like IPNetSlice, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IPNetSliceP(name, shorthand string, value []net.IPNet, usage string) *[]net.IPNet { + p := []net.IPNet{} + f.IPNetSliceVarP(&p, name, shorthand, value, usage) + return &p +} + +// IPNetSlice defines a []net.IPNet flag with specified name, default value, and usage string. +// The return value is the address of a []net.IP variable that stores the value of the flag. +func IPNetSlice(name string, value []net.IPNet, usage string) *[]net.IPNet { + return CommandLine.IPNetSliceP(name, "", value, usage) +} + +// IPNetSliceP is like IPNetSlice, but accepts a shorthand letter that can be used after a single dash. +func IPNetSliceP(name, shorthand string, value []net.IPNet, usage string) *[]net.IPNet { + return CommandLine.IPNetSliceP(name, shorthand, value, usage) +} diff --git a/ipnet_slice_test.go b/ipnet_slice_test.go new file mode 100644 index 0000000..11644c5 --- /dev/null +++ b/ipnet_slice_test.go @@ -0,0 +1,239 @@ +package pflag + +import ( + "fmt" + "net" + "strings" + "testing" +) + +// Helper function to set static slices +func getCIDR(ip net.IP, cidr *net.IPNet, err error) net.IPNet { + return *cidr +} + +func equalCIDR(c1 net.IPNet, c2 net.IPNet) bool { + if c1.String() == c2.String() { + return true + } + return false +} + +func setUpIPNetFlagSet(ipsp *[]net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPNetSliceVar(ipsp, "cidrs", []net.IPNet{}, "Command separated list!") + return f +} + +func setUpIPNetFlagSetWithDefault(ipsp *[]net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPNetSliceVar(ipsp, "cidrs", + []net.IPNet{ + getCIDR(net.ParseCIDR("192.168.1.1/16")), + getCIDR(net.ParseCIDR("fd00::/64")), + }, + "Command separated list!") + return f +} + +func TestEmptyIPNet(t *testing.T) { + var cidrs []net.IPNet + f := setUpIPNetFlagSet(&cidrs) + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + getIPNet, err := f.GetIPNetSlice("cidrs") + if err != nil { + t.Fatal("got an error from GetIPNetSlice():", err) + } + if len(getIPNet) != 0 { + t.Fatalf("got ips %v with len=%d but expected length=0", getIPNet, len(getIPNet)) + } +} + +func TestIPNets(t *testing.T) { + var ips []net.IPNet + f := setUpIPNetFlagSet(&ips) + + vals := []string{"192.168.1.1/24", "10.0.0.1/16", "fd00:0:0:0:0:0:0:2/64"} + arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ips { + if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { + t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) + } else if !equalCIDR(*cidr, v) { + t.Fatalf("expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v) + } + } +} + +func TestIPNetDefault(t *testing.T) { + var cidrs []net.IPNet + f := setUpIPNetFlagSetWithDefault(&cidrs) + + vals := []string{"192.168.1.1/16", "fd00::/64"} + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range cidrs { + if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { + t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) + } else if !equalCIDR(*cidr, v) { + t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + } + } + + getIPNet, err := f.GetIPNetSlice("cidrs") + if err != nil { + t.Fatal("got an error from GetIPNetSlice") + } + for i, v := range getIPNet { + if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { + t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) + } else if !equalCIDR(*cidr, v) { + t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + } + } +} + +func TestIPNetWithDefault(t *testing.T) { + var cidrs []net.IPNet + f := setUpIPNetFlagSetWithDefault(&cidrs) + + vals := []string{"192.168.1.1/16", "fd00::/64"} + arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range cidrs { + if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { + t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) + } else if !equalCIDR(*cidr, v) { + t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + } + } + + getIPNet, err := f.GetIPNetSlice("cidrs") + if err != nil { + t.Fatal("got an error from GetIPNetSlice") + } + for i, v := range getIPNet { + if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { + t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) + } else if !equalCIDR(*cidr, v) { + t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + } + } +} + +func TestIPNetCalledTwice(t *testing.T) { + var cidrs []net.IPNet + f := setUpIPNetFlagSet(&cidrs) + + in := []string{"192.168.1.2/16,fd00::/64", "10.0.0.1/24"} + + expected := []net.IPNet{ + getCIDR(net.ParseCIDR("192.168.1.2/16")), + getCIDR(net.ParseCIDR("fd00::/64")), + getCIDR(net.ParseCIDR("10.0.0.1/24")), + } + argfmt := "--cidrs=%s" + arg1 := fmt.Sprintf(argfmt, in[0]) + arg2 := fmt.Sprintf(argfmt, in[1]) + err := f.Parse([]string{arg1, arg2}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range cidrs { + if !equalCIDR(expected[i], v) { + t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, expected[i], v) + } + } +} + +func TestIPNetBadQuoting(t *testing.T) { + + tests := []struct { + Want []net.IPNet + FlagArg []string + }{ + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128")), + getCIDR(net.ParseCIDR("203.107.49.208/32")), + getCIDR(net.ParseCIDR("14.57.204.90/32")), + }, + FlagArg: []string{ + "a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128", + "203.107.49.208/32", + "14.57.204.90/32", + }, + }, + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("204.228.73.195/32")), + getCIDR(net.ParseCIDR("86.141.15.94/32")), + }, + FlagArg: []string{ + "204.228.73.195/32", + "86.141.15.94/32", + }, + }, + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128")), + getCIDR(net.ParseCIDR("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128")), + }, + FlagArg: []string{ + "c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128", + "4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128", + }, + }, + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("5170:f971:cfac:7be3:512a:af37:952c:bc33/128")), + getCIDR(net.ParseCIDR("93.21.145.140/32")), + getCIDR(net.ParseCIDR("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128")), + }, + FlagArg: []string{ + " 5170:f971:cfac:7be3:512a:af37:952c:bc33/128 , 93.21.145.140/32 ", + "2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128", + }, + }, + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + }, + FlagArg: []string{ + `"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128 "`, + " 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128"}, + }, + } + + for i, test := range tests { + + var cidrs []net.IPNet + f := setUpIPNetFlagSet(&cidrs) + + if err := f.Parse([]string{fmt.Sprintf("--cidrs=%s", strings.Join(test.FlagArg, ","))}); err != nil { + t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s", + err, test.FlagArg, test.Want[i]) + } + + for j, b := range cidrs { + if !equalCIDR(b, test.Want[j]) { + t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b) + } + } + } +} diff --git a/string_array.go b/string_array.go index 4894af8..d1ff0a9 100644 --- a/string_array.go +++ b/string_array.go @@ -31,11 +31,7 @@ func (s *stringArrayValue) Append(val string) error { func (s *stringArrayValue) Replace(val []string) error { out := make([]string, len(val)) for i, d := range val { - var err error out[i] = d - if err != nil { - return err - } } *s.value = out return nil @@ -0,0 +1,81 @@ +package pflag + +import ( + "encoding" + "fmt" + "reflect" +) + +// following is copied from go 1.23.4 flag.go +type textValue struct{ p encoding.TextUnmarshaler } + +func newTextValue(val encoding.TextMarshaler, p encoding.TextUnmarshaler) textValue { + ptrVal := reflect.ValueOf(p) + if ptrVal.Kind() != reflect.Ptr { + panic("variable value type must be a pointer") + } + defVal := reflect.ValueOf(val) + if defVal.Kind() == reflect.Ptr { + defVal = defVal.Elem() + } + if defVal.Type() != ptrVal.Type().Elem() { + panic(fmt.Sprintf("default type does not match variable type: %v != %v", defVal.Type(), ptrVal.Type().Elem())) + } + ptrVal.Elem().Set(defVal) + return textValue{p} +} + +func (v textValue) Set(s string) error { + return v.p.UnmarshalText([]byte(s)) +} + +func (v textValue) Get() interface{} { + return v.p +} + +func (v textValue) String() string { + if m, ok := v.p.(encoding.TextMarshaler); ok { + if b, err := m.MarshalText(); err == nil { + return string(b) + } + } + return "" +} + +//end of copy + +func (v textValue) Type() string { + return reflect.ValueOf(v.p).Type().Name() +} + +// GetText set out, which implements encoding.UnmarshalText, to the value of a flag with given name +func (f *FlagSet) GetText(name string, out encoding.TextUnmarshaler) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag accessed but not defined: %s", name) + } + if flag.Value.Type() != reflect.TypeOf(out).Name() { + return fmt.Errorf("trying to get %s value of flag of type %s", reflect.TypeOf(out).Name(), flag.Value.Type()) + } + return out.UnmarshalText([]byte(flag.Value.String())) +} + +// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p. +func (f *FlagSet) TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) { + f.VarP(newTextValue(value, p), name, "", usage) +} + +// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) { + f.VarP(newTextValue(value, p), name, shorthand, usage) +} + +// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p. +func TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) { + CommandLine.VarP(newTextValue(value, p), name, "", usage) +} + +// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash. +func TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) { + CommandLine.VarP(newTextValue(value, p), name, shorthand, usage) +} diff --git a/text_test.go b/text_test.go new file mode 100644 index 0000000..e60c136 --- /dev/null +++ b/text_test.go @@ -0,0 +1,56 @@ +package pflag + +import ( + "fmt" + "os" + "testing" + "time" +) + +func setUpTime(t *time.Time) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.TextVar(t, "time", time.Now(), "time stamp") + return f +} + +func TestText(t *testing.T) { + testCases := []struct { + input string + success bool + expected time.Time + }{ + {"2003-01-02T15:04:05Z", true, time.Date(2003, 1, 2, 15, 04, 05, 0, time.UTC)}, + {"2003-01-02 15:05:01", false, time.Time{}}, //negative case, invalid layout + {"2024-11-22T03:01:02Z", true, time.Date(2024, 11, 22, 3, 1, 02, 0, time.UTC)}, + {"2006-01-02T15:04:05+07:00", true, time.Date(2006, 1, 2, 15, 4, 5, 0, time.FixedZone("UTC+7", 7*60*60))}, + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + for i := range testCases { + var ts time.Time + f := setUpTime(&ts) + tc := &testCases[i] + arg := fmt.Sprintf("--time=%s", tc.input) + err := f.Parse([]string{arg}) + if err != nil { + if tc.success { + t.Errorf("expected parsing to succeed, but got %q", err) + } + continue + } + if !tc.success { + t.Errorf("expected parsing failure, but parsing succeeded") + continue + } + parsedT := new(time.Time) + err = f.GetText("time", parsedT) + if err != nil { + t.Errorf("Got error trying to fetch the time flag: %v", err) + } + if !parsedT.Equal(tc.expected) { + t.Errorf("expected %q, got %q", tc.expected, parsedT) + } + + } +} @@ -0,0 +1,118 @@ +package pflag + +import ( + "fmt" + "strings" + "time" +) + +// TimeValue adapts time.Time for use as a flag. +type timeValue struct { + *time.Time + formats []string +} + +func newTimeValue(val time.Time, p *time.Time, formats []string) *timeValue { + *p = val + return &timeValue{ + Time: p, + formats: formats, + } +} + +// Set time.Time value from string based on accepted formats. +func (d *timeValue) Set(s string) error { + s = strings.TrimSpace(s) + for _, f := range d.formats { + v, err := time.Parse(f, s) + if err != nil { + continue + } + *d.Time = v + return nil + } + + formatsString := "" + for i, f := range d.formats { + if i > 0 { + formatsString += ", " + } + formatsString += fmt.Sprintf("`%s`", f) + } + + return fmt.Errorf("invalid time format `%s` must be one of: %s", s, formatsString) +} + +// Type name for time.Time flags. +func (d *timeValue) Type() string { + return "time" +} + +func (d *timeValue) String() string { return d.Time.Format(time.RFC3339Nano) } + +// GetTime return the time value of a flag with the given name +func (f *FlagSet) GetTime(name string) (time.Time, error) { + flag := f.Lookup(name) + if flag == nil { + err := fmt.Errorf("flag accessed but not defined: %s", name) + return time.Time{}, err + } + + if flag.Value.Type() != "time" { + err := fmt.Errorf("trying to get %s value of flag of type %s", "time", flag.Value.Type()) + return time.Time{}, err + } + + val, ok := flag.Value.(*timeValue) + if !ok { + return time.Time{}, fmt.Errorf("value %s is not a time", flag.Value) + } + + return *val.Time, nil +} + +// TimeVar defines a time.Time flag with specified name, default value, and usage string. +// The argument p points to a time.Time variable in which to store the value of the flag. +func (f *FlagSet) TimeVar(p *time.Time, name string, value time.Time, formats []string, usage string) { + f.TimeVarP(p, name, "", value, formats, usage) +} + +// TimeVarP is like TimeVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) TimeVarP(p *time.Time, name, shorthand string, value time.Time, formats []string, usage string) { + f.VarP(newTimeValue(value, p, formats), name, shorthand, usage) +} + +// TimeVar defines a time.Time flag with specified name, default value, and usage string. +// The argument p points to a time.Time variable in which to store the value of the flag. +func TimeVar(p *time.Time, name string, value time.Time, formats []string, usage string) { + CommandLine.TimeVarP(p, name, "", value, formats, usage) +} + +// TimeVarP is like TimeVar, but accepts a shorthand letter that can be used after a single dash. +func TimeVarP(p *time.Time, name, shorthand string, value time.Time, formats []string, usage string) { + CommandLine.VarP(newTimeValue(value, p, formats), name, shorthand, usage) +} + +// Time defines a time.Time flag with specified name, default value, and usage string. +// The return value is the address of a time.Time variable that stores the value of the flag. +func (f *FlagSet) Time(name string, value time.Time, formats []string, usage string) *time.Time { + return f.TimeP(name, "", value, formats, usage) +} + +// TimeP is like Time, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) TimeP(name, shorthand string, value time.Time, formats []string, usage string) *time.Time { + p := new(time.Time) + f.TimeVarP(p, name, shorthand, value, formats, usage) + return p +} + +// Time defines a time.Time flag with specified name, default value, and usage string. +// The return value is the address of a time.Time variable that stores the value of the flag. +func Time(name string, value time.Time, formats []string, usage string) *time.Time { + return CommandLine.TimeP(name, "", value, formats, usage) +} + +// TimeP is like Time, but accepts a shorthand letter that can be used after a single dash. +func TimeP(name, shorthand string, value time.Time, formats []string, usage string) *time.Time { + return CommandLine.TimeP(name, shorthand, value, formats, usage) +} diff --git a/time_test.go b/time_test.go new file mode 100644 index 0000000..46a5ada --- /dev/null +++ b/time_test.go @@ -0,0 +1,62 @@ +package pflag + +import ( + "fmt" + "testing" + "time" +) + +func setUpTimeVar(t *time.Time, formats []string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.TimeVar(t, "time", time.Time{}, formats, "Time") + return f +} + +func TestTime(t *testing.T) { + testCases := []struct { + input string + success bool + expected time.Time + }{ + {"2022-01-01T01:01:01+00:00", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.UTC)}, + {" 2022-01-01T01:01:01+00:00", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.UTC)}, + {"2022-01-01T01:01:01+00:00 ", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.UTC)}, + {"2022-01-01T01:01:01+02:00", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.FixedZone("UTC+2", 2*60*60))}, + {"2022-01-01T01:01:01.01+02:00", true, time.Date(2022, 1, 1, 1, 1, 1, 10000000, time.FixedZone("UTC+2", 2*60*60))}, + {"Sat, 01 Jan 2022 01:01:01 +0000", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.UTC)}, + {"Sat, 01 Jan 2022 01:01:01 +0200", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.FixedZone("UTC+2", 2*60*60))}, + {"Sat, 01 Jan 2022 01:01:01 +0000", true, time.Date(2022, 1, 1, 1, 1, 1, 0, time.UTC)}, + {"", false, time.Time{}}, + {"not a date", false, time.Time{}}, + {"2022-01-01 01:01:01", false, time.Time{}}, + {"2022-01-01T01:01:01", false, time.Time{}}, + {"01 Jan 2022 01:01:01 +0000", false, time.Time{}}, + {"Sat, 01 Jan 2022 01:01:01", false, time.Time{}}, + } + + for i := range testCases { + var timeVar time.Time + formats := []string{time.RFC3339Nano, time.RFC1123Z} + f := setUpTimeVar(&timeVar, formats) + + tc := &testCases[i] + + arg := fmt.Sprintf("--time=%s", tc.input) + err := f.Parse([]string{arg}) + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + t.Errorf("expected failure") + continue + } else if tc.success { + timeResult, err := f.GetTime("time") + if err != nil { + t.Errorf("Got error trying to fetch the Time flag: %v", err) + } + if !timeResult.Equal(tc.expected) { + t.Errorf("expected %q, got %q", tc.expected.Format(time.RFC3339Nano), timeVar.Format(time.RFC3339Nano)) + } + } + } +} |
