aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Paris <[email protected]>2015-05-30 21:07:46 -0400
committerEric Paris <[email protected]>2015-06-01 18:45:34 -0400
commit1e0a23de9163cb0706137856f1d060f88e3f277c (patch)
treede398763dd538847a8c32d4073ce07d5fe8bb210
parent5644820622454e71517561946e3d94b9f9db6842 (diff)
Add new FlagSet.Get{Int,String,...} accessor functions
If I declared a bool flag named "hello" I can now call b, err := f.GetBool("hello") And b will hold the value of the flag We can see this is already done in https://github.com/codegangsta/cli/blob/bcec9b08c7e5564f7512ad7e7b03778fe1923116/context.go If people use the codegangsta/cli Other projects have done it themselves using pflags (what inspired this patch) https://github.com/GoogleCloudPlatform/kubernetes/blob/cd817aebd848facda29e0befbbd6e31bf22402e6/pkg/kubectl/cmd/util/helpers.go#L176 Lets just do it ourselves...
-rw-r--r--bool.go13
-rw-r--r--duration.go17
-rw-r--r--flag.go21
-rw-r--r--flag_test.go57
-rw-r--r--float32.go17
-rw-r--r--float64.go13
-rw-r--r--int.go13
-rw-r--r--int32.go17
-rw-r--r--int64.go13
-rw-r--r--int8.go17
-rw-r--r--ip.go17
-rw-r--r--ipmask.go38
-rw-r--r--string.go13
-rw-r--r--uint.go17
-rw-r--r--uint16.go17
-rw-r--r--uint32.go17
-rw-r--r--uint64.go17
-rw-r--r--uint8.go17
18 files changed, 346 insertions, 5 deletions
diff --git a/bool.go b/bool.go
index 70e2e0a..ead8dba 100644
--- a/bool.go
+++ b/bool.go
@@ -34,6 +34,19 @@ func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
func (b *boolValue) IsBoolFlag() bool { return true }
+func boolConv(sval string) (interface{}, error) {
+ return strconv.ParseBool(sval)
+}
+
+// GetBool return the bool value of a flag with the given name
+func (f *FlagSet) GetBool(name string) (bool, error) {
+ val, err := f.getFlagType(name, "bool", boolConv)
+ if err != nil {
+ return false, err
+ }
+ return val.(bool), nil
+}
+
// BoolVar defines a bool flag with specified name, default value, and usage string.
// The argument p points to a bool variable in which to store the value of the flag.
func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) {
diff --git a/duration.go b/duration.go
index 66ed7ac..382ffd3 100644
--- a/duration.go
+++ b/duration.go
@@ -1,6 +1,8 @@
package pflag
-import "time"
+import (
+ "time"
+)
// -- time.Duration Value
type durationValue time.Duration
@@ -22,6 +24,19 @@ func (d *durationValue) Type() string {
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
+func durationConv(sval string) (interface{}, error) {
+ return time.ParseDuration(sval)
+}
+
+// GetDuration return the duration value of a flag with the given name
+func (f *FlagSet) GetDuration(name string) (time.Duration, error) {
+ val, err := f.getFlagType(name, "duration", durationConv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(time.Duration), nil
+}
+
// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
// The argument p points to a time.Duration variable in which to store the value of the flag.
func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
diff --git a/flag.go b/flag.go
index 0070b93..534fce4 100644
--- a/flag.go
+++ b/flag.go
@@ -257,6 +257,27 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag {
return f.formal[name]
}
+// func to return a given type for a given flag name
+func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) {
+ flag := f.Lookup(name)
+ if flag == nil {
+ err := fmt.Errorf("flag accessed but not defined: %s\n", name)
+ return nil, err
+ }
+
+ if flag.Value.Type() != ftype {
+ err := fmt.Errorf("trying to get %s value of flag of type %s\n", ftype, flag.Value.Type())
+ return nil, err
+ }
+
+ sval := flag.Value.String()
+ result, err := convFunc(sval)
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
// Mark a flag deprecated in your program
func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
diff --git a/flag_test.go b/flag_test.go
index d3c1714..e3c5c97 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -160,6 +160,9 @@ func testParse(f *FlagSet, t *testing.T) {
if *boolFlag != true {
t.Error("bool flag should be true, is ", *boolFlag)
}
+ if v, err := f.GetBool("bool"); err != nil || v != *boolFlag {
+ t.Error("GetBool does not work.")
+ }
if *bool2Flag != true {
t.Error("bool2 flag should be true, is ", *bool2Flag)
}
@@ -169,48 +172,96 @@ func testParse(f *FlagSet, t *testing.T) {
if *intFlag != 22 {
t.Error("int flag should be 22, is ", *intFlag)
}
+ if v, err := f.GetInt("int"); err != nil || v != *intFlag {
+ t.Error("GetInt does not work.")
+ }
if *int8Flag != -8 {
t.Error("int8 flag should be 0x23, is ", *int8Flag)
}
+ if v, err := f.GetInt8("int8"); err != nil || v != *int8Flag {
+ t.Error("GetInt8 does not work.")
+ }
if *int32Flag != -32 {
t.Error("int32 flag should be 0x23, is ", *int32Flag)
}
+ if v, err := f.GetInt32("int32"); err != nil || v != *int32Flag {
+ t.Error("GetInt32 does not work.")
+ }
if *int64Flag != 0x23 {
t.Error("int64 flag should be 0x23, is ", *int64Flag)
}
+ if v, err := f.GetInt64("int64"); err != nil || v != *int64Flag {
+ t.Error("GetInt64 does not work.")
+ }
if *uintFlag != 24 {
t.Error("uint flag should be 24, is ", *uintFlag)
}
+ if v, err := f.GetUint("uint"); err != nil || v != *uintFlag {
+ t.Error("GetUint does not work.")
+ }
if *uint8Flag != 8 {
t.Error("uint8 flag should be 8, is ", *uint8Flag)
}
+ if v, err := f.GetUint8("uint8"); err != nil || v != *uint8Flag {
+ t.Error("GetUint8 does not work.")
+ }
if *uint16Flag != 16 {
t.Error("uint16 flag should be 16, is ", *uint16Flag)
}
+ if v, err := f.GetUint16("uint16"); err != nil || v != *uint16Flag {
+ t.Error("GetUint16 does not work.")
+ }
if *uint32Flag != 32 {
t.Error("uint32 flag should be 32, is ", *uint32Flag)
}
+ if v, err := f.GetUint32("uint32"); err != nil || v != *uint32Flag {
+ t.Error("GetUint32 does not work.")
+ }
if *uint64Flag != 25 {
t.Error("uint64 flag should be 25, is ", *uint64Flag)
}
+ if v, err := f.GetUint64("uint64"); err != nil || v != *uint64Flag {
+ t.Error("GetUint64 does not work.")
+ }
if *stringFlag != "hello" {
t.Error("string flag should be `hello`, is ", *stringFlag)
}
+ if v, err := f.GetString("string"); err != nil || v != *stringFlag {
+ t.Error("GetString does not work.")
+ }
if *float32Flag != -172e12 {
- t.Error("float64 flag should be -172e12, is ", *float64Flag)
+ t.Error("float32 flag should be -172e12, is ", *float32Flag)
+ }
+ if v, err := f.GetFloat32("float32"); err != nil || v != *float32Flag {
+ t.Errorf("GetFloat32 returned %v but float32Flag was %v", v, *float32Flag)
}
if *float64Flag != 2718e28 {
t.Error("float64 flag should be 2718e28, is ", *float64Flag)
}
- if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() {
- t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String())
+ if v, err := f.GetFloat64("float64"); err != nil || v != *float64Flag {
+ t.Errorf("GetFloat64 returned %v but float64Flag was %v", v, *float64Flag)
}
if !(*ipFlag).Equal(net.ParseIP("10.11.12.13")) {
t.Error("ip flag should be 10.11.12.13, is ", *ipFlag)
}
+ if v, err := f.GetIP("ip"); err != nil || !v.Equal(*ipFlag) {
+ t.Errorf("GetIP returned %v but ipFlag was %v", v, *ipFlag)
+ }
+ if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() {
+ t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String())
+ }
+ if v, err := f.GetIPv4Mask("mask"); err != nil || v.String() != (*maskFlag).String() {
+ t.Errorf("GetIP returned %v but maskFlag was %v", v, *maskFlag, err)
+ }
if *durationFlag != 2*time.Minute {
t.Error("duration flag should be 2m, is ", *durationFlag)
}
+ if v, err := f.GetDuration("duration"); err != nil || v != *durationFlag {
+ t.Error("GetDuration does not work.")
+ }
+ if _, err := f.GetInt("duration"); err == nil {
+ t.Error("GetInt parsed a time.Duration?!?!")
+ }
if len(f.Args()) != 1 {
t.Error("expected one argument, got", len(f.Args()))
} else if f.Args()[0] != extra {
diff --git a/float32.go b/float32.go
index b7ad67d..30174cb 100644
--- a/float32.go
+++ b/float32.go
@@ -25,6 +25,23 @@ func (f *float32Value) Type() string {
func (f *float32Value) String() string { return fmt.Sprintf("%v", *f) }
+func float32Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseFloat(sval, 32)
+ if err != nil {
+ return 0, err
+ }
+ return float32(v), nil
+}
+
+// GetFloat32 return the float32 value of a flag with the given name
+func (f *FlagSet) GetFloat32(name string) (float32, error) {
+ val, err := f.getFlagType(name, "float32", float32Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(float32), nil
+}
+
// Float32Var defines a float32 flag with specified name, default value, and usage string.
// The argument p points to a float32 variable in which to store the value of the flag.
func (f *FlagSet) Float32Var(p *float32, name string, value float32, usage string) {
diff --git a/float64.go b/float64.go
index 0315512..10e17e4 100644
--- a/float64.go
+++ b/float64.go
@@ -25,6 +25,19 @@ func (f *float64Value) Type() string {
func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) }
+func float64Conv(sval string) (interface{}, error) {
+ return strconv.ParseFloat(sval, 64)
+}
+
+// GetFloat64 return the float64 value of a flag with the given name
+func (f *FlagSet) GetFloat64(name string) (float64, error) {
+ val, err := f.getFlagType(name, "float64", float64Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(float64), nil
+}
+
// Float64Var defines a float64 flag with specified name, default value, and usage string.
// The argument p points to a float64 variable in which to store the value of the flag.
func (f *FlagSet) Float64Var(p *float64, name string, value float64, usage string) {
diff --git a/int.go b/int.go
index dca9da6..23f70dd 100644
--- a/int.go
+++ b/int.go
@@ -25,6 +25,19 @@ func (i *intValue) Type() string {
func (i *intValue) String() string { return fmt.Sprintf("%v", *i) }
+func intConv(sval string) (interface{}, error) {
+ return strconv.Atoi(sval)
+}
+
+// GetInt return the int value of a flag with the given name
+func (f *FlagSet) GetInt(name string) (int, error) {
+ val, err := f.getFlagType(name, "int", intConv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int), nil
+}
+
// IntVar defines an int flag with specified name, default value, and usage string.
// The argument p points to an int variable in which to store the value of the flag.
func (f *FlagSet) IntVar(p *int, name string, value int, usage string) {
diff --git a/int32.go b/int32.go
index 18eaacd..515f90b 100644
--- a/int32.go
+++ b/int32.go
@@ -25,6 +25,23 @@ func (i *int32Value) Type() string {
func (i *int32Value) String() string { return fmt.Sprintf("%v", *i) }
+func int32Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseInt(sval, 0, 32)
+ if err != nil {
+ return 0, err
+ }
+ return int32(v), nil
+}
+
+// GetInt32 return the int32 value of a flag with the given name
+func (f *FlagSet) GetInt32(name string) (int32, error) {
+ val, err := f.getFlagType(name, "int32", int32Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int32), nil
+}
+
// Int32Var defines an int32 flag with specified name, default value, and usage string.
// The argument p points to an int32 variable in which to store the value of the flag.
func (f *FlagSet) Int32Var(p *int32, name string, value int32, usage string) {
diff --git a/int64.go b/int64.go
index 0114aaa..b77ade4 100644
--- a/int64.go
+++ b/int64.go
@@ -25,6 +25,19 @@ func (i *int64Value) Type() string {
func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) }
+func int64Conv(sval string) (interface{}, error) {
+ return strconv.ParseInt(sval, 0, 64)
+}
+
+// GetInt64 return the int64 value of a flag with the given name
+func (f *FlagSet) GetInt64(name string) (int64, error) {
+ val, err := f.getFlagType(name, "int64", int64Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int64), nil
+}
+
// Int64Var defines an int64 flag with specified name, default value, and usage string.
// The argument p points to an int64 variable in which to store the value of the flag.
func (f *FlagSet) Int64Var(p *int64, name string, value int64, usage string) {
diff --git a/int8.go b/int8.go
index aab1022..c51cb4f 100644
--- a/int8.go
+++ b/int8.go
@@ -25,6 +25,23 @@ func (i *int8Value) Type() string {
func (i *int8Value) String() string { return fmt.Sprintf("%v", *i) }
+func int8Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseInt(sval, 0, 8)
+ if err != nil {
+ return 0, err
+ }
+ return int8(v), nil
+}
+
+// GetInt8 return the int8 value of a flag with the given name
+func (f *FlagSet) GetInt8(name string) (int8, error) {
+ val, err := f.getFlagType(name, "int8", int8Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int8), nil
+}
+
// Int8Var defines an int8 flag with specified name, default value, and usage string.
// The argument p points to an int8 variable in which to store the value of the flag.
func (f *FlagSet) Int8Var(p *int8, name string, value int8, usage string) {
diff --git a/ip.go b/ip.go
index efa75fb..746eefd 100644
--- a/ip.go
+++ b/ip.go
@@ -27,6 +27,23 @@ func (i *ipValue) Type() string {
return "ip"
}
+func ipConv(sval string) (interface{}, error) {
+ ip := net.ParseIP(sval)
+ if ip != nil {
+ return ip, nil
+ }
+ return nil, fmt.Errorf("invalid string being converted to IP address: %s", sval)
+}
+
+// GetIP return the net.IP value of a flag with the given name
+func (f *FlagSet) GetIP(name string) (net.IP, error) {
+ val, err := f.getFlagType(name, "ip", ipConv)
+ if err != nil {
+ return nil, err
+ }
+ return val.(net.IP), nil
+}
+
// IPVar defines an net.IP flag with specified name, default value, and usage string.
// The argument p points to an net.IP variable in which to store the value of the flag.
func (f *FlagSet) IPVar(p *net.IP, name string, value net.IP, usage string) {
diff --git a/ipmask.go b/ipmask.go
index 09b9533..1b10efb 100644
--- a/ipmask.go
+++ b/ipmask.go
@@ -3,6 +3,7 @@ package pflag
import (
"fmt"
"net"
+ "strconv"
)
// -- net.IPMask value
@@ -32,11 +33,46 @@ func (i *ipMaskValue) Type() string {
func ParseIPv4Mask(s string) net.IPMask {
mask := net.ParseIP(s)
if mask == nil {
- return nil
+ if len(s) != 8 {
+ return nil
+ }
+ // net.IPMask.String() actually outputs things like ffffff00
+ // so write a horrible parser for that as well :-(
+ m := []int{}
+ for i := 0; i < 4; i++ {
+ b := "0x" + s[2*i:2*i+2]
+ d, err := strconv.ParseInt(b, 0, 0)
+ if err != nil {
+ return nil
+ }
+ m = append(m, int(d))
+ }
+ s := fmt.Sprintf("%d.%d.%d.%d", m[0], m[1], m[2], m[3])
+ mask = net.ParseIP(s)
+ if mask == nil {
+ return nil
+ }
}
return net.IPv4Mask(mask[12], mask[13], mask[14], mask[15])
}
+func parseIPv4Mask(sval string) (interface{}, error) {
+ mask := ParseIPv4Mask(sval)
+ if mask == nil {
+ return nil, fmt.Errorf("unable to parse %s as net.IPMask", sval)
+ }
+ return mask, nil
+}
+
+// GetIPv4Mask return the net.IPv4Mask value of a flag with the given name
+func (f *FlagSet) GetIPv4Mask(name string) (net.IPMask, error) {
+ val, err := f.getFlagType(name, "ipMask", parseIPv4Mask)
+ if err != nil {
+ return nil, err
+ }
+ return val.(net.IPMask), nil
+}
+
// IPMaskVar defines an net.IPMask flag with specified name, default value, and usage string.
// The argument p points to an net.IPMask variable in which to store the value of the flag.
func (f *FlagSet) IPMaskVar(p *net.IPMask, name string, value net.IPMask, usage string) {
diff --git a/string.go b/string.go
index 362fbf8..f89ea8b 100644
--- a/string.go
+++ b/string.go
@@ -20,6 +20,19 @@ func (s *stringValue) Type() string {
func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) }
+func stringConv(sval string) (interface{}, error) {
+ return sval, nil
+}
+
+// GetString return the string value of a flag with the given name
+func (f *FlagSet) GetString(name string) (string, error) {
+ val, err := f.getFlagType(name, "string", stringConv)
+ if err != nil {
+ return "", err
+ }
+ return val.(string), nil
+}
+
// StringVar defines a string flag with specified name, default value, and usage string.
// The argument p points to a string variable in which to store the value of the flag.
func (f *FlagSet) StringVar(p *string, name string, value string, usage string) {
diff --git a/uint.go b/uint.go
index c063fe7..d6f8e5b 100644
--- a/uint.go
+++ b/uint.go
@@ -25,6 +25,23 @@ func (i *uintValue) Type() string {
func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) }
+func uintConv(sval string) (interface{}, error) {
+ v, err := strconv.ParseUint(sval, 0, 0)
+ if err != nil {
+ return 0, err
+ }
+ return uint(v), nil
+}
+
+// GetUint return the uint value of a flag with the given name
+func (f *FlagSet) GetUint(name string) (uint, error) {
+ val, err := f.getFlagType(name, "uint", uintConv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(uint), nil
+}
+
// UintVar defines a uint flag with specified name, default value, and usage string.
// The argument p points to a uint variable in which to store the value of the flag.
func (f *FlagSet) UintVar(p *uint, name string, value uint, usage string) {
diff --git a/uint16.go b/uint16.go
index ab1c1f9..1cdc3df 100644
--- a/uint16.go
+++ b/uint16.go
@@ -23,6 +23,23 @@ func (i *uint16Value) Type() string {
return "uint16"
}
+func uint16Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseUint(sval, 0, 16)
+ if err != nil {
+ return 0, err
+ }
+ return uint16(v), nil
+}
+
+// GetUint16 return the uint16 value of a flag with the given name
+func (f *FlagSet) GetUint16(name string) (uint16, error) {
+ val, err := f.getFlagType(name, "uint16", uint16Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(uint16), nil
+}
+
// Uint16Var defines a uint flag with specified name, default value, and usage string.
// The argument p points to a uint variable in which to store the value of the flag.
func (f *FlagSet) Uint16Var(p *uint16, name string, value uint16, usage string) {
diff --git a/uint32.go b/uint32.go
index db635ae..1326e4a 100644
--- a/uint32.go
+++ b/uint32.go
@@ -23,6 +23,23 @@ func (i *uint32Value) Type() string {
return "uint32"
}
+func uint32Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseUint(sval, 0, 32)
+ if err != nil {
+ return 0, err
+ }
+ return uint32(v), nil
+}
+
+// GetUint32 return the uint32 value of a flag with the given name
+func (f *FlagSet) GetUint32(name string) (uint32, error) {
+ val, err := f.getFlagType(name, "uint32", uint32Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(uint32), nil
+}
+
// Uint32Var defines a uint32 flag with specified name, default value, and usage string.
// The argument p points to a uint32 variable in which to store the value of the flag.
func (f *FlagSet) Uint32Var(p *uint32, name string, value uint32, usage string) {
diff --git a/uint64.go b/uint64.go
index 99c7e80..6788bbf 100644
--- a/uint64.go
+++ b/uint64.go
@@ -25,6 +25,23 @@ func (i *uint64Value) Type() string {
func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) }
+func uint64Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseUint(sval, 0, 64)
+ if err != nil {
+ return 0, err
+ }
+ return uint64(v), nil
+}
+
+// GetUint64 return the uint64 value of a flag with the given name
+func (f *FlagSet) GetUint64(name string) (uint64, error) {
+ val, err := f.getFlagType(name, "uint64", uint64Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(uint64), nil
+}
+
// Uint64Var defines a uint64 flag with specified name, default value, and usage string.
// The argument p points to a uint64 variable in which to store the value of the flag.
func (f *FlagSet) Uint64Var(p *uint64, name string, value uint64, usage string) {
diff --git a/uint8.go b/uint8.go
index 6fef508..560c569 100644
--- a/uint8.go
+++ b/uint8.go
@@ -25,6 +25,23 @@ func (i *uint8Value) Type() string {
func (i *uint8Value) String() string { return fmt.Sprintf("%v", *i) }
+func uint8Conv(sval string) (interface{}, error) {
+ v, err := strconv.ParseUint(sval, 0, 8)
+ if err != nil {
+ return 0, err
+ }
+ return uint8(v), nil
+}
+
+// GetUint8 return the uint8 value of a flag with the given name
+func (f *FlagSet) GetUint8(name string) (uint8, error) {
+ val, err := f.getFlagType(name, "uint8", uint8Conv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(uint8), nil
+}
+
// Uint8Var defines a uint8 flag with specified name, default value, and usage string.
// The argument p points to a uint8 variable in which to store the value of the flag.
func (f *FlagSet) Uint8Var(p *uint8, name string, value uint8, usage string) {