aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomas Aschan <[email protected]>2025-04-22 17:47:47 +0200
committerGitHub <[email protected]>2025-04-22 17:47:47 +0200
commitd661846b4df0d6611ad95577ddfb240474d21b7c (patch)
tree319f3221c52fb10c2e870d50cdfdc531c22eb4e4
parent19c9c4072e41218b18b93dbfc3798c18835d2fd5 (diff)
parent6ca66b16cbe1b365ce9a6c56faf9b04acb8d8035 (diff)
Merge pull request #425 from eth-p/error-structs
feat: Use structs for errors returned by pflag.
-rw-r--r--errors.go149
-rw-r--r--errors_test.go67
-rw-r--r--flag.go52
-rw-r--r--flag_test.go59
4 files changed, 303 insertions, 24 deletions
diff --git a/errors.go b/errors.go
new file mode 100644
index 0000000..ff11b66
--- /dev/null
+++ b/errors.go
@@ -0,0 +1,149 @@
+package pflag
+
+import "fmt"
+
+// notExistErrorMessageType specifies which flavor of "flag does not exist"
+// is printed by NotExistError. This allows the related errors to be grouped
+// under a single NotExistError struct without making a breaking change to
+// the error message text.
+type notExistErrorMessageType int
+
+const (
+ flagNotExistMessage notExistErrorMessageType = iota
+ flagNotDefinedMessage
+ flagNoSuchFlagMessage
+ flagUnknownFlagMessage
+ flagUnknownShorthandFlagMessage
+)
+
+// NotExistError is the error returned when trying to access a flag that
+// does not exist in the FlagSet.
+type NotExistError struct {
+ name string
+ specifiedShorthands string
+ messageType notExistErrorMessageType
+}
+
+// Error implements error.
+func (e *NotExistError) Error() string {
+ switch e.messageType {
+ case flagNotExistMessage:
+ return fmt.Sprintf("flag %q does not exist", e.name)
+
+ case flagNotDefinedMessage:
+ return fmt.Sprintf("flag accessed but not defined: %s", e.name)
+
+ case flagNoSuchFlagMessage:
+ return fmt.Sprintf("no such flag -%v", e.name)
+
+ case flagUnknownFlagMessage:
+ return fmt.Sprintf("unknown flag: --%s", e.name)
+
+ case flagUnknownShorthandFlagMessage:
+ c := rune(e.name[0])
+ return fmt.Sprintf("unknown shorthand flag: %q in -%s", c, e.specifiedShorthands)
+ }
+
+ panic(fmt.Errorf("unknown flagNotExistErrorMessageType: %v", e.messageType))
+}
+
+// GetSpecifiedName returns the name of the flag (without dashes) as it
+// appeared in the parsed arguments.
+func (e *NotExistError) GetSpecifiedName() string {
+ return e.name
+}
+
+// GetSpecifiedShortnames returns the group of shorthand arguments
+// (without dashes) that the flag appeared within. If the flag was not in a
+// shorthand group, this will return an empty string.
+func (e *NotExistError) GetSpecifiedShortnames() string {
+ return e.specifiedShorthands
+}
+
+// ValueRequiredError is the error returned when a flag needs an argument but
+// no argument was provided.
+type ValueRequiredError struct {
+ flag *Flag
+ specifiedName string
+ specifiedShorthands string
+}
+
+// Error implements error.
+func (e *ValueRequiredError) Error() string {
+ if len(e.specifiedShorthands) > 0 {
+ c := rune(e.specifiedName[0])
+ return fmt.Sprintf("flag needs an argument: %q in -%s", c, e.specifiedShorthands)
+ }
+
+ return fmt.Sprintf("flag needs an argument: --%s", e.specifiedName)
+}
+
+// GetFlag returns the flag for which the error occurred.
+func (e *ValueRequiredError) GetFlag() *Flag {
+ return e.flag
+}
+
+// GetSpecifiedName returns the name of the flag (without dashes) as it
+// appeared in the parsed arguments.
+func (e *ValueRequiredError) GetSpecifiedName() string {
+ return e.specifiedName
+}
+
+// GetSpecifiedShortnames returns the group of shorthand arguments
+// (without dashes) that the flag appeared within. If the flag was not in a
+// shorthand group, this will return an empty string.
+func (e *ValueRequiredError) GetSpecifiedShortnames() string {
+ return e.specifiedShorthands
+}
+
+// InvalidValueError is the error returned when an invalid value is used
+// for a flag.
+type InvalidValueError struct {
+ flag *Flag
+ value string
+ cause error
+}
+
+// Error implements error.
+func (e *InvalidValueError) Error() string {
+ flag := e.flag
+ var flagName string
+ if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
+ flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name)
+ } else {
+ flagName = fmt.Sprintf("--%s", flag.Name)
+ }
+ return fmt.Sprintf("invalid argument %q for %q flag: %v", e.value, flagName, e.cause)
+}
+
+// Unwrap implements errors.Unwrap.
+func (e *InvalidValueError) Unwrap() error {
+ return e.cause
+}
+
+// GetFlag returns the flag for which the error occurred.
+func (e *InvalidValueError) GetFlag() *Flag {
+ return e.flag
+}
+
+// GetValue returns the invalid value that was provided.
+func (e *InvalidValueError) GetValue() string {
+ return e.value
+}
+
+// InvalidSyntaxError is the error returned when a bad flag name is passed on
+// the command line.
+type InvalidSyntaxError struct {
+ specifiedFlag string
+}
+
+// Error implements error.
+func (e *InvalidSyntaxError) Error() string {
+ return fmt.Sprintf("bad flag syntax: %s", e.specifiedFlag)
+}
+
+// GetSpecifiedName returns the exact flag (with dashes) as it
+// appeared in the parsed arguments.
+func (e *InvalidSyntaxError) GetSpecifiedFlag() string {
+ return e.specifiedFlag
+}
diff --git a/errors_test.go b/errors_test.go
new file mode 100644
index 0000000..7b4c7a4
--- /dev/null
+++ b/errors_test.go
@@ -0,0 +1,67 @@
+package pflag
+
+import (
+ "errors"
+ "testing"
+)
+
+func TestNotExistError(t *testing.T) {
+ err := &NotExistError{
+ name: "foo",
+ specifiedShorthands: "bar",
+ }
+
+ if err.GetSpecifiedName() != "foo" {
+ t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName())
+ }
+ if err.GetSpecifiedShortnames() != "bar" {
+ t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames())
+ }
+}
+
+func TestValueRequiredError(t *testing.T) {
+ err := &ValueRequiredError{
+ flag: &Flag{},
+ specifiedName: "foo",
+ specifiedShorthands: "bar",
+ }
+
+ if err.GetFlag() == nil {
+ t.Error("Expected GetSpecifiedName to return its flag field, but got nil")
+ }
+ if err.GetSpecifiedName() != "foo" {
+ t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName())
+ }
+ if err.GetSpecifiedShortnames() != "bar" {
+ t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames())
+ }
+}
+
+func TestInvalidValueError(t *testing.T) {
+ expectedCause := errors.New("error")
+ err := &InvalidValueError{
+ flag: &Flag{},
+ value: "foo",
+ cause: expectedCause,
+ }
+
+ if err.GetFlag() == nil {
+ t.Error("Expected GetSpecifiedName to return its flag field, but got nil")
+ }
+ if err.GetValue() != "foo" {
+ t.Errorf("Expected GetValue to return %q, got %q", "foo", err.GetValue())
+ }
+ if err.Unwrap() != expectedCause {
+ t.Errorf("Expected Unwrwap to return %q, got %q", expectedCause, err.Unwrap())
+ }
+}
+
+func TestInvalidSyntaxError(t *testing.T) {
+ err := &InvalidSyntaxError{
+ specifiedFlag: "--=",
+ }
+
+ if err.GetSpecifiedFlag() != "--=" {
+ t.Errorf("Expected GetSpecifiedFlag to return %q, got %q", "--=", err.GetSpecifiedFlag())
+ }
+}
diff --git a/flag.go b/flag.go
index 4bdbd0c..80bd580 100644
--- a/flag.go
+++ b/flag.go
@@ -381,7 +381,7 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag {
func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) {
flag := f.Lookup(name)
if flag == nil {
- err := fmt.Errorf("flag accessed but not defined: %s", name)
+ err := &NotExistError{name: name, messageType: flagNotDefinedMessage}
return nil, err
}
@@ -411,7 +411,7 @@ func (f *FlagSet) ArgsLenAtDash() int {
func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
if flag == nil {
- return fmt.Errorf("flag %q does not exist", name)
+ return &NotExistError{name: name, messageType: flagNotExistMessage}
}
if usageMessage == "" {
return fmt.Errorf("deprecated message for flag %q must be set", name)
@@ -427,7 +427,7 @@ func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
if flag == nil {
- return fmt.Errorf("flag %q does not exist", name)
+ return &NotExistError{name: name, messageType: flagNotExistMessage}
}
if usageMessage == "" {
return fmt.Errorf("deprecated message for flag %q must be set", name)
@@ -441,7 +441,7 @@ func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) erro
func (f *FlagSet) MarkHidden(name string) error {
flag := f.Lookup(name)
if flag == nil {
- return fmt.Errorf("flag %q does not exist", name)
+ return &NotExistError{name: name, messageType: flagNotExistMessage}
}
flag.Hidden = true
return nil
@@ -464,18 +464,16 @@ func (f *FlagSet) Set(name, value string) error {
normalName := f.normalizeFlagName(name)
flag, ok := f.formal[normalName]
if !ok {
- return fmt.Errorf("no such flag -%v", name)
+ return &NotExistError{name: name, messageType: flagNoSuchFlagMessage}
}
err := flag.Value.Set(value)
if err != nil {
- var flagName string
- if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
- flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name)
- } else {
- flagName = fmt.Sprintf("--%s", flag.Name)
+ return &InvalidValueError{
+ flag: flag,
+ value: value,
+ cause: err,
}
- return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err)
}
if !flag.Changed {
@@ -501,7 +499,7 @@ func (f *FlagSet) SetAnnotation(name, key string, values []string) error {
normalName := f.normalizeFlagName(name)
flag, ok := f.formal[normalName]
if !ok {
- return fmt.Errorf("no such flag -%v", name)
+ return &NotExistError{name: name, messageType: flagNoSuchFlagMessage}
}
if flag.Annotations == nil {
flag.Annotations = map[string][]string{}
@@ -911,10 +909,9 @@ func VarP(value Value, name, shorthand, usage string) {
CommandLine.VarP(value, name, shorthand, usage)
}
-// failf prints to standard error a formatted error and usage message and
+// fail prints an error message and usage message to standard error and
// returns the error.
-func (f *FlagSet) failf(format string, a ...interface{}) error {
- err := fmt.Errorf(format, a...)
+func (f *FlagSet) fail(err error) error {
if f.errorHandling != ContinueOnError {
fmt.Fprintln(f.Output(), err)
f.usage()
@@ -960,7 +957,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
a = args
name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
- err = f.failf("bad flag syntax: %s", s)
+ err = f.fail(&InvalidSyntaxError{specifiedFlag: s})
return
}
@@ -982,7 +979,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
return stripUnknownFlagValue(a), nil
default:
- err = f.failf("unknown flag: --%s", name)
+ err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage})
return
}
}
@@ -1000,13 +997,16 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
a = a[1:]
} else {
// '--flag' (arg was required)
- err = f.failf("flag needs an argument: %s", s)
+ err = f.fail(&ValueRequiredError{
+ flag: flag,
+ specifiedName: name,
+ })
return
}
err = fn(flag, value)
if err != nil {
- f.failf(err.Error())
+ f.fail(err)
}
return
}
@@ -1039,7 +1039,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
outArgs = stripUnknownFlagValue(outArgs)
return
default:
- err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
+ err = f.fail(&NotExistError{
+ name: string(c),
+ specifiedShorthands: shorthands,
+ messageType: flagUnknownShorthandFlagMessage,
+ })
return
}
}
@@ -1062,7 +1066,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
outArgs = args[1:]
} else {
// '-f' (arg was required)
- err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
+ err = f.fail(&ValueRequiredError{
+ flag: flag,
+ specifiedName: string(c),
+ specifiedShorthands: shorthands,
+ })
return
}
@@ -1072,7 +1080,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
err = fn(flag, value)
if err != nil {
- f.failf(err.Error())
+ f.fail(err)
}
return
}
diff --git a/flag_test.go b/flag_test.go
index 0dbe874..aa2f434 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -103,9 +103,14 @@ func TestEverything(t *testing.T) {
func TestUsage(t *testing.T) {
called := false
ResetForTesting(func() { called = true })
- if GetCommandLine().Parse([]string{"--x"}) == nil {
+ err := GetCommandLine().Parse([]string{"--x"})
+ expectedErr := "unknown flag: --x"
+ if err == nil {
t.Error("parse did not fail for unknown flag")
}
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
if called {
t.Error("did call Usage while using ContinueOnError")
}
@@ -131,9 +136,14 @@ func TestAddFlagSet(t *testing.T) {
func TestAnnotation(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
- if err := f.SetAnnotation("missing-flag", "key", nil); err == nil {
+ err := f.SetAnnotation("missing-flag", "key", nil)
+ expectedErr := "no such flag -missing-flag"
+ if err == nil {
t.Errorf("Expected error setting annotation on non-existent flag")
}
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
f.StringP("stringa", "a", "", "string value")
if err := f.SetAnnotation("stringa", "key", nil); err != nil {
@@ -349,6 +359,33 @@ func testParse(f *FlagSet, t *testing.T) {
} else if f.Args()[0] != extra {
t.Errorf("expected argument %q got %q", extra, f.Args()[0])
}
+ // Test unknown
+ err := f.Parse([]string{"--unknown"})
+ expectedErr := "unknown flag: --unknown"
+ if err == nil {
+ t.Error("parse did not fail for unknown flag")
+ }
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
+ // Test invalid
+ err = f.Parse([]string{"--bool=abcdefg"})
+ expectedErr = `invalid argument "abcdefg" for "--bool" flag: strconv.ParseBool: parsing "abcdefg": invalid syntax`
+ if err == nil {
+ t.Error("parse did not fail for invalid argument")
+ }
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
+ // Test required
+ err = f.Parse([]string{"--int"})
+ expectedErr = `flag needs an argument: --int`
+ if err == nil {
+ t.Error("parse did not fail for missing argument")
+ }
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
}
func testParseAll(f *FlagSet, t *testing.T) {
@@ -538,6 +575,24 @@ func TestShorthand(t *testing.T) {
if f.ArgsLenAtDash() != 1 {
t.Errorf("expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1)
}
+ // Test unknown
+ err := f.Parse([]string{"-ukn"})
+ expectedErr := "unknown shorthand flag: 'u' in -ukn"
+ if err == nil {
+ t.Error("parse did not fail for unknown shorthand flag")
+ }
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
+ // Test required
+ err = f.Parse([]string{"-as"})
+ expectedErr = `flag needs an argument: 's' in -s`
+ if err == nil {
+ t.Error("parse did not fail for missing argument")
+ }
+ if err.Error() != expectedErr {
+ t.Errorf("expected error %q, got %q", expectedErr, err.Error())
+ }
}
func TestShorthandLookup(t *testing.T) {