aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--golangflag.go34
-rw-r--r--golangflag_test.go74
2 files changed, 108 insertions, 0 deletions
diff --git a/golangflag.go b/golangflag.go
index f563907..e62eab5 100644
--- a/golangflag.go
+++ b/golangflag.go
@@ -8,6 +8,7 @@ import (
goflag "flag"
"reflect"
"strings"
+ "time"
)
// go test flags prefixes
@@ -113,6 +114,38 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) {
f.addedGoFlagSets = append(f.addedGoFlagSets, newSet)
}
+// CopyToGoFlagSet will add all current flags to the given Go flag set.
+// Deprecation remarks get copied into the usage description.
+// Whenever possible, a flag gets added for which Go flags shows
+// a proper type in the help message.
+func (f *FlagSet) CopyToGoFlagSet(newSet *goflag.FlagSet) {
+ f.VisitAll(func(flag *Flag) {
+ usage := flag.Usage
+ if flag.Deprecated != "" {
+ usage += " (DEPRECATED: " + flag.Deprecated + ")"
+ }
+
+ switch value := flag.Value.(type) {
+ case *stringValue:
+ newSet.StringVar((*string)(value), flag.Name, flag.DefValue, usage)
+ case *intValue:
+ newSet.IntVar((*int)(value), flag.Name, *(*int)(value), usage)
+ case *int64Value:
+ newSet.Int64Var((*int64)(value), flag.Name, *(*int64)(value), usage)
+ case *uintValue:
+ newSet.UintVar((*uint)(value), flag.Name, *(*uint)(value), usage)
+ case *uint64Value:
+ newSet.Uint64Var((*uint64)(value), flag.Name, *(*uint64)(value), usage)
+ case *durationValue:
+ newSet.DurationVar((*time.Duration)(value), flag.Name, *(*time.Duration)(value), usage)
+ case *float64Value:
+ newSet.Float64Var((*float64)(value), flag.Name, *(*float64)(value), usage)
+ default:
+ newSet.Var(flag.Value, flag.Name, usage)
+ }
+ })
+}
+
// ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(),
// since by default those are skipped by pflag.Parse().
// Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)`
@@ -125,3 +158,4 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error {
}
return goFlagSet.Parse(skippedFlags)
}
+
diff --git a/golangflag_test.go b/golangflag_test.go
index 2ecefef..7309808 100644
--- a/golangflag_test.go
+++ b/golangflag_test.go
@@ -7,6 +7,7 @@ package pflag
import (
goflag "flag"
"testing"
+ "time"
)
func TestGoflags(t *testing.T) {
@@ -59,3 +60,76 @@ func TestGoflags(t *testing.T) {
t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called")
}
}
+
+func TestToGoflags(t *testing.T) {
+ pfs := FlagSet{}
+ gfs := goflag.FlagSet{}
+ pfs.String("StringFlag", "String value", "String flag usage")
+ pfs.Int("IntFlag", 1, "Int flag usage")
+ pfs.Uint("UintFlag", 2, "Uint flag usage")
+ pfs.Int64("Int64Flag", 3, "Int64 flag usage")
+ pfs.Uint64("Uint64Flag", 4, "Uint64 flag usage")
+ pfs.Int8("Int8Flag", 5, "Int8 flag usage")
+ pfs.Float64("Float64Flag", 6.0, "Float64 flag usage")
+ pfs.Duration("DurationFlag", time.Second, "Duration flag usage")
+ pfs.Bool("BoolFlag", true, "Bool flag usage")
+ pfs.String("deprecated", "Deprecated value", "Deprecated flag usage")
+ pfs.MarkDeprecated("deprecated", "obsolete")
+
+ pfs.CopyToGoFlagSet(&gfs)
+
+ // Modify via pfs. Should be visible via gfs because both share the
+ // same values.
+ for name, value := range map[string]string{
+ "StringFlag": "Modified String value",
+ "IntFlag": "11",
+ "UintFlag": "12",
+ "Int64Flag": "13",
+ "Uint64Flag": "14",
+ "Int8Flag": "15",
+ "Float64Flag": "16.0",
+ "BoolFlag": "false",
+ } {
+ pf := pfs.Lookup(name)
+ if pf == nil {
+ t.Errorf("%s: not found in pflag flag set", name)
+ continue
+ }
+ if err := pf.Value.Set(value); err != nil {
+ t.Errorf("error setting %s = %s: %v", name, value, err)
+ }
+ }
+
+ // Check that all flags were added and share the same value.
+ pfs.VisitAll(func(pf *Flag) {
+ gf := gfs.Lookup(pf.Name)
+ if gf == nil {
+ t.Errorf("%s: not found in Go flag set", pf.Name)
+ return
+ }
+ if gf.Value.String() != pf.Value.String() {
+ t.Errorf("%s: expected value %v from Go flag set, got %v",
+ pf.Name, pf.Value, gf.Value)
+ return
+ }
+ })
+
+ // Check for unexpected additional flags.
+ gfs.VisitAll(func(gf *goflag.Flag) {
+ pf := gfs.Lookup(gf.Name)
+ if pf == nil {
+ t.Errorf("%s: not found in pflag flag set", gf.Name)
+ return
+ }
+ })
+
+ deprecated := gfs.Lookup("deprecated")
+ if deprecated == nil {
+ t.Error("deprecated: not found in Go flag set")
+ } else {
+ expectedUsage := "Deprecated flag usage (DEPRECATED: obsolete)"
+ if deprecated.Usage != expectedUsage {
+ t.Errorf("deprecation remark not added, expected usage %q, got %q", expectedUsage, deprecated.Usage)
+ }
+ }
+}