diff --git a/cmd/crane/cmd/pull.go b/cmd/crane/cmd/pull.go index 41c6e95cd..4dd3a3360 100644 --- a/cmd/crane/cmd/pull.go +++ b/cmd/crane/cmd/pull.go @@ -32,6 +32,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { var ( cachePath, format string annotateRef bool + resumable bool ) cmd := &cobra.Command{ @@ -49,6 +50,10 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { return fmt.Errorf("parsing reference %q: %w", src, err) } + if resumable { + o.Remote = append(o.Remote, remote.WithResumable()) + } + rmt, err := remote.Get(ref, o.Remote...) if err != nil { return err @@ -133,6 +138,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command { cmd.Flags().StringVarP(&cachePath, "cache_path", "c", "", "Path to cache image layers") cmd.Flags().StringVar(&format, "format", "tarball", fmt.Sprintf("Format in which to save images (%q, %q, or %q)", "tarball", "legacy", "oci")) cmd.Flags().BoolVar(&annotateRef, "annotate-ref", false, "Preserves image reference used to pull as an annotation when used with --format=oci") + cmd.Flags().BoolVar(&resumable, "resumable", false, "Enable resumable transport for pulling images") return cmd } diff --git a/pkg/v1/remote/image_test.go b/pkg/v1/remote/image_test.go index f15e96a6d..302076a15 100644 --- a/pkg/v1/remote/image_test.go +++ b/pkg/v1/remote/image_test.go @@ -17,6 +17,8 @@ package remote import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -747,3 +749,45 @@ func TestData(t *testing.T) { t.Fatal(err) } } + +func TestImageResumable(t *testing.T) { + ref, err := name.ParseReference("ghcr.io/labring/fastgpt:v4.9.0") + if err != nil { + t.Fatal(err) + } + + image, err := Image(ref, WithResumable()) + if err != nil { + t.Fatal(err) + } + + layers, err := image.Layers() + if err != nil { + t.Fatal(err) + } + + for _, layer := range layers { + digest, err := layer.Digest() + if err != nil { + t.Fatal(err) + } + + rc, err := layer.Compressed() + if err != nil { + t.Fatal(err) + } + + hash := sha256.New() + _, err = io.Copy(hash, rc) + rc.Close() + if err != nil { + t.Fatal(err) + } + + if digest.Hex == hex.EncodeToString(hash.Sum(nil)) { + t.Logf("digest matches: %s", digest) + } else { + t.Errorf("digest mismatch: %s != %s", digest, hex.EncodeToString(hash.Sum(nil))) + } + } +} diff --git a/pkg/v1/remote/options.go b/pkg/v1/remote/options.go index 15b7da1e4..e1224c84a 100644 --- a/pkg/v1/remote/options.go +++ b/pkg/v1/remote/options.go @@ -45,6 +45,8 @@ type options struct { retryBackoff Backoff retryPredicate retry.Predicate retryStatusCodes []int + resumable bool + resumableBackoff Backoff // Only these options can overwrite Reuse()d options. platform v1.Platform @@ -135,6 +137,7 @@ func makeOptions(opts ...Option) (*options, error) { retryPredicate: defaultRetryPredicate, retryBackoff: defaultRetryBackoff, retryStatusCodes: defaultRetryStatusCodes, + resumableBackoff: defaultRetryBackoff, } for _, option := range opts { @@ -170,6 +173,11 @@ func makeOptions(opts ...Option) (*options, error) { // Wrap the transport in something that can retry network flakes. o.transport = transport.NewRetry(o.transport, transport.WithRetryBackoff(o.retryBackoff), transport.WithRetryPredicate(predicate), transport.WithRetryStatusCodes(o.retryStatusCodes...)) + + if o.resumable { + o.transport = transport.NewResumable(o.transport, o.resumableBackoff) + } + // Wrap this last to prevent transport.New from double-wrapping. if o.userAgent != "" { o.transport = transport.NewUserAgent(o.transport, o.userAgent) @@ -192,6 +200,25 @@ func WithTransport(t http.RoundTripper) Option { } } +// WithResumable is a functional option for enabling resumable downloads. and it will wrap retry transport by default. +// If configures retry and resumable backoff, should be aware of all backoff will be applied. +func WithResumable() Option { + return func(o *options) error { + o.resumable = true + return nil + } +} + +// WithResumableBackoff is a functional option for overriding the default resumable backoff for remote operations. +// Resumable backoff will resume failed requests after a delay, unlike retry actions, resumable backoff will ignore +// transport.RoundTripper.RoundTrip errors. +func WithResumableBackoff(backoff Backoff) Option { + return func(o *options) error { + o.resumableBackoff = backoff + return nil + } +} + // WithAuth is a functional option for overriding the default authenticator // for remote operations. // It is an error to use both WithAuth and WithAuthFromKeychain in the same Option set. diff --git a/pkg/v1/remote/transport/resumable.go b/pkg/v1/remote/transport/resumable.go new file mode 100644 index 000000000..98084aa2f --- /dev/null +++ b/pkg/v1/remote/transport/resumable.go @@ -0,0 +1,302 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/google/go-containerregistry/pkg/logs" +) + +// NewResumable creates a http.RoundTripper that resumes http GET from error, and continue +// transfer data from last successful transfer offset. +func NewResumable(inner http.RoundTripper, backoff Backoff) http.RoundTripper { + if backoff.Steps <= 0 { + // resume once + backoff.Steps = 1 + } + + if backoff.Duration <= 0 { + backoff.Duration = 100 * time.Millisecond + } + + return &resumableTransport{inner: inner, backoff: backoff} +} + +var ( + contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`) + rangeRe = regexp.MustCompile(`bytes=(\d+)-(\d+)?`) +) + +type resumableTransport struct { + inner http.RoundTripper + backoff Backoff +} + +func (rt *resumableTransport) RoundTrip(in *http.Request) (resp *http.Response, err error) { + var total, start, end int64 + // check initial request, maybe resumable transport is already enabled + if contentRange := in.Header.Get("Range"); contentRange != "" { + if matches := rangeRe.FindStringSubmatch(contentRange); len(matches) == 3 { + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil { + return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err) + } + + if len(matches[2]) == 0 { + // request whole file + end = -1 + } else if end, err = strconv.ParseInt(matches[2], 10, 64); err == nil { + if start > end { + return nil, fmt.Errorf("invalid content range %q", contentRange) + } + } else { + return nil, fmt.Errorf("invalid content range %q: %w", contentRange, err) + } + } + } + + if resp, err = rt.inner.RoundTrip(in); err != nil { + return resp, err + } + + if in.Method != http.MethodGet { + return resp, nil + } + + switch resp.StatusCode { + case http.StatusOK: + if end != 0 { + // request range content, but unexpected status code, cant not resume for this request + return resp, nil + } + + total = resp.ContentLength + case http.StatusPartialContent: + // keep original response status code, which should be processed by original transport or operation + if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || total <= 0 { + return resp, nil + } else if end > 0 { + total = end + 1 + } + default: + return resp, nil + } + + if total > 0 { + resp.Body = &resumableBody{ + rc: resp.Body, + inner: rt.inner, + req: in, + total: total, + transferred: start, + backoff: rt.backoff, + } + } + + return resp, nil +} + +type resumableBody struct { + rc io.ReadCloser + + inner http.RoundTripper + req *http.Request + + backoff Backoff + + transferred int64 + total int64 + + closed uint32 +} + +func (rb *resumableBody) Read(p []byte) (n int, err error) { + if atomic.LoadUint32(&rb.closed) == 1 { + // response body already closed + return 0, http.ErrBodyReadAfterClose + } else if rb.total >= 0 && rb.transferred >= rb.total { + return 0, io.EOF + } + + for { + if n, err = rb.rc.Read(p); n > 0 { + if rb.transferred+int64(n) >= rb.total { + n = int(rb.total - rb.transferred) + err = io.EOF + } + rb.transferred += int64(n) + } + + if err == nil { + return + } + + if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred >= rb.total { + return + } + + if err = rb.resume(rb.backoff, err); err == nil { + if n == 0 { + // zero bytes read, try reading again with new response.Body + continue + } + + // already read some bytes from previous response.Body, returns and waits for next Read operation + } + + return n, err + } +} + +func (rb *resumableBody) Close() (err error) { + if !atomic.CompareAndSwapUint32(&rb.closed, 0, 1) { + return nil + } + + return rb.rc.Close() +} + +func (rb *resumableBody) resume(backoff Backoff, reason error) error { + if backoff.Steps <= 0 { + // resumable transport is disabled + return reason + } + + if reason != nil { + logs.Debug.Printf("Resume http transporting from error: %v", reason) + } + + var ( + resp *http.Response + err error + ) + + for backoff.Steps > 0 { + time.Sleep(backoff.Step()) + + ctx := rb.req.Context() + select { + case <-ctx.Done(): + // context already done, stop resuming from error + return ctx.Err() + default: + } + + req := rb.req.Clone(ctx) + req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-") + if resp, err = rb.inner.RoundTrip(req); err != nil { + err = fmt.Errorf("unable to resume from '%v', %w", reason, err) + continue + } + + if err = rb.validate(resp); err != nil { + resp.Body.Close() + // wraps original error + return fmt.Errorf("%w, %v", reason, err) + } + + if atomic.LoadUint32(&rb.closed) == 1 { + resp.Body.Close() + return http.ErrBodyReadAfterClose + } + + rb.rc.Close() + rb.rc = resp.Body + + break + } + + return err +} + +const size100m = 100 << 20 + +func (rb *resumableBody) validate(resp *http.Response) (err error) { + var start, total int64 + switch resp.StatusCode { + case http.StatusPartialContent: + // donot using total size from Content-Range header, keep rb.total unchanged + if start, _, _, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil { + return err + } + + if start == rb.transferred { + break + } else if start < rb.transferred { + // incoming data is overlapped for somehow, just discard it + if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil { + return fmt.Errorf("discard overlapped data failed, %v", err) + } + } else { + return fmt.Errorf("unexpected resume start %d, wanted: %d", start, rb.transferred) + } + case http.StatusOK: + if rb.transferred > 0 { + // range is not supported, and transferred data is too large, stop resuming + if rb.transferred > size100m { + return fmt.Errorf("too large data transferred: %d", rb.transferred) + } + + // try resume from unsupported range request + if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil { + return err + } + } + case http.StatusRequestedRangeNotSatisfiable: + if contentRange := resp.Header.Get("Content-Range"); contentRange != "" && strings.HasPrefix(contentRange, "bytes */") { + if total, err = strconv.ParseInt(strings.TrimPrefix(contentRange, "bytes */"), 10, 64); err == nil && total >= 0 && rb.transferred >= total { + return io.EOF + } + } + + fallthrough + default: + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + return nil +} + +func parseContentRange(contentRange string) (start, end, size int64, err error) { + if contentRange == "" { + return -1, -1, -1, errors.New("unexpected empty content range") + } + + matches := contentRangeRe.FindStringSubmatch(contentRange) + if len(matches) != 4 { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil { + return -1, -1, -1, fmt.Errorf("unexpected start from content range '%s', %v", contentRange, err) + } + + if end, err = strconv.ParseInt(matches[2], 10, 64); err != nil { + return -1, -1, -1, fmt.Errorf("unexpected end from content range '%s', %v", contentRange, err) + } + + if start > end { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + + if matches[3] == "*" { + size = -1 + } else { + size, err = strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return -1, -1, -1, fmt.Errorf("unexpected total from content range '%s', %v", contentRange, err) + } + + if end >= size { + return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange) + } + } + + return +} diff --git a/pkg/v1/remote/transport/resumable_test.go b/pkg/v1/remote/transport/resumable_test.go new file mode 100644 index 000000000..2e109a8d6 --- /dev/null +++ b/pkg/v1/remote/transport/resumable_test.go @@ -0,0 +1,344 @@ +package transport + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "testing" + "time" + + stdrand "math/rand" + + "github.com/google/go-containerregistry/pkg/logs" + "github.com/google/go-containerregistry/pkg/v1/random" + "github.com/google/go-containerregistry/pkg/v1/types" +) + +func handleResumableLayer(data []byte, w http.ResponseWriter, r *http.Request, t *testing.T) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var ( + contentLength, start, end int64 + statusCode = http.StatusOK + err error + ) + + contentLength = int64(len(data)) + end = contentLength - 1 + contentRange := r.Header.Get("Range") + if contentRange != "" { + matches := rangeRe.FindStringSubmatch(contentRange) + if len(matches) != 3 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil || start < 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if start >= int64(contentLength) { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + + if matches[2] != "" { + end, err = strconv.ParseInt(matches[2], 10, 64) + if err != nil || end < 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + if end >= int64(contentLength) { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + } + + statusCode = http.StatusPartialContent + } + + var currentContentLength = end - start + 1 + if currentContentLength <= 0 { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if currentContentLength > 4096 { + if currentContentLength = stdrand.Int63n(currentContentLength); currentContentLength < 1024 { + currentContentLength = 1024 + } + + if r.Header.Get("X-Overlap") == "true" { + overlapSize := int64(stdrand.Int31n(64)) + if start > overlapSize { + start -= overlapSize + // t.Logf("Overlap data size: %d", overlapSize) + } + } + } + + end = start + currentContentLength - 1 + + if statusCode == http.StatusPartialContent { + w.Header().Set("Content-Length", strconv.FormatInt(currentContentLength, 10)) + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, contentLength)) + } else { + w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) + } + + w.WriteHeader(statusCode) + w.Write(data[start : end+1]) + time.Sleep(time.Second) +} + +func resumableRequest(client *http.Client, url string, leading, trailing []byte, size int64, digest string, overlap bool, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + if overlap { + req.Header.Set("X-Overlap", "true") + } + + if len(leading) > 0 || len(trailing) > 0 { + var buf bytes.Buffer + buf.WriteString("bytes=") + buf.WriteString(fmt.Sprintf("%d-", len(leading))) + if len(trailing) > 0 { + buf.WriteString(fmt.Sprintf("%d", size-int64(len(trailing))-1)) + } + req.Header.Set("Range", buf.String()) + } + + resp, err := client.Do(req.WithContext(t.Context())) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + hash := sha256.New() + if len(leading) > 0 { + io.Copy(hash, bytes.NewReader(leading)) + } + + if _, err = io.Copy(hash, resp.Body); err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(trailing) > 0 { + io.Copy(hash, bytes.NewReader(trailing)) + } + + actualDigest := "sha256:" + hex.EncodeToString(hash.Sum(nil)) + + if actualDigest != digest { + t.Errorf("unexpected digest: %s, actually: %s", digest, actualDigest) + } +} + +func nonResumableRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + + _, ok := resp.Body.(*resumableBody) + if ok { + t.Error("expected non-resumable body") + } +} + +func resumableStopByTimeoutRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*3) + defer cancel() + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + if _, err = io.Copy(io.Discard, resp.Body); err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Error("expected context deadline exceeded error") + } +} + +func resumableStopByCancelRequest(client *http.Client, url string, t *testing.T) { + req, err := http.NewRequest(http.MethodGet, url, http.NoBody) + if err != nil { + t.Fatalf("http.NewRequest(): %v", err) + } + + ctx, cancel := context.WithCancel(t.Context()) + time.AfterFunc(time.Second*3, cancel) + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("client.Do(): %v", err) + } + defer resp.Body.Close() + + if _, ok := resp.Body.(*resumableBody); !ok { + t.Error("expected resumable body") + return + } + + if _, err = io.Copy(io.Discard, resp.Body); err != nil && !errors.Is(err, context.Canceled) { + t.Error("expected context cancel error") + } +} + +func TestResumableTransport(t *testing.T) { + logs.Debug.SetOutput(os.Stdout) + layer, err := random.Layer(2<<20, types.DockerLayer) + if err != nil { + t.Fatalf("random.Layer(): %v", err) + } + + digest, err := layer.Digest() + if err != nil { + t.Fatalf("layer.Digest(): %v", err) + } + + size, err := layer.Size() + if err != nil { + t.Fatalf("layer.Size(): %v", err) + } + + rc, err := layer.Compressed() + if err != nil { + t.Fatalf("layer.Compressed(): %v", err) + } + + data, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("io.ReadAll(): %v", err) + } + + layerPath := fmt.Sprintf("/v2/foo/bar/blobs/%s", digest.String()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case layerPath: + handleResumableLayer(data, w, r, t) + default: + http.Error(w, "not found", http.StatusNotFound) + } + })) + defer server.Close() + + address, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("url.Parse(%v) = %v", server.URL, err) + } + + client := &http.Client{ + Transport: NewResumable(http.DefaultTransport.(*http.Transport).Clone(), Backoff{ + Duration: 1.0 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: 3, + }), + } + + tests := []struct { + name string + digest string + leading, trailing int64 + timeout bool + cancel bool + nonResumable bool + overlap bool + ranged bool + }{ + { + name: "resumable", + digest: digest.String(), + leading: 0, + }, + { + name: "resumable-range-leading", + digest: digest.String(), + leading: 3, + }, + { + name: "resumable-range-trailing", + digest: digest.String(), + leading: 0, + }, + { + name: "resumable-range-leading-trailing", + digest: digest.String(), + leading: 3, + trailing: 6, + }, + { + name: "resumable-overlap", + digest: digest.String(), + leading: 0, + overlap: true, + }, + { + name: "non-resumable", + nonResumable: true, + }, + { + name: "resumable stop by timeout", + cancel: true, + }, + { + name: "resumable stop by cancel", + cancel: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := address.String() + layerPath + if tt.nonResumable { + nonResumableRequest(client, address.String(), t) + } else if tt.cancel { + resumableStopByCancelRequest(client, url, t) + } else if tt.timeout { + resumableStopByTimeoutRequest(client, url, t) + } else if tt.digest != "" { + resumableRequest(client, url, data[:tt.leading], data[size-tt.trailing:], size, tt.digest, tt.overlap, t) + } + }) + } +}