aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHu Jun <[email protected]>2024-12-27 12:06:14 -0800
committerHu Jun <[email protected]>2024-12-27 12:06:14 -0800
commit1118d46dfe34c2d77aee04d819315a21942aa444 (patch)
tree6dc9f270994b1db1a37499015b7596e490e882b4
parentd5e0c0615acee7028e1e2740a11102313be88de1 (diff)
add support equivalent to golang flag.TextVar(), also fixes the test failure as described in #368
-rw-r--r--flag_test.go6
-rw-r--r--text.go81
-rw-r--r--text_test.go53
3 files changed, 136 insertions, 4 deletions
diff --git a/flag_test.go b/flag_test.go
index 58a5d25..9faaba4 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -1134,7 +1134,6 @@ func TestMultipleNormalizeFlagNameInvocations(t *testing.T) {
}
}
-//
func TestHiddenFlagInUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("secretFlag", true, "shhh")
@@ -1149,7 +1148,6 @@ func TestHiddenFlagInUsage(t *testing.T) {
}
}
-//
func TestHiddenFlagUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("secretFlag", true, "shhh")
@@ -1238,8 +1236,8 @@ func TestPrintDefaults(t *testing.T) {
fs.PrintDefaults()
got := buf.String()
if got != defaultOutput {
- fmt.Println("\n" + got)
- fmt.Println("\n" + defaultOutput)
+ fmt.Print("\n" + got + "\n")
+ fmt.Print("\n" + defaultOutput + "\n")
t.Errorf("got %q want %q\n", got, defaultOutput)
}
}
diff --git a/text.go b/text.go
new file mode 100644
index 0000000..3726606
--- /dev/null
+++ b/text.go
@@ -0,0 +1,81 @@
+package pflag
+
+import (
+ "encoding"
+ "fmt"
+ "reflect"
+)
+
+// following is copied from go 1.23.4 flag.go
+type textValue struct{ p encoding.TextUnmarshaler }
+
+func newTextValue(val encoding.TextMarshaler, p encoding.TextUnmarshaler) textValue {
+ ptrVal := reflect.ValueOf(p)
+ if ptrVal.Kind() != reflect.Ptr {
+ panic("variable value type must be a pointer")
+ }
+ defVal := reflect.ValueOf(val)
+ if defVal.Kind() == reflect.Ptr {
+ defVal = defVal.Elem()
+ }
+ if defVal.Type() != ptrVal.Type().Elem() {
+ panic(fmt.Sprintf("default type does not match variable type: %v != %v", defVal.Type(), ptrVal.Type().Elem()))
+ }
+ ptrVal.Elem().Set(defVal)
+ return textValue{p}
+}
+
+func (v textValue) Set(s string) error {
+ return v.p.UnmarshalText([]byte(s))
+}
+
+func (v textValue) Get() interface{} {
+ return v.p
+}
+
+func (v textValue) String() string {
+ if m, ok := v.p.(encoding.TextMarshaler); ok {
+ if b, err := m.MarshalText(); err == nil {
+ return string(b)
+ }
+ }
+ return ""
+}
+
+//end of copy
+
+func (v textValue) Type() string {
+ return reflect.ValueOf(v.p).Type().Name()
+}
+
+// GetText set out, which implements encoding.UnmarshalText, to the value of a flag with given name
+func (f *FlagSet) GetText(name string, out encoding.TextUnmarshaler) error {
+ flag := f.Lookup(name)
+ if flag == nil {
+ return fmt.Errorf("flag accessed but not defined: %s", name)
+ }
+ if flag.Value.Type() != reflect.TypeOf(out).Name() {
+ fmt.Errorf("trying to get %s value of flag of type %s", reflect.TypeOf(out).Name(), flag.Value.Type())
+ }
+ return out.UnmarshalText([]byte(flag.Value.String()))
+}
+
+// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p.
+func (f *FlagSet) TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
+ f.VarP(newTextValue(value, p), name, "", usage)
+}
+
+// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) {
+ f.VarP(newTextValue(value, p), name, shorthand, usage)
+}
+
+// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p.
+func TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
+ CommandLine.VarP(newTextValue(value, p), name, "", usage)
+}
+
+// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash.
+func TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) {
+ CommandLine.VarP(newTextValue(value, p), name, shorthand, usage)
+}
diff --git a/text_test.go b/text_test.go
new file mode 100644
index 0000000..2a667ab
--- /dev/null
+++ b/text_test.go
@@ -0,0 +1,53 @@
+package pflag
+
+import (
+ "fmt"
+ "os"
+ "testing"
+ "time"
+)
+
+func setUpTime(t *time.Time) *FlagSet {
+ f := NewFlagSet("test", ContinueOnError)
+ f.TextVar(t, "time", time.Now(), "time stamp")
+ return f
+}
+
+func TestText(t *testing.T) {
+ testCases := []struct {
+ input string
+ success bool
+ expected time.Time
+ }{
+ {"2003-01-02T15:04:05Z", true, time.Date(2003, 1, 2, 15, 04, 05, 0, time.UTC)},
+ {"2003-01-02 15:05:01", false, time.Date(2002, 1, 2, 15, 05, 05, 07, time.UTC)},
+ {"2024-11-22T03:01:02Z", true, time.Date(2024, 11, 22, 3, 1, 02, 0, time.UTC)},
+ {"2006-01-02T15:04:05+07:00", true, time.Date(2006, 1, 2, 15, 4, 5, 0, time.FixedZone("UTC+7", 7*60*60))},
+ }
+
+ devnull, _ := os.Open(os.DevNull)
+ os.Stderr = devnull
+ for i := range testCases {
+ var ts time.Time
+ f := setUpTime(&ts)
+ tc := &testCases[i]
+ arg := fmt.Sprintf("--time=%s", tc.input)
+ err := f.Parse([]string{arg})
+ if err != nil && tc.success == true {
+ t.Errorf("expected success, got %q", err)
+ continue
+ } else if err == nil && tc.success == false {
+ t.Errorf("expected failure, but succeeded")
+ continue
+ } else if tc.success {
+ parsedT := new(time.Time)
+ err := f.GetText("time", parsedT)
+ if err != nil {
+ t.Errorf("Got error trying to fetch the time flag: %v", err)
+ }
+ if !parsedT.Equal(tc.expected) {
+ t.Errorf("expected %q, got %q", tc.expected, parsedT)
+ }
+ }
+ }
+}