diff --git a/internal/update/update.go b/internal/update/update.go index 7202df65..1b71b5bc 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -26,6 +26,8 @@ const ( checksumsAssetName = "SHA256SUMS" defaultGitHubBaseURL = "https://github.com" maxChecksumsBytes = 1 << 20 + metadataTimeout = 30 * time.Second + downloadTimeout = 30 * time.Minute ) type UpdateInfo = selfupdate.Info @@ -79,7 +81,7 @@ func GetCacheDir() string { func NewUpdater(deps Deps) *Updater { if deps.Client == nil { - deps.Client = &http.Client{Timeout: 30 * time.Second} + deps.Client = defaultHTTPClient() } if deps.Now == nil { deps.Now = time.Now @@ -113,7 +115,9 @@ func (u *Updater) CheckForUpdate(forceCheck bool) (*UpdateInfo, error) { if selfupdate.IsDevBuildVersion(u.deps.Version) && !forceCheck { return nil, nil } - info, err := u.client().Check(context.Background(), selfupdate.CheckOptions{ + ctx, cancel := context.WithTimeout(context.Background(), metadataTimeout) + defer cancel() + info, err := u.client().Check(ctx, selfupdate.CheckOptions{ Force: forceCheck, GOOS: u.deps.GOOS, GOARCH: u.deps.GOARCH, @@ -125,7 +129,9 @@ func (u *Updater) CheckForUpdate(forceCheck bool) (*UpdateInfo, error) { // runners, corporate NAT) routinely exhaust the per-IP rate limit with // a 403. Resolve the release through github.com instead, which is not // subject to API rate limits. - info, fallbackErr := u.fallbackCheck(context.Background()) + fallbackCtx, fallbackCancel := context.WithTimeout(context.Background(), metadataTimeout) + defer fallbackCancel() + info, fallbackErr := u.fallbackCheck(fallbackCtx) if fallbackErr != nil { return nil, fmt.Errorf("%w (github.com fallback also failed: %w)", err, fallbackErr) } @@ -312,7 +318,9 @@ func (u *Updater) PerformUpdate(info *UpdateInfo, reporter Reporter) error { dstPath := filepath.Join(installDir, targetBinary) reporter.Stepf("Downloading %s...\n", info.AssetName) - if err := u.client().Install(context.Background(), info, selfupdate.InstallOptions{ + ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout) + defer cancel() + if err := u.client().Install(ctx, info, selfupdate.InstallOptions{ DestinationPath: dstPath, ArchiveBinaryName: targetBinary, Progress: reporter.Progress, @@ -385,6 +393,16 @@ func (r stdoutReporter) Progress(downloaded, total int64) { } } +func defaultHTTPClient() *http.Client { + transport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return &http.Client{} + } + cloned := transport.Clone() + cloned.ResponseHeaderTimeout = metadataTimeout + return &http.Client{Transport: cloned} +} + func (nopReporter) Stepf(string, ...any) {} func (nopReporter) Progress(int64, int64) {} diff --git a/internal/update/update_test.go b/internal/update/update_test.go index 9efc7e5f..7335f8d3 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -75,6 +75,13 @@ func TestUpdaterCheckForUpdateSkipsNetworkWithFreshCache(t *testing.T) { assert.Equal(t, 0, requests) } +func TestNewUpdaterDefaultClientDoesNotSetWholeRequestTimeout(t *testing.T) { + updater := NewUpdater(Deps{}) + + require.NotNil(t, updater.deps.Client) + assert.Zero(t, updater.deps.Client.Timeout) +} + func TestUpdaterCheckForUpdateUsesKitConventionalReleaseDiscovery(t *testing.T) { const releaseTag = "v1.3.0" const assetName = "roborev_1.3.0_windows_amd64.zip"