diff options
| -rw-r--r-- | bool.go | 13 | ||||
| -rw-r--r-- | duration.go | 17 | ||||
| -rw-r--r-- | flag.go | 21 | ||||
| -rw-r--r-- | flag_test.go | 57 | ||||
| -rw-r--r-- | float32.go | 17 | ||||
| -rw-r--r-- | float64.go | 13 | ||||
| -rw-r--r-- | int.go | 13 | ||||
| -rw-r--r-- | int32.go | 17 | ||||
| -rw-r--r-- | int64.go | 13 | ||||
| -rw-r--r-- | int8.go | 17 | ||||
| -rw-r--r-- | ip.go | 17 | ||||
| -rw-r--r-- | ipmask.go | 38 | ||||
| -rw-r--r-- | string.go | 13 | ||||
| -rw-r--r-- | uint.go | 17 | ||||
| -rw-r--r-- | uint16.go | 17 | ||||
| -rw-r--r-- | uint32.go | 17 | ||||
| -rw-r--r-- | uint64.go | 17 | ||||
| -rw-r--r-- | uint8.go | 17 |
18 files changed, 346 insertions, 5 deletions
@@ -34,6 +34,19 @@ func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) } func (b *boolValue) IsBoolFlag() bool { return true } +func boolConv(sval string) (interface{}, error) { + return strconv.ParseBool(sval) +} + +// GetBool return the bool value of a flag with the given name +func (f *FlagSet) GetBool(name string) (bool, error) { + val, err := f.getFlagType(name, "bool", boolConv) + if err != nil { + return false, err + } + return val.(bool), nil +} + // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) { diff --git a/duration.go b/duration.go index 66ed7ac..382ffd3 100644 --- a/duration.go +++ b/duration.go @@ -1,6 +1,8 @@ package pflag -import "time" +import ( + "time" +) // -- time.Duration Value type durationValue time.Duration @@ -22,6 +24,19 @@ func (d *durationValue) Type() string { func (d *durationValue) String() string { return (*time.Duration)(d).String() } +func durationConv(sval string) (interface{}, error) { + return time.ParseDuration(sval) +} + +// GetDuration return the duration value of a flag with the given name +func (f *FlagSet) GetDuration(name string) (time.Duration, error) { + val, err := f.getFlagType(name, "duration", durationConv) + if err != nil { + return 0, err + } + return val.(time.Duration), nil +} + // DurationVar defines a time.Duration flag with specified name, default value, and usage string. // The argument p points to a time.Duration variable in which to store the value of the flag. func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) { @@ -257,6 +257,27 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag { return f.formal[name] } +// func to return a given type for a given flag name +func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { + flag := f.Lookup(name) + if flag == nil { + err := fmt.Errorf("flag accessed but not defined: %s\n", name) + return nil, err + } + + if flag.Value.Type() != ftype { + err := fmt.Errorf("trying to get %s value of flag of type %s\n", ftype, flag.Value.Type()) + return nil, err + } + + sval := flag.Value.String() + result, err := convFunc(sval) + if err != nil { + return nil, err + } + return result, nil +} + // Mark a flag deprecated in your program func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { flag := f.Lookup(name) diff --git a/flag_test.go b/flag_test.go index d3c1714..e3c5c97 100644 --- a/flag_test.go +++ b/flag_test.go @@ -160,6 +160,9 @@ func testParse(f *FlagSet, t *testing.T) { if *boolFlag != true { t.Error("bool flag should be true, is ", *boolFlag) } + if v, err := f.GetBool("bool"); err != nil || v != *boolFlag { + t.Error("GetBool does not work.") + } if *bool2Flag != true { t.Error("bool2 flag should be true, is ", *bool2Flag) } @@ -169,48 +172,96 @@ func testParse(f *FlagSet, t *testing.T) { if *intFlag != 22 { t.Error("int flag should be 22, is ", *intFlag) } + if v, err := f.GetInt("int"); err != nil || v != *intFlag { + t.Error("GetInt does not work.") + } if *int8Flag != -8 { t.Error("int8 flag should be 0x23, is ", *int8Flag) } + if v, err := f.GetInt8("int8"); err != nil || v != *int8Flag { + t.Error("GetInt8 does not work.") + } if *int32Flag != -32 { t.Error("int32 flag should be 0x23, is ", *int32Flag) } + if v, err := f.GetInt32("int32"); err != nil || v != *int32Flag { + t.Error("GetInt32 does not work.") + } if *int64Flag != 0x23 { t.Error("int64 flag should be 0x23, is ", *int64Flag) } + if v, err := f.GetInt64("int64"); err != nil || v != *int64Flag { + t.Error("GetInt64 does not work.") + } if *uintFlag != 24 { t.Error("uint flag should be 24, is ", *uintFlag) } + if v, err := f.GetUint("uint"); err != nil || v != *uintFlag { + t.Error("GetUint does not work.") + } if *uint8Flag != 8 { t.Error("uint8 flag should be 8, is ", *uint8Flag) } + if v, err := f.GetUint8("uint8"); err != nil || v != *uint8Flag { + t.Error("GetUint8 does not work.") + } if *uint16Flag != 16 { t.Error("uint16 flag should be 16, is ", *uint16Flag) } + if v, err := f.GetUint16("uint16"); err != nil || v != *uint16Flag { + t.Error("GetUint16 does not work.") + } if *uint32Flag != 32 { t.Error("uint32 flag should be 32, is ", *uint32Flag) } + if v, err := f.GetUint32("uint32"); err != nil || v != *uint32Flag { + t.Error("GetUint32 does not work.") + } if *uint64Flag != 25 { t.Error("uint64 flag should be 25, is ", *uint64Flag) } + if v, err := f.GetUint64("uint64"); err != nil || v != *uint64Flag { + t.Error("GetUint64 does not work.") + } if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } + if v, err := f.GetString("string"); err != nil || v != *stringFlag { + t.Error("GetString does not work.") + } if *float32Flag != -172e12 { - t.Error("float64 flag should be -172e12, is ", *float64Flag) + t.Error("float32 flag should be -172e12, is ", *float32Flag) + } + if v, err := f.GetFloat32("float32"); err != nil || v != *float32Flag { + t.Errorf("GetFloat32 returned %v but float32Flag was %v", v, *float32Flag) } if *float64Flag != 2718e28 { t.Error("float64 flag should be 2718e28, is ", *float64Flag) } - if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() { - t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String()) + if v, err := f.GetFloat64("float64"); err != nil || v != *float64Flag { + t.Errorf("GetFloat64 returned %v but float64Flag was %v", v, *float64Flag) } if !(*ipFlag).Equal(net.ParseIP("10.11.12.13")) { t.Error("ip flag should be 10.11.12.13, is ", *ipFlag) } + if v, err := f.GetIP("ip"); err != nil || !v.Equal(*ipFlag) { + t.Errorf("GetIP returned %v but ipFlag was %v", v, *ipFlag) + } + if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() { + t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String()) + } + if v, err := f.GetIPv4Mask("mask"); err != nil || v.String() != (*maskFlag).String() { + t.Errorf("GetIP returned %v but maskFlag was %v", v, *maskFlag, err) + } if *durationFlag != 2*time.Minute { t.Error("duration flag should be 2m, is ", *durationFlag) } + if v, err := f.GetDuration("duration"); err != nil || v != *durationFlag { + t.Error("GetDuration does not work.") + } + if _, err := f.GetInt("duration"); err == nil { + t.Error("GetInt parsed a time.Duration?!?!") + } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra { @@ -25,6 +25,23 @@ func (f *float32Value) Type() string { func (f *float32Value) String() string { return fmt.Sprintf("%v", *f) } +func float32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseFloat(sval, 32) + if err != nil { + return 0, err + } + return float32(v), nil +} + +// GetFloat32 return the float32 value of a flag with the given name +func (f *FlagSet) GetFloat32(name string) (float32, error) { + val, err := f.getFlagType(name, "float32", float32Conv) + if err != nil { + return 0, err + } + return val.(float32), nil +} + // Float32Var defines a float32 flag with specified name, default value, and usage string. // The argument p points to a float32 variable in which to store the value of the flag. func (f *FlagSet) Float32Var(p *float32, name string, value float32, usage string) { @@ -25,6 +25,19 @@ func (f *float64Value) Type() string { func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) } +func float64Conv(sval string) (interface{}, error) { + return strconv.ParseFloat(sval, 64) +} + +// GetFloat64 return the float64 value of a flag with the given name +func (f *FlagSet) GetFloat64(name string) (float64, error) { + val, err := f.getFlagType(name, "float64", float64Conv) + if err != nil { + return 0, err + } + return val.(float64), nil +} + // Float64Var defines a float64 flag with specified name, default value, and usage string. // The argument p points to a float64 variable in which to store the value of the flag. func (f *FlagSet) Float64Var(p *float64, name string, value float64, usage string) { @@ -25,6 +25,19 @@ func (i *intValue) Type() string { func (i *intValue) String() string { return fmt.Sprintf("%v", *i) } +func intConv(sval string) (interface{}, error) { + return strconv.Atoi(sval) +} + +// GetInt return the int value of a flag with the given name +func (f *FlagSet) GetInt(name string) (int, error) { + val, err := f.getFlagType(name, "int", intConv) + if err != nil { + return 0, err + } + return val.(int), nil +} + // IntVar defines an int flag with specified name, default value, and usage string. // The argument p points to an int variable in which to store the value of the flag. func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { @@ -25,6 +25,23 @@ func (i *int32Value) Type() string { func (i *int32Value) String() string { return fmt.Sprintf("%v", *i) } +func int32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseInt(sval, 0, 32) + if err != nil { + return 0, err + } + return int32(v), nil +} + +// GetInt32 return the int32 value of a flag with the given name +func (f *FlagSet) GetInt32(name string) (int32, error) { + val, err := f.getFlagType(name, "int32", int32Conv) + if err != nil { + return 0, err + } + return val.(int32), nil +} + // Int32Var defines an int32 flag with specified name, default value, and usage string. // The argument p points to an int32 variable in which to store the value of the flag. func (f *FlagSet) Int32Var(p *int32, name string, value int32, usage string) { @@ -25,6 +25,19 @@ func (i *int64Value) Type() string { func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) } +func int64Conv(sval string) (interface{}, error) { + return strconv.ParseInt(sval, 0, 64) +} + +// GetInt64 return the int64 value of a flag with the given name +func (f *FlagSet) GetInt64(name string) (int64, error) { + val, err := f.getFlagType(name, "int64", int64Conv) + if err != nil { + return 0, err + } + return val.(int64), nil +} + // Int64Var defines an int64 flag with specified name, default value, and usage string. // The argument p points to an int64 variable in which to store the value of the flag. func (f *FlagSet) Int64Var(p *int64, name string, value int64, usage string) { @@ -25,6 +25,23 @@ func (i *int8Value) Type() string { func (i *int8Value) String() string { return fmt.Sprintf("%v", *i) } +func int8Conv(sval string) (interface{}, error) { + v, err := strconv.ParseInt(sval, 0, 8) + if err != nil { + return 0, err + } + return int8(v), nil +} + +// GetInt8 return the int8 value of a flag with the given name +func (f *FlagSet) GetInt8(name string) (int8, error) { + val, err := f.getFlagType(name, "int8", int8Conv) + if err != nil { + return 0, err + } + return val.(int8), nil +} + // Int8Var defines an int8 flag with specified name, default value, and usage string. // The argument p points to an int8 variable in which to store the value of the flag. func (f *FlagSet) Int8Var(p *int8, name string, value int8, usage string) { @@ -27,6 +27,23 @@ func (i *ipValue) Type() string { return "ip" } +func ipConv(sval string) (interface{}, error) { + ip := net.ParseIP(sval) + if ip != nil { + return ip, nil + } + return nil, fmt.Errorf("invalid string being converted to IP address: %s", sval) +} + +// GetIP return the net.IP value of a flag with the given name +func (f *FlagSet) GetIP(name string) (net.IP, error) { + val, err := f.getFlagType(name, "ip", ipConv) + if err != nil { + return nil, err + } + return val.(net.IP), nil +} + // IPVar defines an net.IP flag with specified name, default value, and usage string. // The argument p points to an net.IP variable in which to store the value of the flag. func (f *FlagSet) IPVar(p *net.IP, name string, value net.IP, usage string) { @@ -3,6 +3,7 @@ package pflag import ( "fmt" "net" + "strconv" ) // -- net.IPMask value @@ -32,11 +33,46 @@ func (i *ipMaskValue) Type() string { func ParseIPv4Mask(s string) net.IPMask { mask := net.ParseIP(s) if mask == nil { - return nil + if len(s) != 8 { + return nil + } + // net.IPMask.String() actually outputs things like ffffff00 + // so write a horrible parser for that as well :-( + m := []int{} + for i := 0; i < 4; i++ { + b := "0x" + s[2*i:2*i+2] + d, err := strconv.ParseInt(b, 0, 0) + if err != nil { + return nil + } + m = append(m, int(d)) + } + s := fmt.Sprintf("%d.%d.%d.%d", m[0], m[1], m[2], m[3]) + mask = net.ParseIP(s) + if mask == nil { + return nil + } } return net.IPv4Mask(mask[12], mask[13], mask[14], mask[15]) } +func parseIPv4Mask(sval string) (interface{}, error) { + mask := ParseIPv4Mask(sval) + if mask == nil { + return nil, fmt.Errorf("unable to parse %s as net.IPMask", sval) + } + return mask, nil +} + +// GetIPv4Mask return the net.IPv4Mask value of a flag with the given name +func (f *FlagSet) GetIPv4Mask(name string) (net.IPMask, error) { + val, err := f.getFlagType(name, "ipMask", parseIPv4Mask) + if err != nil { + return nil, err + } + return val.(net.IPMask), nil +} + // IPMaskVar defines an net.IPMask flag with specified name, default value, and usage string. // The argument p points to an net.IPMask variable in which to store the value of the flag. func (f *FlagSet) IPMaskVar(p *net.IPMask, name string, value net.IPMask, usage string) { @@ -20,6 +20,19 @@ func (s *stringValue) Type() string { func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) } +func stringConv(sval string) (interface{}, error) { + return sval, nil +} + +// GetString return the string value of a flag with the given name +func (f *FlagSet) GetString(name string) (string, error) { + val, err := f.getFlagType(name, "string", stringConv) + if err != nil { + return "", err + } + return val.(string), nil +} + // StringVar defines a string flag with specified name, default value, and usage string. // The argument p points to a string variable in which to store the value of the flag. func (f *FlagSet) StringVar(p *string, name string, value string, usage string) { @@ -25,6 +25,23 @@ func (i *uintValue) Type() string { func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) } +func uintConv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 0) + if err != nil { + return 0, err + } + return uint(v), nil +} + +// GetUint return the uint value of a flag with the given name +func (f *FlagSet) GetUint(name string) (uint, error) { + val, err := f.getFlagType(name, "uint", uintConv) + if err != nil { + return 0, err + } + return val.(uint), nil +} + // UintVar defines a uint flag with specified name, default value, and usage string. // The argument p points to a uint variable in which to store the value of the flag. func (f *FlagSet) UintVar(p *uint, name string, value uint, usage string) { @@ -23,6 +23,23 @@ func (i *uint16Value) Type() string { return "uint16" } +func uint16Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 16) + if err != nil { + return 0, err + } + return uint16(v), nil +} + +// GetUint16 return the uint16 value of a flag with the given name +func (f *FlagSet) GetUint16(name string) (uint16, error) { + val, err := f.getFlagType(name, "uint16", uint16Conv) + if err != nil { + return 0, err + } + return val.(uint16), nil +} + // Uint16Var defines a uint flag with specified name, default value, and usage string. // The argument p points to a uint variable in which to store the value of the flag. func (f *FlagSet) Uint16Var(p *uint16, name string, value uint16, usage string) { @@ -23,6 +23,23 @@ func (i *uint32Value) Type() string { return "uint32" } +func uint32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 32) + if err != nil { + return 0, err + } + return uint32(v), nil +} + +// GetUint32 return the uint32 value of a flag with the given name +func (f *FlagSet) GetUint32(name string) (uint32, error) { + val, err := f.getFlagType(name, "uint32", uint32Conv) + if err != nil { + return 0, err + } + return val.(uint32), nil +} + // Uint32Var defines a uint32 flag with specified name, default value, and usage string. // The argument p points to a uint32 variable in which to store the value of the flag. func (f *FlagSet) Uint32Var(p *uint32, name string, value uint32, usage string) { @@ -25,6 +25,23 @@ func (i *uint64Value) Type() string { func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) } +func uint64Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 64) + if err != nil { + return 0, err + } + return uint64(v), nil +} + +// GetUint64 return the uint64 value of a flag with the given name +func (f *FlagSet) GetUint64(name string) (uint64, error) { + val, err := f.getFlagType(name, "uint64", uint64Conv) + if err != nil { + return 0, err + } + return val.(uint64), nil +} + // Uint64Var defines a uint64 flag with specified name, default value, and usage string. // The argument p points to a uint64 variable in which to store the value of the flag. func (f *FlagSet) Uint64Var(p *uint64, name string, value uint64, usage string) { @@ -25,6 +25,23 @@ func (i *uint8Value) Type() string { func (i *uint8Value) String() string { return fmt.Sprintf("%v", *i) } +func uint8Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 8) + if err != nil { + return 0, err + } + return uint8(v), nil +} + +// GetUint8 return the uint8 value of a flag with the given name +func (f *FlagSet) GetUint8(name string) (uint8, error) { + val, err := f.getFlagType(name, "uint8", uint8Conv) + if err != nil { + return 0, err + } + return val.(uint8), nil +} + // Uint8Var defines a uint8 flag with specified name, default value, and usage string. // The argument p points to a uint8 variable in which to store the value of the flag. func (f *FlagSet) Uint8Var(p *uint8, name string, value uint8, usage string) { |
