diff --git a/command.go b/command.go index c91b793..4f31181 100644 --- a/command.go +++ b/command.go @@ -83,6 +83,10 @@ type FlagMetadata struct { // Required indicates whether the flag is required. Required bool + + // Local indicates that the flag should not be inherited by child commands. When true, the + // flag is only available on the command that defines it. + Local bool } // FlagsFunc is a helper function that creates a new [flag.FlagSet] and applies the given function diff --git a/parse.go b/parse.go index c6ec933..d9e57a0 100644 --- a/parse.go +++ b/parse.go @@ -110,12 +110,21 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) { name := strings.TrimLeft(arg, "-") skipValue := false for _, cmd := range root.state.path { + localFlags := localFlagSet(cmd.FlagsMetadata) + // Skip local flags on ancestor commands (any command already in the + // path is an ancestor of the not-yet-resolved terminal command). + if localFlags[name] { + continue + } // First try direct lookup. f := cmd.Flags.Lookup(name) // If not found, check if it's a short alias. if f == nil { for _, fm := range cmd.FlagsMetadata { if fm.Short == name { + if localFlags[fm.Name] { + break + } f = cmd.Flags.Lookup(fm.Name) break } @@ -161,13 +170,20 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) { func combineFlags(path []*Command) *flag.FlagSet { combined := flag.NewFlagSet(path[0].Name, flag.ContinueOnError) combined.SetOutput(io.Discard) - for i := len(path) - 1; i >= 0; i-- { + terminalIdx := len(path) - 1 + for i := terminalIdx; i >= 0; i-- { cmd := path[i] if cmd.Flags == nil { continue } + localFlags := localFlagSet(cmd.FlagsMetadata) shortMap := shortFlagMap(cmd.FlagsMetadata) + isAncestor := i < terminalIdx cmd.Flags.VisitAll(func(f *flag.Flag) { + // Skip local flags from ancestor commands — they are not inherited. + if isAncestor && localFlags[f.Name] { + return + } if combined.Lookup(f.Name) == nil { combined.Var(f.Value, f.Name, f.Usage) } @@ -182,6 +198,17 @@ func combineFlags(path []*Command) *flag.FlagSet { return combined } +// localFlagSet builds a set of flag names that are marked as local in FlagsMetadata. +func localFlagSet(metadata []FlagMetadata) map[string]bool { + m := make(map[string]bool, len(metadata)) + for _, fm := range metadata { + if fm.Local { + m[fm.Name] = true + } + } + return m +} + // shortFlagMap builds a map from long flag name to short alias from FlagsMetadata. func shortFlagMap(metadata []FlagMetadata) map[string]string { m := make(map[string]string, len(metadata)) @@ -203,12 +230,17 @@ func checkRequiredFlags(path []*Command, combined *flag.FlagSet) error { setFlags[f.Name] = struct{}{} }) + terminalIdx := len(path) - 1 var missingFlags []string - for _, cmd := range path { + for i, cmd := range path { for _, flagMetadata := range cmd.FlagsMetadata { if !flagMetadata.Required { continue } + // Skip required-flag checks for local flags on ancestor commands. + if flagMetadata.Local && i < terminalIdx { + continue + } if combined.Lookup(flagMetadata.Name) == nil { return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(path), formatFlagName(flagMetadata.Name)) } diff --git a/parse_test.go b/parse_test.go index ca2c3e0..fc71c9c 100644 --- a/parse_test.go +++ b/parse_test.go @@ -836,6 +836,164 @@ func TestShortFlags(t *testing.T) { }) } +func TestLocalFlags(t *testing.T) { + t.Parallel() + + t.Run("local flag on parent not available to child", func(t *testing.T) { + t.Parallel() + child := &Command{ + Name: "child", + Exec: func(ctx context.Context, s *State) error { return nil }, + } + root := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("version", false, "show version") + f.Bool("verbose", false, "enable verbose output") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "version", Local: true}, + }, + SubCommands: []*Command{child}, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + // --version on child should fail because it's local to root + err := Parse(root, []string{"child", "--version"}) + require.Error(t, err) + require.ErrorContains(t, err, "flag provided but not defined") + + // --verbose on child should still work (not local) + root2 := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("version", false, "show version") + f.Bool("verbose", false, "enable verbose output") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "version", Local: true}, + }, + SubCommands: []*Command{{ + Name: "child", + Exec: func(ctx context.Context, s *State) error { return nil }, + }}, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + err = Parse(root2, []string{"child", "--verbose"}) + require.NoError(t, err) + assert.True(t, GetFlag[bool](root2.state, "verbose")) + }) + + t.Run("local flag works on defining command", func(t *testing.T) { + t.Parallel() + root := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("version", false, "show version") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "version", Local: true}, + }, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + err := Parse(root, []string{"--version"}) + require.NoError(t, err) + assert.True(t, GetFlag[bool](root.state, "version")) + }) + + t.Run("local required flag only enforced on defining command", func(t *testing.T) { + t.Parallel() + child := &Command{ + Name: "child", + Exec: func(ctx context.Context, s *State) error { return nil }, + } + root := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.String("token", "", "auth token") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "token", Required: true, Local: true}, + }, + SubCommands: []*Command{child}, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + // Child command should not require parent's local required flag + err := Parse(root, []string{"child"}) + require.NoError(t, err) + + // But root command itself should still require it + root2 := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.String("token", "", "auth token") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "token", Required: true, Local: true}, + }, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + err = Parse(root2, []string{}) + require.Error(t, err) + require.ErrorContains(t, err, "required flag") + }) + + t.Run("usage excludes local parent flags from inherited flags", func(t *testing.T) { + t.Parallel() + child := &Command{ + Name: "child", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("dry-run", false, "dry run mode") + }), + Exec: func(ctx context.Context, s *State) error { return nil }, + } + root := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("version", false, "show version") + f.Bool("verbose", false, "enable verbose output") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "version", Local: true}, + }, + SubCommands: []*Command{child}, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + err := Parse(root, []string{"child", "--help"}) + require.ErrorIs(t, err, flag.ErrHelp) + + usage := DefaultUsage(root) + // --verbose should appear in inherited flags (not local) + assert.Contains(t, usage, "--verbose") + // --version should NOT appear (local to root, not inherited) + assert.NotContains(t, usage, "--version") + // --dry-run should appear in local flags + assert.Contains(t, usage, "--dry-run") + }) + + t.Run("local flag with short alias not inherited", func(t *testing.T) { + t.Parallel() + child := &Command{ + Name: "child", + Exec: func(ctx context.Context, s *State) error { return nil }, + } + root := &Command{ + Name: "root", + Flags: FlagsFunc(func(f *flag.FlagSet) { + f.Bool("version", false, "show version") + }), + FlagsMetadata: []FlagMetadata{ + {Name: "version", Short: "V", Local: true}, + }, + SubCommands: []*Command{child}, + Exec: func(ctx context.Context, s *State) error { return nil }, + } + // Short alias -V should also not work on child + err := Parse(root, []string{"child", "-V"}) + require.Error(t, err) + require.ErrorContains(t, err, "flag provided but not defined") + }) +} + func getCommand(t *testing.T, c *Command) *Command { require.NotNil(t, c) require.NotNil(t, c.state) diff --git a/usage.go b/usage.go index 0ba6de5..1a5fb54 100644 --- a/usage.go +++ b/usage.go @@ -90,19 +90,26 @@ func DefaultUsage(root *Command) string { var flags []flagInfo if root.state != nil && len(root.state.path) > 0 { + terminalIdx := len(root.state.path) - 1 for i, cmd := range root.state.path { if cmd.Flags == nil { continue } - isGlobal := i < len(root.state.path)-1 + isInherited := i < terminalIdx metaMap := flagMetadataMap(cmd.FlagsMetadata) cmd.Flags.VisitAll(func(f *flag.Flag) { + // Skip local flags from ancestor commands — they don't appear in child help. + if isInherited { + if m, ok := metaMap[f.Name]; ok && m.Local { + return + } + } fi := flagInfo{ - name: "--" + f.Name, - usage: f.Usage, - defval: f.DefValue, - typeName: flagTypeName(f), - global: isGlobal, + name: "--" + f.Name, + usage: f.Usage, + defval: f.DefValue, + typeName: flagTypeName(f), + inherited: isInherited, } if m, ok := metaMap[f.Name]; ok { fi.required = m.Required @@ -150,10 +157,10 @@ func DefaultUsage(root *Command) string { } hasLocal := false - hasGlobal := false + hasInherited := false for _, f := range flags { - if f.global { - hasGlobal = true + if f.inherited { + hasInherited = true } else { hasLocal = true } @@ -165,8 +172,8 @@ func DefaultUsage(root *Command) string { b.WriteString("\n") } - if hasGlobal { - b.WriteString("Global Flags:\n") + if hasInherited { + b.WriteString("Inherited Flags:\n") writeFlagSection(&b, flags, maxFlagLen, true, hasAnyShort) b.WriteString("\n") } @@ -184,12 +191,12 @@ func DefaultUsage(root *Command) string { } // writeFlagSection handles the formatting of flag descriptions -func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, global, hasAnyShort bool) { +func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, inherited, hasAnyShort bool) { nameWidth := maxLen + 4 wrapWidth := defaultTerminalWidth - nameWidth for _, f := range flags { - if f.global != global { + if f.inherited != inherited { continue } @@ -222,13 +229,13 @@ func flagMetadataMap(metadata []FlagMetadata) map[string]FlagMetadata { } type flagInfo struct { - name string - short string - usage string - defval string - typeName string - global bool - required bool + name string + short string + usage string + defval string + typeName string + inherited bool + required bool } // displayName returns the flag name with optional short alias and type hint. When hasAnyShort is diff --git a/usage_test.go b/usage_test.go index 021e3ca..014cf8a 100644 --- a/usage_test.go +++ b/usage_test.go @@ -305,7 +305,7 @@ func TestUsageGeneration(t *testing.T) { require.Contains(t, output, "custom [options] ") }) - t.Run("usage with global and local flags", func(t *testing.T) { + t.Run("usage with inherited and local flags", func(t *testing.T) { t.Parallel() child := &Command{ @@ -487,6 +487,6 @@ func TestWriteFlagSection(t *testing.T) { output := DefaultUsage(cmd) require.NotContains(t, output, "Flags:") - require.NotContains(t, output, "Global Flags:") + require.NotContains(t, output, "Inherited Flags:") }) }