Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- New `flagtype` package with common `flag.Value` implementations: `StringSlice`, `Enum`,
`StringMap`, `URL`, and `Regexp`

## [v0.5.0] - 2026-02-17

### Changed
Expand Down
95 changes: 95 additions & 0 deletions docs/design/001-flagtype-api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# 001 - flagtype API

**Date:** 2026-02-18

## Context

Users of pressly/cli must manually implement `flag.Value` (and `flag.Getter`) for common types like
string slices, enums, and maps. This is repetitive boilerplate that most CLI tools need.

## Decision

Use stdlib-native constructors that return `flag.Value`, registered via `f.Var()`.

```go
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
f.Bool("verbose", false, "enable verbose output")
f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)")
f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format")
f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)")
})
```

The flagtype package has no knowledge of `flag.FlagSet`. Each constructor returns a value that
implements `flag.Value` and `flag.Getter`. Storage is internal -- no destination pointers needed
since values are retrieved via `cli.GetFlag[T]`.

## Alternatives considered

### A: flagtype takes a FlagSet

```go
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
f.Bool("verbose", false, "enable verbose output")
flagtype.StringSlice(f, "tag", "add a tag (repeatable)")
flagtype.Enum(f, "format", "output format", "json", "yaml", "table")
})
```

One-liner registration, no `f.Var()` ceremony. Rejected because it introduces a second calling
convention in the same block -- stdlib flags use `f.Type(name, default, usage)` while flagtype would
use `flagtype.Type(f, name, usage)`. The argument ordering inconsistency makes it harder to read at
a glance.

### B: FlagSet wrapper

```go
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
f.Bool("verbose", false, "enable verbose output")
ft := flagtype.From(f)
ft.StringSlice("tag", "add a tag (repeatable)")
ft.Enum("format", "output format", "json", "yaml", "table")
})
```

Feels like a natural extension of FlagSet. Rejected because it requires managing two objects in the
same closure -- `f` for standard types and `ft` for custom types. Also adds a layer of indirection
that doesn't pull its weight.

### C: Declarative flag list

```go
Flags: []cli.Flag{
cli.String("output", "", "output file"),
cli.Bool("verbose", false, "enable verbose output"),
flagtype.StringSlice("tag", "add a tag (repeatable)"),
flagtype.Enum("format", "output format", "json", "yaml", "table"),
}
```

Fully declarative, no callback, no FlagSet. Rejected because it's a significant departure from the
stdlib `flag` package and would require rethinking the core `Command` type. Essentially a different
framework.

### D: Destination pointer pattern

```go
var tags []string
var re *regexp.Regexp
f.Var(flagtype.StringSlice(&tags), "tag", "add a tag (repeatable)")
f.Var(flagtype.Regexp(&re), "pattern", "regex pattern")
```

The initial implementation. Each constructor takes a pointer to the destination variable. Rejected
because pointer types like `*regexp.Regexp` and `*url.URL` require double pointers
(`**regexp.Regexp`), which is awkward. Since values are always retrieved via `cli.GetFlag[T]`, the
destination pointer serves no purpose.

## Why this approach

- **Zero new concepts.** Anyone who knows `flag.Var` already knows how to use flagtype.
- **No coupling.** flagtype has no dependency on the cli package or `flag.FlagSet`.
- **Consistent with stdlib.** Custom flag types in Go have always been registered via `f.Var()`.
This follows that convention exactly.
- **No double pointers.** Internal storage means the API is clean for all types, including pointer
types like `*url.URL` and `*regexp.Regexp`.
25 changes: 25 additions & 0 deletions flagtype/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Package flagtype provides common [flag.Value] implementations for use with [flag.FlagSet.Var].
//
// All types implement [flag.Getter] so they work with [cli.GetFlag].
//
// The following types are available:
// - [StringSlice] - repeatable flag that collects values into []string
// - [Enum] - restricts values to a predefined set, retrieved as string
// - [StringMap] - repeatable flag that parses key=value pairs into map[string]string
// - [URL] - parses and validates a URL (must have scheme and host), retrieved as *url.URL
// - [Regexp] - compiles a regular expression, retrieved as *regexp.Regexp
//
// Example registration:
//
// Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
// f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)")
// f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format")
// f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)")
// })
//
// Example retrieval in Exec:
//
// tags := cli.GetFlag[[]string](s, "tag")
// format := cli.GetFlag[string](s, "format")
// labels := cli.GetFlag[map[string]string](s, "label")
package flagtype
37 changes: 37 additions & 0 deletions flagtype/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package flagtype

import (
"flag"
"fmt"
"slices"
"strings"
)

type enumValue struct {
val string
allowed []string
}

// Enum returns a [flag.Value] that restricts the flag to one of the allowed values. If a value not
// in the allowed list is provided, an error is returned listing valid options.
//
// Use [cli.GetFlag] with type string to retrieve the value.
func Enum(allowed ...string) flag.Value {
return &enumValue{allowed: allowed}
}

func (v *enumValue) String() string {
return v.val
}

func (v *enumValue) Set(s string) error {
if !slices.Contains(v.allowed, s) {
return fmt.Errorf("invalid value %q, must be one of: %s", s, strings.Join(v.allowed, ", "))
}
v.val = s
return nil
}

func (v *enumValue) Get() any {
return v.val
}
208 changes: 208 additions & 0 deletions flagtype/flagtype_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package flagtype

import (
"flag"
"net/url"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestStringSlice(t *testing.T) {
t.Parallel()

t.Run("single value", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(StringSlice(), "tag", "")
err := fs.Parse([]string{"--tag=foo"})
require.NoError(t, err)
got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string)
assert.Equal(t, []string{"foo"}, got)
})
t.Run("multiple values", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(StringSlice(), "tag", "")
err := fs.Parse([]string{"--tag=foo", "--tag=bar", "--tag=baz"})
require.NoError(t, err)
got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string)
assert.Equal(t, []string{"foo", "bar", "baz"}, got)
})
t.Run("string output", func(t *testing.T) {
t.Parallel()
v := StringSlice()
require.NoError(t, v.Set("a"))
require.NoError(t, v.Set("b"))
assert.Equal(t, "a,b", v.String())
})
t.Run("empty", func(t *testing.T) {
t.Parallel()
v := StringSlice()
assert.Equal(t, "", v.String())
got := v.(flag.Getter).Get().([]string)
assert.Nil(t, got)
})
}

func TestEnum(t *testing.T) {
t.Parallel()

t.Run("valid value", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(Enum("json", "yaml", "table"), "format", "")
err := fs.Parse([]string{"--format=yaml"})
require.NoError(t, err)
got := fs.Lookup("format").Value.(flag.Getter).Get().(string)
assert.Equal(t, "yaml", got)
})
t.Run("invalid value", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(nopWriter{})
fs.Var(Enum("json", "yaml"), "format", "")
err := fs.Parse([]string{"--format=xml"})
require.Error(t, err)
assert.Contains(t, err.Error(), "must be one of")
assert.Contains(t, err.Error(), "json, yaml")
})
t.Run("empty default", func(t *testing.T) {
t.Parallel()
v := Enum("a", "b")
assert.Equal(t, "", v.String())
assert.Equal(t, "", v.(flag.Getter).Get())
})
}

func TestStringMap(t *testing.T) {
t.Parallel()

t.Run("single pair", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(StringMap(), "label", "")
err := fs.Parse([]string{"--label=env=prod"})
require.NoError(t, err)
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
assert.Equal(t, map[string]string{"env": "prod"}, got)
})
t.Run("multiple pairs", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(StringMap(), "label", "")
err := fs.Parse([]string{"--label=env=prod", "--label=tier=web"})
require.NoError(t, err)
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
assert.Equal(t, map[string]string{"env": "prod", "tier": "web"}, got)
})
t.Run("value contains equals", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(StringMap(), "label", "")
err := fs.Parse([]string{"--label=query=a=b"})
require.NoError(t, err)
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
assert.Equal(t, map[string]string{"query": "a=b"}, got)
})
t.Run("missing equals", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(nopWriter{})
fs.Var(StringMap(), "label", "")
err := fs.Parse([]string{"--label=nope"})
require.Error(t, err)
assert.Contains(t, err.Error(), "missing '='")
})
t.Run("empty key", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(nopWriter{})
fs.Var(StringMap(), "label", "")
err := fs.Parse([]string{"--label==value"})
require.Error(t, err)
assert.Contains(t, err.Error(), "empty key")
})
t.Run("string output sorted", func(t *testing.T) {
t.Parallel()
v := StringMap()
require.NoError(t, v.Set("b=2"))
require.NoError(t, v.Set("a=1"))
assert.Equal(t, "a=1,b=2", v.String())
})
t.Run("empty", func(t *testing.T) {
t.Parallel()
v := StringMap()
assert.Equal(t, "", v.String())
assert.Nil(t, v.(flag.Getter).Get())
})
}

func TestURL(t *testing.T) {
t.Parallel()

t.Run("valid url", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(URL(), "endpoint", "")
err := fs.Parse([]string{"--endpoint=https://example.com/api"})
require.NoError(t, err)
got := fs.Lookup("endpoint").Value.(flag.Getter).Get().(*url.URL)
require.NotNil(t, got)
assert.Equal(t, "https", got.Scheme)
assert.Equal(t, "example.com", got.Host)
assert.Equal(t, "/api", got.Path)
})
t.Run("missing scheme", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(nopWriter{})
fs.Var(URL(), "endpoint", "")
err := fs.Parse([]string{"--endpoint=example.com"})
require.Error(t, err)
assert.Contains(t, err.Error(), "must have a scheme and host")
})
t.Run("empty", func(t *testing.T) {
t.Parallel()
v := URL()
assert.Equal(t, "", v.String())
assert.Nil(t, v.(flag.Getter).Get())
})
}

func TestRegexp(t *testing.T) {
t.Parallel()

t.Run("valid pattern", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.Var(Regexp(), "pattern", "")
err := fs.Parse([]string{"--pattern=^foo.*bar$"})
require.NoError(t, err)
got := fs.Lookup("pattern").Value.(flag.Getter).Get().(*regexp.Regexp)
require.NotNil(t, got)
assert.True(t, got.MatchString("fooXbar"))
assert.False(t, got.MatchString("baz"))
})
t.Run("invalid pattern", func(t *testing.T) {
t.Parallel()
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(nopWriter{})
fs.Var(Regexp(), "pattern", "")
err := fs.Parse([]string{"--pattern=[invalid"})
require.Error(t, err)
})
t.Run("empty", func(t *testing.T) {
t.Parallel()
v := Regexp()
assert.Equal(t, "", v.String())
assert.Nil(t, v.(flag.Getter).Get())
})
}

// nopWriter discards all writes, used to suppress flag.FlagSet error output in tests.
type nopWriter struct{}

func (nopWriter) Write(p []byte) (int, error) { return len(p), nil }
Loading