From 73212925ee112189062514987f479ffd3d25018d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:17:10 +0000 Subject: [PATCH 1/3] Initial plan From 48664659533708cce5d5fbf7eeed3139e4ea6814 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:29:11 +0000 Subject: [PATCH 2/3] Add HTTP header support via --header flag and PGET_HEADERS env var Co-authored-by: markphelps <209477+markphelps@users.noreply.github.com> --- cmd/root/root.go | 19 +++++++++++++++++ pkg/client/client.go | 4 +++- pkg/client/client_test.go | 45 +++++++++++++++++++++++++++++++++++++++ pkg/config/config.go | 34 +++++++++++++++++++++++++++++ pkg/config/config_test.go | 27 +++++++++++++++++++++++ pkg/config/optnames.go | 1 + 6 files changed, 129 insertions(+), 1 deletion(-) diff --git a/cmd/root/root.go b/cmd/root/root.go index 4853f91b..f5dbc674 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -153,6 +153,24 @@ func rootPersistentPreRunEFunc(cmd *cobra.Command, args []string) error { viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) } + // Process headers from CLI flags + headerSlice := viper.GetStringSlice(config.OptHeader) + if len(headerSlice) > 0 { + headerMap, err := config.HeadersToMap(headerSlice) + if err != nil { + return fmt.Errorf("error parsing headers: %w", err) + } + // Merge with any existing headers from environment variable + existingHeaders := viper.GetStringMapString(config.OptHeaders) + if existingHeaders == nil { + existingHeaders = make(map[string]string) + } + for k, v := range headerMap { + existingHeaders[k] = v + } + viper.Set(config.OptHeaders, existingHeaders) + } + return nil } @@ -179,6 +197,7 @@ func persistentFlags(cmd *cobra.Command) error { cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host") cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)") cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path") + cmd.PersistentFlags().StringSliceP(config.OptHeader, "H", []string{}, "HTTP headers to include in requests (format: 'Key: Value')") if err := hideAndDeprecateFlags(cmd); err != nil { return err diff --git a/pkg/client/client.go b/pkg/client/client.go index 70c4be19..cb41b148 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -39,10 +39,12 @@ type PGetHTTPClient struct { } func (c *PGetHTTPClient) Do(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion())) + // Set custom headers first for k, v := range c.headers { req.Header.Set(k, v) } + // Set User-Agent last to ensure it's always the pget user agent + req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion())) return c.Client.Do(req) } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index eafeac3b..fb110afb 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -5,10 +5,13 @@ import ( "fmt" "net" "net/http" + "net/http/httptest" "net/url" "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/config" @@ -160,3 +163,45 @@ func TestRetryPolicy(t *testing.T) { }) } } + +func TestPGetHTTPClient_Headers(t *testing.T) { + // Create a test server that echoes back the headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Write back the custom headers as response headers for verification + for key, values := range r.Header { + for _, value := range values { + w.Header().Add("Echo-"+key, value) + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Set up viper with custom headers + viper.Set(config.OptHeaders, map[string]string{ + "Authorization": "Bearer test-token", + "X-Custom-Header": "custom-value", + }) + defer viper.Reset() + + // Create client + httpClient := client.NewHTTPClient(client.Options{ + MaxRetries: 0, + }) + + // Make a request + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + + resp, err := httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify that our custom headers were sent + assert.Equal(t, "Bearer test-token", resp.Header.Get("Echo-Authorization")) + assert.Equal(t, "custom-value", resp.Header.Get("Echo-X-Custom-Header")) + + // Verify that User-Agent is set and contains "pget" + userAgent := resp.Header.Get("Echo-User-Agent") + assert.Contains(t, userAgent, "pget/") +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 519c632c..a51248d1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -140,6 +140,40 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) return resolveOverrideMap, nil } +// HeadersToMap converts a slice of header strings in the format "Key: Value" to a map[string]string. +// It merges with any existing headers from the PGET_HEADERS environment variable. +func HeadersToMap(headerSlice []string) (map[string]string, error) { + logger := logging.GetLogger() + headerMap := make(map[string]string) + + if len(headerSlice) == 0 { + return nil, nil + } + + for _, header := range headerSlice { + // Split on the first colon to separate key and value + parts := strings.SplitN(header, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid header format, expected 'Key: Value', got: %s", header) + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("header key cannot be empty in: %s", header) + } + + headerMap[key] = value + } + + if logger.GetLevel() == zerolog.DebugLevel { + for key, value := range headerMap { + logger.Debug().Str("header", key).Str("value", value).Msg("Header") + } + } + return headerMap, nil +} + // GetConsumer returns the consumer specified by the user on the command line // or an error if the consumer is invalid. Note that this function explicitly // calls viper.GetString(OptExtract) internally. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index f4cc016d..61f61101 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -185,3 +185,30 @@ func TestGetCacheSRV(t *testing.T) { }) } } + +func TestHeadersToMap(t *testing.T) { + testCases := []struct { + name string + headers []string + expected map[string]string + err bool + }{ + {"empty", []string{}, nil, false}, + {"single", []string{"Authorization: Bearer token123"}, map[string]string{"Authorization": "Bearer token123"}, false}, + {"multiple", []string{"Authorization: Bearer token123", "X-Custom-Header: value"}, map[string]string{"Authorization": "Bearer token123", "X-Custom-Header": "value"}, false}, + {"with spaces", []string{"Content-Type: application/json"}, map[string]string{"Content-Type": "application/json"}, false}, + {"value with colon", []string{"Authorization: Bearer: token:123"}, map[string]string{"Authorization": "Bearer: token:123"}, false}, + {"trim spaces", []string{" Authorization : Bearer token123 "}, map[string]string{"Authorization": "Bearer token123"}, false}, + {"invalid format no colon", []string{"InvalidHeader"}, nil, true}, + {"invalid format empty key", []string{": value"}, nil, true}, + {"empty value", []string{"X-Empty-Header:"}, map[string]string{"X-Empty-Header": ""}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + headers, err := HeadersToMap(tc.headers) + assert.Equal(t, tc.err, err != nil) + assert.Equal(t, tc.expected, headers) + }) + } +} diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 831cfbf7..2b6a7ec4 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -21,6 +21,7 @@ const ( OptExtract = "extract" OptForce = "force" OptForceHTTP2 = "force-http2" + OptHeader = "header" OptLoggingLevel = "log-level" OptMaxChunks = "max-chunks" OptMaxConnPerHost = "max-conn-per-host" From bd9e63462a71dc80bebc657eaa126f1611fcc70f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:30:27 +0000 Subject: [PATCH 3/3] Update README with HTTP header documentation and examples Co-authored-by: markphelps <209477+markphelps@users.noreply.github.com> --- README.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f03a2e10..d66537a6 100644 --- a/README.md +++ b/README.md @@ -55,12 +55,26 @@ This builds a static binary that can work inside containers. - Type: `bool` - Default: `false` -#### Example +#### Examples + +Download and extract an archive: pget https://storage.googleapis.com/replicant-misc/sd15.tar ./sd15 -x This command will download Stable Diffusion 1.5 weights to the path ./sd15 with high concurrency. After the file is downloaded, it will be automatically extracted. +Download with authentication headers: + + pget -H "Authorization: Bearer token123" https://api.example.com/file.tar ./file.tar + +Download with multiple custom headers: + + pget -H "Authorization: Bearer token123" -H "X-Custom-Header: value" https://api.example.com/file.tar ./file.tar + +Use environment variable for headers: + + PGET_HEADERS='{"Authorization":"Bearer token123"}' pget https://api.example.com/file.tar ./file.tar + ### Multi-File Mode pget multifile @@ -112,6 +126,11 @@ https://example.com/music.mp3 /local/path/to/music.mp3 - Force download, overwriting existing file - Type: `bool` - Default: `false` +- `-H`, `--header` + - HTTP headers to include in requests (format: 'Key: Value'), can be specified multiple times + - Type: `string slice` + - Example: `-H "Authorization: Bearer token123" -H "X-Custom-Header: value"` + - Environment variable: `PGET_HEADERS` (JSON map format: `{"Key":"Value"}`) - `--log-level` - Log level (debug, info, warn, error) - Type: `string`