diff options
| author | Tamal Saha <[email protected]> | 2018-08-21 07:45:17 -0400 |
|---|---|---|
| committer | Eric Paris <[email protected]> | 2018-08-21 07:45:17 -0400 |
| commit | d929dcbb10863323c436af3cf76cb16a6dfc9b29 (patch) | |
| tree | 57482235bc31513e488433c7275f5d5f104edebc | |
| parent | 947b89bd1b7dabfed991ac30e1a56f5193f0c88b (diff) | |
Handle single string=>string flags without quotes (#179)
OK: --f1 "a=5,6" --f2 b=3,4 --f3 "c=5,6",d=7
Not OK: --f4 c="5,6"
| -rw-r--r-- | string_to_string.go | 19 | ||||
| -rw-r--r-- | string_to_string_test.go | 14 |
2 files changed, 24 insertions, 9 deletions
diff --git a/string_to_string.go b/string_to_string.go index 64892db..890a01a 100644 --- a/string_to_string.go +++ b/string_to_string.go @@ -22,11 +22,22 @@ func newStringToStringValue(val map[string]string, p *map[string]string) *string // Format: a=1,b=2 func (s *stringToStringValue) Set(val string) error { - r := csv.NewReader(strings.NewReader(val)) - ss, err := r.Read() - if err != nil { - return err + var ss []string + n := strings.Count(val, "=") + switch n { + case 0: + return fmt.Errorf("%s must be formatted as key=value", val) + case 1: + ss = append(ss, strings.Trim(val, `"`)) + default: + r := csv.NewReader(strings.NewReader(val)) + var err error + ss, err = r.Read() + if err != nil { + return err + } } + out := make(map[string]string, len(ss)) for _, pair := range ss { kv := strings.SplitN(pair, "=", 2) diff --git a/string_to_string_test.go b/string_to_string_test.go index f1aae04..0777f03 100644 --- a/string_to_string_test.go +++ b/string_to_string_test.go @@ -140,16 +140,20 @@ func TestS2SCalledTwice(t *testing.T) { var s2s map[string]string f := setUpS2SFlagSet(&s2s) - in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f="7,8"`} + in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f=7,8`} expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"} argfmt := "--s2s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - arg3 := fmt.Sprintf(argfmt, in[2]) - err := f.Parse([]string{arg1, arg2, arg3}) + arg0 := fmt.Sprintf(argfmt, in[0]) + arg1 := fmt.Sprintf(argfmt, in[1]) + arg2 := fmt.Sprintf(argfmt, in[2]) + arg3 := fmt.Sprintf(argfmt, in[3]) + err := f.Parse([]string{arg0, arg1, arg2, arg3}) if err != nil { t.Fatal("expected no error; got", err) } + if len(s2s) != len(expected) { + t.Fatalf("expected %d flags; got %d flags", len(expected), len(s2s)) + } for i, v := range s2s { if expected[i] != v { t.Fatalf("expected s2s[%s] to be %s but got: %s", i, expected[i], v) |
