diff options
| -rw-r--r-- | count.go | 84 | ||||
| -rw-r--r-- | count_test.go | 55 |
2 files changed, 139 insertions, 0 deletions
diff --git a/count.go b/count.go new file mode 100644 index 0000000..d061368 --- /dev/null +++ b/count.go @@ -0,0 +1,84 @@ +package pflag + +import ( + "fmt" + "strconv" +) + +// -- count Value +type countValue int + +func newCountValue(val int, p *int) *countValue { + *p = val + return (*countValue)(p) +} + +func (i *countValue) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + // -1 means that no specific value was passed, so increment + if v == -1 { + *i = countValue(*i + 1) + } else { + *i = countValue(v) + } + return err +} + +func (i *countValue) Type() string { + return "count" +} + +func (i *countValue) String() string { return fmt.Sprintf("%v", *i) } + +func countConv(sval string) (interface{}, error) { + i, err := strconv.Atoi(sval) + if err != nil { + return nil, err + } + return i, nil +} + +func (f *FlagSet) GetCount(name string) (int, error) { + val, err := f.getFlagType(name, "count", countConv) + if err != nil { + return 0, err + } + return val.(int), nil +} + +func (f *FlagSet) CountVar(p *int, name string, usage string) { + f.CountVarP(p, name, "", usage) +} + +func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) { + flag := f.VarPF(newCountValue(0, p), name, shorthand, usage) + flag.NoOptDefVal = "-1" +} + +func CountVar(p *int, name string, usage string) { + CommandLine.CountVar(p, name, usage) +} + +func CountVarP(p *int, name, shorthand string, usage string) { + CommandLine.CountVarP(p, name, shorthand, usage) +} + +func (f *FlagSet) Count(name string, usage string) *int { + p := new(int) + f.CountVarP(p, name, "", usage) + return p +} + +func (f *FlagSet) CountP(name, shorthand string, usage string) *int { + p := new(int) + f.CountVarP(p, name, shorthand, usage) + return p +} + +func Count(name string, usage string) *int { + return CommandLine.CountP(name, "", usage) +} + +func CountP(name, shorthand string, usage string) *int { + return CommandLine.CountP(name, shorthand, usage) +} diff --git a/count_test.go b/count_test.go new file mode 100644 index 0000000..716765c --- /dev/null +++ b/count_test.go @@ -0,0 +1,55 @@ +package pflag + +import ( + "fmt" + "os" + "testing" +) + +var _ = fmt.Printf + +func setUpCount(c *int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.CountVarP(c, "verbose", "v", "a counter") + return f +} + +func TestCount(t *testing.T) { + testCases := []struct { + input []string + success bool + expected int + }{ + {[]string{"-vvv"}, true, 3}, + {[]string{"-v", "-v", "-v"}, true, 3}, + {[]string{"-v", "--verbose", "-v"}, true, 3}, + {[]string{"-v=3", "-v"}, true, 4}, + {[]string{"-v=a"}, false, 0}, + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + for i := range testCases { + var count int + f := setUpCount(&count) + + tc := &testCases[i] + + err := f.Parse(tc.input) + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + t.Errorf("expected failure, got success") + continue + } else if tc.success { + c, err := f.GetCount("verbose") + if err != nil { + t.Errorf("Got error trying to fetch the counter flag") + } + if c != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, c) + } + } + } +} |
