aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Paris <[email protected]>2015-08-05 18:42:34 -0400
committerEric Paris <[email protected]>2015-08-05 19:32:15 -0400
commit686b63fc3227f05db10725b91e8952fe58766b70 (patch)
tree0112a8a541820ee6b593de6b5a5aa7af379529fa
parent6fc1d08da4a3ef902957f13557f8471f78229712 (diff)
New counter flag type
``` -vvv ``` Will give a value of 3 ``` -v=4 -v ``` Will give a value of 5
-rw-r--r--count.go84
-rw-r--r--count_test.go55
2 files changed, 139 insertions, 0 deletions
diff --git a/count.go b/count.go
new file mode 100644
index 0000000..d061368
--- /dev/null
+++ b/count.go
@@ -0,0 +1,84 @@
+package pflag
+
+import (
+ "fmt"
+ "strconv"
+)
+
+// -- count Value
+type countValue int
+
+func newCountValue(val int, p *int) *countValue {
+ *p = val
+ return (*countValue)(p)
+}
+
+func (i *countValue) Set(s string) error {
+ v, err := strconv.ParseInt(s, 0, 64)
+ // -1 means that no specific value was passed, so increment
+ if v == -1 {
+ *i = countValue(*i + 1)
+ } else {
+ *i = countValue(v)
+ }
+ return err
+}
+
+func (i *countValue) Type() string {
+ return "count"
+}
+
+func (i *countValue) String() string { return fmt.Sprintf("%v", *i) }
+
+func countConv(sval string) (interface{}, error) {
+ i, err := strconv.Atoi(sval)
+ if err != nil {
+ return nil, err
+ }
+ return i, nil
+}
+
+func (f *FlagSet) GetCount(name string) (int, error) {
+ val, err := f.getFlagType(name, "count", countConv)
+ if err != nil {
+ return 0, err
+ }
+ return val.(int), nil
+}
+
+func (f *FlagSet) CountVar(p *int, name string, usage string) {
+ f.CountVarP(p, name, "", usage)
+}
+
+func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) {
+ flag := f.VarPF(newCountValue(0, p), name, shorthand, usage)
+ flag.NoOptDefVal = "-1"
+}
+
+func CountVar(p *int, name string, usage string) {
+ CommandLine.CountVar(p, name, usage)
+}
+
+func CountVarP(p *int, name, shorthand string, usage string) {
+ CommandLine.CountVarP(p, name, shorthand, usage)
+}
+
+func (f *FlagSet) Count(name string, usage string) *int {
+ p := new(int)
+ f.CountVarP(p, name, "", usage)
+ return p
+}
+
+func (f *FlagSet) CountP(name, shorthand string, usage string) *int {
+ p := new(int)
+ f.CountVarP(p, name, shorthand, usage)
+ return p
+}
+
+func Count(name string, usage string) *int {
+ return CommandLine.CountP(name, "", usage)
+}
+
+func CountP(name, shorthand string, usage string) *int {
+ return CommandLine.CountP(name, shorthand, usage)
+}
diff --git a/count_test.go b/count_test.go
new file mode 100644
index 0000000..716765c
--- /dev/null
+++ b/count_test.go
@@ -0,0 +1,55 @@
+package pflag
+
+import (
+ "fmt"
+ "os"
+ "testing"
+)
+
+var _ = fmt.Printf
+
+func setUpCount(c *int) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.CountVarP(c, "verbose", "v", "a counter")
+ return f
+}
+
+func TestCount(t *testing.T) {
+ testCases := []struct {
+ input []string
+ success bool
+ expected int
+ }{
+ {[]string{"-vvv"}, true, 3},
+ {[]string{"-v", "-v", "-v"}, true, 3},
+ {[]string{"-v", "--verbose", "-v"}, true, 3},
+ {[]string{"-v=3", "-v"}, true, 4},
+ {[]string{"-v=a"}, false, 0},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var count int
+ f := setUpCount(&count)
+
+ tc := &testCases[i]
+
+ err := f.Parse(tc.input)
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure, got success")
+ continue
+ } else if tc.success {
+ c, err := f.GetCount("verbose")
+ if err != nil {
+ t.Errorf("Got error trying to fetch the counter flag")
+ }
+ if c != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, c)
+ }
+ }
+ }
+}