aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTamal Saha <[email protected]>2018-08-15 17:04:06 -0400
committerEric Paris <[email protected]>2018-08-15 17:04:06 -0400
commit947b89bd1b7dabfed991ac30e1a56f5193f0c88b (patch)
tree180690569f91f3fcfdca374a462a6752fef3463b
parent9a97c102cda95a86cec2345a6f09f55a939babf5 (diff)
Add map valued (string->string, string->int) flags. (#133)
Format: --myflag=a=1,b=2
-rw-r--r--string_to_int.go149
-rw-r--r--string_to_int_test.go156
-rw-r--r--string_to_string.go149
-rw-r--r--string_to_string_test.go158
4 files changed, 612 insertions, 0 deletions
diff --git a/string_to_int.go b/string_to_int.go
new file mode 100644
index 0000000..5ceda39
--- /dev/null
+++ b/string_to_int.go
@@ -0,0 +1,149 @@
+package pflag
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// -- stringToInt Value
+type stringToIntValue struct {
+ value *map[string]int
+ changed bool
+}
+
+func newStringToIntValue(val map[string]int, p *map[string]int) *stringToIntValue {
+ ssv := new(stringToIntValue)
+ ssv.value = p
+ *ssv.value = val
+ return ssv
+}
+
+// Format: a=1,b=2
+func (s *stringToIntValue) Set(val string) error {
+ ss := strings.Split(val, ",")
+ out := make(map[string]int, len(ss))
+ for _, pair := range ss {
+ kv := strings.SplitN(pair, "=", 2)
+ if len(kv) != 2 {
+ return fmt.Errorf("%s must be formatted as key=value", pair)
+ }
+ var err error
+ out[kv[0]], err = strconv.Atoi(kv[1])
+ if err != nil {
+ return err
+ }
+ }
+ if !s.changed {
+ *s.value = out
+ } else {
+ for k, v := range out {
+ (*s.value)[k] = v
+ }
+ }
+ s.changed = true
+ return nil
+}
+
+func (s *stringToIntValue) Type() string {
+ return "stringToInt"
+}
+
+func (s *stringToIntValue) String() string {
+ var buf bytes.Buffer
+ i := 0
+ for k, v := range *s.value {
+ if i > 0 {
+ buf.WriteRune(',')
+ }
+ buf.WriteString(k)
+ buf.WriteRune('=')
+ buf.WriteString(strconv.Itoa(v))
+ i++
+ }
+ return "[" + buf.String() + "]"
+}
+
+func stringToIntConv(val string) (interface{}, error) {
+ val = strings.Trim(val, "[]")
+ // An empty string would cause an empty map
+ if len(val) == 0 {
+ return map[string]int{}, nil
+ }
+ ss := strings.Split(val, ",")
+ out := make(map[string]int, len(ss))
+ for _, pair := range ss {
+ kv := strings.SplitN(pair, "=", 2)
+ if len(kv) != 2 {
+ return nil, fmt.Errorf("%s must be formatted as key=value", pair)
+ }
+ var err error
+ out[kv[0]], err = strconv.Atoi(kv[1])
+ if err != nil {
+ return nil, err
+ }
+ }
+ return out, nil
+}
+
+// GetStringToInt return the map[string]int value of a flag with the given name
+func (f *FlagSet) GetStringToInt(name string) (map[string]int, error) {
+ val, err := f.getFlagType(name, "stringToInt", stringToIntConv)
+ if err != nil {
+ return map[string]int{}, err
+ }
+ return val.(map[string]int), nil
+}
+
+// StringToIntVar defines a string flag with specified name, default value, and usage string.
+// The argument p points to a map[string]int variable in which to store the values of the multiple flags.
+// The value of each argument will not try to be separated by comma
+func (f *FlagSet) StringToIntVar(p *map[string]int, name string, value map[string]int, usage string) {
+ f.VarP(newStringToIntValue(value, p), name, "", usage)
+}
+
+// StringToIntVarP is like StringToIntVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) StringToIntVarP(p *map[string]int, name, shorthand string, value map[string]int, usage string) {
+ f.VarP(newStringToIntValue(value, p), name, shorthand, usage)
+}
+
+// StringToIntVar defines a string flag with specified name, default value, and usage string.
+// The argument p points to a map[string]int variable in which to store the value of the flag.
+// The value of each argument will not try to be separated by comma
+func StringToIntVar(p *map[string]int, name string, value map[string]int, usage string) {
+ CommandLine.VarP(newStringToIntValue(value, p), name, "", usage)
+}
+
+// StringToIntVarP is like StringToIntVar, but accepts a shorthand letter that can be used after a single dash.
+func StringToIntVarP(p *map[string]int, name, shorthand string, value map[string]int, usage string) {
+ CommandLine.VarP(newStringToIntValue(value, p), name, shorthand, usage)
+}
+
+// StringToInt defines a string flag with specified name, default value, and usage string.
+// The return value is the address of a map[string]int variable that stores the value of the flag.
+// The value of each argument will not try to be separated by comma
+func (f *FlagSet) StringToInt(name string, value map[string]int, usage string) *map[string]int {
+ p := map[string]int{}
+ f.StringToIntVarP(&p, name, "", value, usage)
+ return &p
+}
+
+// StringToIntP is like StringToInt, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) StringToIntP(name, shorthand string, value map[string]int, usage string) *map[string]int {
+ p := map[string]int{}
+ f.StringToIntVarP(&p, name, shorthand, value, usage)
+ return &p
+}
+
+// StringToInt defines a string flag with specified name, default value, and usage string.
+// The return value is the address of a map[string]int variable that stores the value of the flag.
+// The value of each argument will not try to be separated by comma
+func StringToInt(name string, value map[string]int, usage string) *map[string]int {
+ return CommandLine.StringToIntP(name, "", value, usage)
+}
+
+// StringToIntP is like StringToInt, but accepts a shorthand letter that can be used after a single dash.
+func StringToIntP(name, shorthand string, value map[string]int, usage string) *map[string]int {
+ return CommandLine.StringToIntP(name, shorthand, value, usage)
+}
diff --git a/string_to_int_test.go b/string_to_int_test.go
new file mode 100644
index 0000000..b60bbaf
--- /dev/null
+++ b/string_to_int_test.go
@@ -0,0 +1,156 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of ths2i source code s2i governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pflag
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+ "testing"
+)
+
+func setUpS2IFlagSet(s2ip *map[string]int) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.StringToIntVar(s2ip, "s2i", map[string]int{}, "Command separated ls2it!")
+ return f
+}
+
+func setUpS2IFlagSetWithDefault(s2ip *map[string]int) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.StringToIntVar(s2ip, "s2i", map[string]int{"a": 1, "b": 2}, "Command separated ls2it!")
+ return f
+}
+
+func createS2IFlag(vals map[string]int) string {
+ var buf bytes.Buffer
+ i := 0
+ for k, v := range vals {
+ if i > 0 {
+ buf.WriteRune(',')
+ }
+ buf.WriteString(k)
+ buf.WriteRune('=')
+ buf.WriteString(strconv.Itoa(v))
+ i++
+ }
+ return buf.String()
+}
+
+func TestEmptyS2I(t *testing.T) {
+ var s2i map[string]int
+ f := setUpS2IFlagSet(&s2i)
+ err := f.Parse([]string{})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+
+ getS2I, err := f.GetStringToInt("s2i")
+ if err != nil {
+ t.Fatal("got an error from GetStringToInt():", err)
+ }
+ if len(getS2I) != 0 {
+ t.Fatalf("got s2i %v with len=%d but expected length=0", getS2I, len(getS2I))
+ }
+}
+
+func TestS2I(t *testing.T) {
+ var s2i map[string]int
+ f := setUpS2IFlagSet(&s2i)
+
+ vals := map[string]int{"a": 1, "b": 2, "d": 4, "c": 3}
+ arg := fmt.Sprintf("--s2i=%s", createS2IFlag(vals))
+ err := f.Parse([]string{arg})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2i {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v)
+ }
+ }
+ getS2I, err := f.GetStringToInt("s2i")
+ if err != nil {
+ t.Fatalf("got error: %v", err)
+ }
+ for k, v := range getS2I {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d but got: %d from GetStringToInt", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2IDefault(t *testing.T) {
+ var s2i map[string]int
+ f := setUpS2IFlagSetWithDefault(&s2i)
+
+ vals := map[string]int{"a": 1, "b": 2}
+
+ err := f.Parse([]string{})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2i {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v)
+ }
+ }
+
+ getS2I, err := f.GetStringToInt("s2i")
+ if err != nil {
+ t.Fatal("got an error from GetStringToInt():", err)
+ }
+ for k, v := range getS2I {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d from GetStringToInt but got: %d", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2IWithDefault(t *testing.T) {
+ var s2i map[string]int
+ f := setUpS2IFlagSetWithDefault(&s2i)
+
+ vals := map[string]int{"a": 1, "b": 2}
+ arg := fmt.Sprintf("--s2i=%s", createS2IFlag(vals))
+ err := f.Parse([]string{arg})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2i {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v)
+ }
+ }
+
+ getS2I, err := f.GetStringToInt("s2i")
+ if err != nil {
+ t.Fatal("got an error from GetStringToInt():", err)
+ }
+ for k, v := range getS2I {
+ if vals[k] != v {
+ t.Fatalf("expected s2i[%s] to be %d from GetStringToInt but got: %d", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2ICalledTwice(t *testing.T) {
+ var s2i map[string]int
+ f := setUpS2IFlagSet(&s2i)
+
+ in := []string{"a=1,b=2", "b=3"}
+ expected := map[string]int{"a": 1, "b": 3}
+ argfmt := "--s2i=%s"
+ arg1 := fmt.Sprintf(argfmt, in[0])
+ arg2 := fmt.Sprintf(argfmt, in[1])
+ err := f.Parse([]string{arg1, arg2})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for i, v := range s2i {
+ if expected[i] != v {
+ t.Fatalf("expected s2i[%s] to be %d but got: %d", i, expected[i], v)
+ }
+ }
+}
diff --git a/string_to_string.go b/string_to_string.go
new file mode 100644
index 0000000..64892db
--- /dev/null
+++ b/string_to_string.go
@@ -0,0 +1,149 @@
+package pflag
+
+import (
+ "bytes"
+ "encoding/csv"
+ "fmt"
+ "strings"
+)
+
+// -- stringToString Value
+type stringToStringValue struct {
+ value *map[string]string
+ changed bool
+}
+
+func newStringToStringValue(val map[string]string, p *map[string]string) *stringToStringValue {
+ ssv := new(stringToStringValue)
+ ssv.value = p
+ *ssv.value = val
+ return ssv
+}
+
+// Format: a=1,b=2
+func (s *stringToStringValue) Set(val string) error {
+ r := csv.NewReader(strings.NewReader(val))
+ ss, err := r.Read()
+ if err != nil {
+ return err
+ }
+ out := make(map[string]string, len(ss))
+ for _, pair := range ss {
+ kv := strings.SplitN(pair, "=", 2)
+ if len(kv) != 2 {
+ return fmt.Errorf("%s must be formatted as key=value", pair)
+ }
+ out[kv[0]] = kv[1]
+ }
+ if !s.changed {
+ *s.value = out
+ } else {
+ for k, v := range out {
+ (*s.value)[k] = v
+ }
+ }
+ s.changed = true
+ return nil
+}
+
+func (s *stringToStringValue) Type() string {
+ return "stringToString"
+}
+
+func (s *stringToStringValue) String() string {
+ records := make([]string, 0, len(*s.value)>>1)
+ for k, v := range *s.value {
+ records = append(records, k+"="+v)
+ }
+
+ var buf bytes.Buffer
+ w := csv.NewWriter(&buf)
+ if err := w.Write(records); err != nil {
+ panic(err)
+ }
+ w.Flush()
+ return "[" + strings.TrimSpace(buf.String()) + "]"
+}
+
+func stringToStringConv(val string) (interface{}, error) {
+ val = strings.Trim(val, "[]")
+ // An empty string would cause an empty map
+ if len(val) == 0 {
+ return map[string]string{}, nil
+ }
+ r := csv.NewReader(strings.NewReader(val))
+ ss, err := r.Read()
+ if err != nil {
+ return nil, err
+ }
+ out := make(map[string]string, len(ss))
+ for _, pair := range ss {
+ kv := strings.SplitN(pair, "=", 2)
+ if len(kv) != 2 {
+ return nil, fmt.Errorf("%s must be formatted as key=value", pair)
+ }
+ out[kv[0]] = kv[1]
+ }
+ return out, nil
+}
+
+// GetStringToString return the map[string]string value of a flag with the given name
+func (f *FlagSet) GetStringToString(name string) (map[string]string, error) {
+ val, err := f.getFlagType(name, "stringToString", stringToStringConv)
+ if err != nil {
+ return map[string]string{}, err
+ }
+ return val.(map[string]string), nil
+}
+
+// StringToStringVar defines a string flag with specified name, default value, and usage string.
+// The argument p points to a map[string]string variable in which to store the values of the multiple flags.
+// The value of each argument will not try to be separated by comma
+func (f *FlagSet) StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) {
+ f.VarP(newStringToStringValue(value, p), name, "", usage)
+}
+
+// StringToStringVarP is like StringToStringVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) {
+ f.VarP(newStringToStringValue(value, p), name, shorthand, usage)
+}
+
+// StringToStringVar defines a string flag with specified name, default value, and usage string.
+// The argument p points to a map[string]string variable in which to store the value of the flag.
+// The value of each argument will not try to be separated by comma
+func StringToStringVar(p *map[string]string, name string, value map[string]string, usage string) {
+ CommandLine.VarP(newStringToStringValue(value, p), name, "", usage)
+}
+
+// StringToStringVarP is like StringToStringVar, but accepts a shorthand letter that can be used after a single dash.
+func StringToStringVarP(p *map[string]string, name, shorthand string, value map[string]string, usage string) {
+ CommandLine.VarP(newStringToStringValue(value, p), name, shorthand, usage)
+}
+
+// StringToString defines a string flag with specified name, default value, and usage string.
+// The return value is the address of a map[string]string variable that stores the value of the flag.
+// The value of each argument will not try to be separated by comma
+func (f *FlagSet) StringToString(name string, value map[string]string, usage string) *map[string]string {
+ p := map[string]string{}
+ f.StringToStringVarP(&p, name, "", value, usage)
+ return &p
+}
+
+// StringToStringP is like StringToString, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) StringToStringP(name, shorthand string, value map[string]string, usage string) *map[string]string {
+ p := map[string]string{}
+ f.StringToStringVarP(&p, name, shorthand, value, usage)
+ return &p
+}
+
+// StringToString defines a string flag with specified name, default value, and usage string.
+// The return value is the address of a map[string]string variable that stores the value of the flag.
+// The value of each argument will not try to be separated by comma
+func StringToString(name string, value map[string]string, usage string) *map[string]string {
+ return CommandLine.StringToStringP(name, "", value, usage)
+}
+
+// StringToStringP is like StringToString, but accepts a shorthand letter that can be used after a single dash.
+func StringToStringP(name, shorthand string, value map[string]string, usage string) *map[string]string {
+ return CommandLine.StringToStringP(name, shorthand, value, usage)
+}
diff --git a/string_to_string_test.go b/string_to_string_test.go
new file mode 100644
index 0000000..f1aae04
--- /dev/null
+++ b/string_to_string_test.go
@@ -0,0 +1,158 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of ths2s source code s2s governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pflag
+
+import (
+ "bytes"
+ "encoding/csv"
+ "fmt"
+ "strings"
+ "testing"
+)
+
+func setUpS2SFlagSet(s2sp *map[string]string) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.StringToStringVar(s2sp, "s2s", map[string]string{}, "Command separated ls2st!")
+ return f
+}
+
+func setUpS2SFlagSetWithDefault(s2sp *map[string]string) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.StringToStringVar(s2sp, "s2s", map[string]string{"da": "1", "db": "2", "de": "5,6"}, "Command separated ls2st!")
+ return f
+}
+
+func createS2SFlag(vals map[string]string) string {
+ records := make([]string, 0, len(vals)>>1)
+ for k, v := range vals {
+ records = append(records, k+"="+v)
+ }
+
+ var buf bytes.Buffer
+ w := csv.NewWriter(&buf)
+ if err := w.Write(records); err != nil {
+ panic(err)
+ }
+ w.Flush()
+ return strings.TrimSpace(buf.String())
+}
+
+func TestEmptyS2S(t *testing.T) {
+ var s2s map[string]string
+ f := setUpS2SFlagSet(&s2s)
+ err := f.Parse([]string{})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+
+ getS2S, err := f.GetStringToString("s2s")
+ if err != nil {
+ t.Fatal("got an error from GetStringToString():", err)
+ }
+ if len(getS2S) != 0 {
+ t.Fatalf("got s2s %v with len=%d but expected length=0", getS2S, len(getS2S))
+ }
+}
+
+func TestS2S(t *testing.T) {
+ var s2s map[string]string
+ f := setUpS2SFlagSet(&s2s)
+
+ vals := map[string]string{"a": "1", "b": "2", "d": "4", "c": "3", "e": "5,6"}
+ arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals))
+ err := f.Parse([]string{arg})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2s {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v)
+ }
+ }
+ getS2S, err := f.GetStringToString("s2s")
+ if err != nil {
+ t.Fatalf("got error: %v", err)
+ }
+ for k, v := range getS2S {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s but got: %s from GetStringToString", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2SDefault(t *testing.T) {
+ var s2s map[string]string
+ f := setUpS2SFlagSetWithDefault(&s2s)
+
+ vals := map[string]string{"da": "1", "db": "2", "de": "5,6"}
+
+ err := f.Parse([]string{})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2s {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v)
+ }
+ }
+
+ getS2S, err := f.GetStringToString("s2s")
+ if err != nil {
+ t.Fatal("got an error from GetStringToString():", err)
+ }
+ for k, v := range getS2S {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2SWithDefault(t *testing.T) {
+ var s2s map[string]string
+ f := setUpS2SFlagSetWithDefault(&s2s)
+
+ vals := map[string]string{"a": "1", "b": "2", "e": "5,6"}
+ arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals))
+ err := f.Parse([]string{arg})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for k, v := range s2s {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v)
+ }
+ }
+
+ getS2S, err := f.GetStringToString("s2s")
+ if err != nil {
+ t.Fatal("got an error from GetStringToString():", err)
+ }
+ for k, v := range getS2S {
+ if vals[k] != v {
+ t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v)
+ }
+ }
+}
+
+func TestS2SCalledTwice(t *testing.T) {
+ var s2s map[string]string
+ f := setUpS2SFlagSet(&s2s)
+
+ in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f="7,8"`}
+ expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"}
+ argfmt := "--s2s=%s"
+ arg1 := fmt.Sprintf(argfmt, in[0])
+ arg2 := fmt.Sprintf(argfmt, in[1])
+ arg3 := fmt.Sprintf(argfmt, in[2])
+ err := f.Parse([]string{arg1, arg2, arg3})
+ if err != nil {
+ t.Fatal("expected no error; got", err)
+ }
+ for i, v := range s2s {
+ if expected[i] != v {
+ t.Fatalf("expected s2s[%s] to be %s but got: %s", i, expected[i], v)
+ }
+ }
+}