diff --git a/oci/client/download.go b/oci/client/download.go new file mode 100644 index 00000000..6f882f68 --- /dev/null +++ b/oci/client/download.go @@ -0,0 +1,359 @@ +/* +Copyright 2024 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "syscall" + "time" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/sync/errgroup" +) + +const ( + minChunkSize = 100 * 1024 * 1024 // 100MB + maxChunkSize = 1 << 30 // 1GB + defaultNumberOfChunks = 50 +) + +var ( + // errRangeRequestNotSupported is returned when the registry does not support range requests. + errRangeRequestNotSupported = fmt.Errorf("range requests are not supported by the registry") + errCopyFailed = errors.New("copy failed") +) + +var ( + retries = 3 + defaultRetryBackoff = remote.Backoff{ + Duration: 1.0 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: retries, + } +) + +type downloadOption func(*downloadOptions) + +type downloadOptions struct { + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + numberOfChunks int +} + +type blobManager struct { + name name.Reference + c *retryablehttp.Client + layer v1.Layer + path string + digest v1.Hash + size int64 + downloadOptions +} + +func withTransport(t http.RoundTripper) downloadOption { + return func(o *downloadOptions) { + o.transport = t + } +} + +func withAuth(auth authn.Authenticator) downloadOption { + return func(o *downloadOptions) { + o.auth = auth + } +} + +func withKeychain(k authn.Keychain) downloadOption { + return func(o *downloadOptions) { + o.keychain = k + } +} + +func withNumberOfChunks(n int) downloadOption { + return func(o *downloadOptions) { + o.numberOfChunks = n + } +} + +type chunk struct { + n int + offset int64 + size int64 + writeCounter +} + +func makeChunk(n int, offset, size int64) *chunk { + return &chunk{ + n: n, + offset: offset, + size: size, + writeCounter: writeCounter{}, + } +} + +// newDownloader returns a new blobManager with the given options. +func newDownloader(name name.Reference, path string, layer v1.Layer, opts ...downloadOption) *blobManager { + o := &downloadOptions{ + numberOfChunks: defaultNumberOfChunks, + keychain: authn.DefaultKeychain, + transport: remote.DefaultTransport.(*http.Transport).Clone(), + } + d := &blobManager{ + layer: layer, + name: name, + path: path, + downloadOptions: *o, + } + for _, opt := range opts { + opt(&d.downloadOptions) + } + + return d +} + +func (d *blobManager) download(ctx context.Context) error { + digest, err := d.layer.Digest() + if err != nil { + return fmt.Errorf("failed to get layer digest: %w", err) + } + d.digest = digest + + size, err := d.layer.Size() + if err != nil { + return fmt.Errorf("failed to get layer size: %w", err) + } + d.size = size + + if d.c == nil { + h, err := makeHttpClient(ctx, d.name.Context(), &d.downloadOptions) + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) + } + d.c = h + } + + ok, err := d.isRangeRequestEnabled(ctx) + if err != nil { + return fmt.Errorf("failed to check range request support: %w", err) + } + + if !ok { + return errRangeRequestNotSupported + } + + if err := d.downloadChunks(ctx); err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := d.verifyDigest(); err != nil { + return fmt.Errorf("failed to verify layer digest: %w", err) + } + + return nil +} + +func (d *blobManager) downloadChunks(ctx context.Context) error { + u := makeUrl(d.name, d.digest) + + file, err := os.OpenFile(d.path+".tmp", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to create layer file: %w", err) + } + defer file.Close() + + chunkSize := d.size / int64(d.numberOfChunks) + if chunkSize < minChunkSize { + chunkSize = minChunkSize + } else if chunkSize > maxChunkSize { + chunkSize = maxChunkSize + } + + var ( + chunks []*chunk + n int + ) + + for offset := int64(0); offset < d.size; offset += chunkSize { + if offset+chunkSize > d.size { + chunkSize = d.size - offset + } + chunk := makeChunk(n, offset, chunkSize) + chunks = append(chunks, chunk) + n++ + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(d.numberOfChunks) + for _, chunk := range chunks { + chunk := chunk + g.Go(func() error { + b := defaultRetryBackoff + for i := 0; i < retries; i++ { + w := io.NewOffsetWriter(file, chunk.offset) + err := chunk.download(ctx, d.c, w, u) + switch { + case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): + return err + case errors.Is(err, errCopyFailed): + time.Sleep(b.Step()) + continue + default: + return nil + } + } + return fmt.Errorf("failed to download chunk %d: %w", n, err) + }) + } + + err = g.Wait() + if err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := os.Rename(file.Name(), d.path); err != nil { + return err + } + + return nil + +} + +func (c *chunk) download(ctx context.Context, client *retryablehttp.Client, w io.Writer, u url.URL) error { + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return err + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", c.offset, c.offset+c.size-1)) + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return err + } + + _, err = io.Copy(w, io.TeeReader(resp.Body, &c.writeCounter)) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { + // TODO: if the download was interrupted, we can resume it + return fmt.Errorf("failed to download chunk %d: %w", c.n, err) + } + + return err +} + +func (d *blobManager) isRangeRequestEnabled(ctx context.Context) (bool, error) { + u := makeUrl(d.name, d.digest) + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := d.c.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + + return false, nil +} + +func (d *blobManager) verifyDigest() error { + f, err := os.Open(d.path) + if err != nil { + return fmt.Errorf("failed to open layer file: %w", err) + } + defer f.Close() + + h := sha256.New() + _, err = io.Copy(h, f) + if err != nil { + return fmt.Errorf("failed to hash layer: %w", err) + } + + newDigest := h.Sum(nil) + if d.digest.String() != fmt.Sprintf("sha256:%x", newDigest) { + return fmt.Errorf("layer digest does not match: %s != sha256:%x", d.digest.String(), newDigest) + } + return nil +} + +func makeUrl(name name.Reference, digest v1.Hash) url.URL { + return url.URL{ + Scheme: name.Context().Scheme(), + Host: name.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", name.Context().RepositoryStr(), digest.String()), + } +} + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o *downloadOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull.go b/oci/client/pull.go index 633bf4b3..5ed76515 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -20,17 +20,29 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" + "net/http" "os" + "github.com/fluxcd/pkg/tar" + "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - gcrv1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/fluxcd/pkg/tar" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" ) +// const ( +// // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. +// // If the layer is larger than this, it will be downloaded in chunks. +// thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB +// // maxConcurrentPulls is the maximum number of concurrent downloads. +// maxConcurrentPulls = 10 +// ) + var ( // gzipMagicHeader are bytes found at the start of gzip files // https://github.com/google/go-containerregistry/blob/a54d64203cffcbf94146e04069aae4a97f228ee2/internal/gzip/zip.go#L28 @@ -41,6 +53,9 @@ var ( type PullOptions struct { layerIndex int layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain } // PullOption is a function for configuring PullOptions. @@ -60,22 +75,47 @@ func WithPullLayerIndex(i int) PullOption { } } +func WithTransport(t http.RoundTripper) PullOption { + return func(o *PullOptions) { + o.transport = t + } +} + +func WithAuth(auth authn.Authenticator) PullOption { + return func(o *PullOptions) { + o.auth = auth + } +} + +func WithKeychain(k authn.Keychain) PullOption { + return func(o *PullOptions) { + o.keychain = k + } +} + // Pull downloads an artifact from an OCI repository and extracts the content. // It untar or copies the content to the given outPath depending on the layerType. // If no layer type is given, it tries to determine the right type by checking compressed content of the layer. -func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOption) (*Metadata, error) { +func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...PullOption) (*Metadata, error) { o := &PullOptions{ layerIndex: 0, } + o.keychain = authn.DefaultKeychain for _, opt := range opts { opt(o) } - ref, err := name.ParseReference(url) + + if o.transport == nil { + transport := remote.DefaultTransport.(*http.Transport).Clone() + o.transport = transport + } + + ref, err := name.ParseReference(urlString) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - img, err := crane.Pull(url, c.optionsWithContext(ctx)...) + img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err } @@ -91,7 +131,7 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti } meta := MetadataFromAnnotations(manifest.Annotations) - meta.URL = url + meta.URL = urlString meta.Digest = ref.Context().Digest(digest.String()).String() layers, err := img.Layers() @@ -107,15 +147,32 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return nil, fmt.Errorf("index '%d' out of bound for '%d' layers in artifact", o.layerIndex, len(layers)) } - err = extractLayer(layers[o.layerIndex], outPath, o.layerType) + size, err := layers[o.layerIndex].Size() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get layer size: %w", err) } + + if size > minChunkSize { + manager := newDownloader(ref, outPath, layers[o.layerIndex], + withTransport(o.transport), withKeychain(o.keychain), withAuth(o.auth)) + err = manager.download(ctx) + if err != nil && !errors.Is(err, errRangeRequestNotSupported) { + return nil, fmt.Errorf("failed to download layer: %w", err) + } + } + + if size <= minChunkSize || errors.Is(err, errRangeRequestNotSupported) { + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) + if err != nil { + return nil, err + } + } + return meta, nil } // extractLayer extracts the Layer to the path -func extractLayer(layer gcrv1.Layer, path string, layerType LayerType) error { +func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader blob, err := layer.Compressed() if err != nil { diff --git a/oci/client/pull_test.go b/oci/client/pull_test.go index 86795284..b68dd15a 100644 --- a/oci/client/pull_test.go +++ b/oci/client/pull_test.go @@ -41,6 +41,7 @@ func Test_PullAnyTarball(t *testing.T) { repo := "test-no-annotations" + randStringRunes(5) dst := fmt.Sprintf("%s/%s:%s", dockerReg, repo, tag) + fmt.Println("Pulling from:", dst) artifact := filepath.Join(t.TempDir(), "artifact.tgz") g.Expect(build(artifact, testDir, nil)).To(Succeed()) @@ -82,3 +83,23 @@ func Test_PullAnyTarball(t *testing.T) { g.Expect(extractTo + "/" + entry).To(Or(BeAnExistingFile(), BeADirectory())) } } + +func Test_PullLargeTarball(t *testing.T) { + g := NewWithT(t) + ctx := context.Background() + c := NewClient(DefaultOptions()) + dst := "vnp505/zephyr-7b-alpha:alpha" + extractTo := filepath.Join(t.TempDir(), "artifact") + m, err := c.Pull(ctx, dst, extractTo, WithPullLayerIndex(19)) + fmt.Println("Pulled from:", dst) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(m).ToNot(BeNil()) + g.Expect(m.Annotations).To(BeEmpty()) + g.Expect(m.Created).To(BeEmpty()) + g.Expect(m.Revision).To(BeEmpty()) + g.Expect(m.Source).To(BeEmpty()) + g.Expect(m.URL).To(Equal(dst)) + g.Expect(m.Digest).ToNot(BeEmpty()) + g.Expect(extractTo).ToNot(BeEmpty()) +} diff --git a/oci/client/push_pull_test.go b/oci/client/push_pull_test.go index 3c68b253..9d02f101 100644 --- a/oci/client/push_pull_test.go +++ b/oci/client/push_pull_test.go @@ -305,6 +305,7 @@ func Test_Push_Pull(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) fileInfo, err := os.Stat(tt.sourcePath) + g.Expect(err).ToNot(HaveOccurred()) // if a directory was pushed, then the created file should be a gzipped archive if fileInfo.IsDir() { bufReader := bufio.NewReader(bytes.NewReader(got)) diff --git a/oci/go.mod b/oci/go.mod index e992b575..50681ead 100644 --- a/oci/go.mod +++ b/oci/go.mod @@ -21,9 +21,11 @@ require ( github.com/fluxcd/pkg/tar v0.4.0 github.com/fluxcd/pkg/version v0.2.2 github.com/google/go-containerregistry v0.18.0 + github.com/hashicorp/go-retryablehttp v0.7.5 github.com/onsi/gomega v1.31.1 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/sirupsen/logrus v1.9.3 + golang.org/x/sync v0.6.0 sigs.k8s.io/controller-runtime v0.16.3 ) @@ -80,6 +82,7 @@ require ( github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/golang-lru/arc/v2 v2.0.5 // indirect github.com/hashicorp/golang-lru/v2 v2.0.5 // indirect github.com/imdario/mergo v0.3.15 // indirect @@ -130,7 +133,6 @@ require ( golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect - golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/oci/go.sum b/oci/go.sum index 87ecfa5d..45aeee99 100644 --- a/oci/go.sum +++ b/oci/go.sum @@ -155,6 +155,12 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= +github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/golang-lru/arc/v2 v2.0.5 h1:l2zaLDubNhW4XO3LnliVj0GXO3+/CGNJAg1dcN2Fpfw= github.com/hashicorp/golang-lru/arc/v2 v2.0.5/go.mod h1:ny6zBSQZi2JxIeYcv7kt2sH2PXJtirBN7RDhRpxPkxU= github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4=