diff --git a/Makefile b/Makefile index 9eb3d3f540a..fd8452e6846 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ SDK_COMPA_PKGS=${SDK_CORE_PKGS} ${SDK_CLIENT_PKGS} SDK_EXAMPLES_PKGS= SDK_ALL_PKGS=${SDK_COMPA_PKGS} ${SDK_EXAMPLES_PKGS} -RUN_NONE=-run '^$$' +RUN_NONE=-run NONE RUN_INTEG=-run '^TestInteg_' CODEGEN_RESOURCES_PATH=$(shell pwd)/codegen/smithy-aws-go-codegen/src/main/resources/software/amazon/smithy/aws/go/codegen @@ -98,7 +98,7 @@ gen-endpoint-prefix.json: # Unit Testing # ################ -unit: lint unit-modules-. +unit: lint unit-modules-. unit-race: lint unit-race-modules-. unit-test: test-modules-. @@ -194,7 +194,7 @@ integ-modules-%: "go test -timeout=10m -tags "integration" -v ${RUN_INTEG} -count 1 ./..." cleanup-integ-buckets: - @echo "Cleaning up SDK integraiton resources" + @echo "Cleaning up SDK integration resources" go run -tags "integration" ./internal/awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration" ############## diff --git a/feature/s3/manager/api.go b/feature/s3/manager/api.go new file mode 100644 index 00000000000..4059f9851d7 --- /dev/null +++ b/feature/s3/manager/api.go @@ -0,0 +1,37 @@ +package manager + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// DeleteObjectsAPIClient is an S3 API client that can invoke the DeleteObjects operation. +type DeleteObjectsAPIClient interface { + DeleteObjects(context.Context, *s3.DeleteObjectsInput, ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error) +} + +// DownloadAPIClient is an S3 API client that can invoke the GetObject operation. +type DownloadAPIClient interface { + GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) +} + +// HeadBucketAPIClient is an S3 API client that can invoke the HeadBucket operation. +type HeadBucketAPIClient interface { + HeadBucket(context.Context, *s3.HeadBucketInput, ...func(*s3.Options)) (*s3.HeadBucketOutput, error) +} + +// ListObjectsV2APIClient is an S3 API client that can invoke the ListObjectV2 operation. +type ListObjectsV2APIClient interface { + ListObjectsV2(context.Context, *s3.ListObjectsV2Input, ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) +} + +// UploadAPIClient is an S3 API client that can invoke PutObject, UploadPart, CreateMultipartUpload, +// CompleteMultipartUpload, and AbortMultipartUpload operations. +type UploadAPIClient interface { + PutObject(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error) + UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error) + CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) + CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) + AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) +} diff --git a/feature/s3/manager/bucket_region.go b/feature/s3/manager/bucket_region.go new file mode 100644 index 00000000000..e810877bc4d --- /dev/null +++ b/feature/s3/manager/bucket_region.go @@ -0,0 +1,133 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/awslabs/smithy-go/middleware" + smithyhttp "github.com/awslabs/smithy-go/transport/http" +) + +const bucketRegionHeader = "X-Amz-Bucket-Region" + +// GetBucketRegion will attempt to get the region for a bucket using the +// client's configured region to determine which AWS partition to perform the query on. +// +// The request will not be signed, and will not use your AWS credentials. +// +// A BucketNotFound error will be returned if the bucket does not exist in the +// AWS partition the client region belongs to. +// +// For example to get the region of a bucket which exists in "eu-central-1" +// you could provide a region hint of "us-west-2". +// +// cfg := config.LoadDefaultConfig() +// +// bucket := "my-bucket" +// region, err := s3manager.GetBucketRegion(ctx, s3.NewFromConfig(cfg), bucket) +// if err != nil { +// var bnf BucketNotFound +// if errors.As(err, &bnf) { +// fmt.Fprintf(os.Stderr, "unable to find bucket %s's region\n", bucket) +// } +// } +// fmt.Printf("Bucket %s is in %s region\n", bucket, region) +// +// By default the request will be made to the Amazon S3 endpoint using the virtual-hosted-style addressing. +// +// bucketname.s3.us-west-2.amazonaws.com/ +// +// To configure the GetBucketRegion to make a request via the Amazon +// S3 FIPS endpoints directly when a FIPS region name is not available, (e.g. +// fips-us-gov-west-1) set the EndpointResolver on the config or client the +// utility is called with. +// +// cfg, err := config.LoadDefaultConfig(config.WithEndpointResolver{ +// EndpointResolver: aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { +// return aws.Endpoint{URL: "https://s3-fips.us-west-2.amazonaws.com"}, nil +// }), +// }) +// if err != nil { +// panic(err) +// } +func GetBucketRegion(ctx context.Context, client HeadBucketAPIClient, bucket string, optFns ...func(*s3.Options)) (string, error) { + var captureBucketRegion deserializeBucketRegion + + clientOptionFns := make([]func(*s3.Options), len(optFns)+1) + clientOptionFns[0] = func(options *s3.Options) { + options.Credentials = aws.AnonymousCredentials{} + options.APIOptions = append(options.APIOptions, captureBucketRegion.RegisterMiddleware) + } + copy(clientOptionFns[1:], optFns) + + _, err := client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(bucket), + }, clientOptionFns...) + if len(captureBucketRegion.BucketRegion) == 0 && err != nil { + var httpStatusErr interface { + HTTPStatusCode() int + } + if !errors.As(err, &httpStatusErr) { + return "", err + } + + if httpStatusErr.HTTPStatusCode() == http.StatusNotFound { + return "", &bucketNotFound{} + } + + return "", err + } + + return captureBucketRegion.BucketRegion, nil +} + +type deserializeBucketRegion struct { + BucketRegion string +} + +func (d *deserializeBucketRegion) RegisterMiddleware(stack *middleware.Stack) error { + return stack.Deserialize.Add(d, middleware.After) +} + +func (d *deserializeBucketRegion) ID() string { + return "DeserializeBucketRegion" +} + +func (d *deserializeBucketRegion) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + + resp, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + return out, metadata, fmt.Errorf("unknown transport type %T", out.RawResponse) + } + + d.BucketRegion = resp.Header.Get(bucketRegionHeader) + + return out, metadata, err +} + +// BucketNotFound indicates the bucket was not found in the partition when calling GetBucketRegion. +type BucketNotFound interface { + error + + isBucketNotFound() +} + +type bucketNotFound struct{} + +func (b *bucketNotFound) Error() string { + return "bucket not found" +} + +func (b *bucketNotFound) isBucketNotFound() {} + +var _ BucketNotFound = (*bucketNotFound)(nil) diff --git a/feature/s3/manager/bucket_region_test.go b/feature/s3/manager/bucket_region_test.go new file mode 100644 index 00000000000..dd56d7d1a69 --- /dev/null +++ b/feature/s3/manager/bucket_region_test.go @@ -0,0 +1,120 @@ +package manager + +import ( + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +var mockErrResponse = []byte(` + + MockCode + The error message + 4442587FB7D0A2F9 +`) + +func testSetupGetBucketRegionServer(region string, statusCode int, incHeader bool) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(ioutil.Discard, r.Body) + if incHeader { + w.Header().Set(bucketRegionHeader, region) + } + if statusCode >= 300 { + w.Header().Set("Content-Length", strconv.Itoa(len(mockErrResponse))) + w.WriteHeader(statusCode) + w.Write(mockErrResponse) + } else { + w.WriteHeader(statusCode) + } + })) +} + +var testGetBucketRegionCases = []struct { + RespRegion string + StatusCode int + ExpectReqRegion string +}{ + { + RespRegion: "bucket-region", + StatusCode: 301, + }, + { + RespRegion: "bucket-region", + StatusCode: 403, + }, + { + RespRegion: "bucket-region", + StatusCode: 200, + }, + { + RespRegion: "bucket-region", + StatusCode: 200, + ExpectReqRegion: "default-region", + }, +} + +func TestGetBucketRegion_Exists(t *testing.T) { + for i, c := range testGetBucketRegionCases { + server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true) + + client := s3.New(s3.Options{ + EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) { + return aws.Endpoint{ + URL: server.URL, + }, nil + }), + }) + + region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) { + o.UsePathStyle = true + }) + if err != nil { + t.Errorf("%d, expect no error, got %v", i, err) + goto closeServer + } + if e, a := c.RespRegion, region; e != a { + t.Errorf("%d, expect %q region, got %q", i, e, a) + } + + closeServer: + server.Close() + } +} + +func TestGetBucketRegion_NotExists(t *testing.T) { + server := testSetupGetBucketRegionServer("ignore-region", 404, false) + defer server.Close() + + client := s3.New(s3.Options{ + EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) { + return aws.Endpoint{ + URL: server.URL, + }, nil + }), + }) + + region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) { + o.UsePathStyle = true + }) + if err == nil { + t.Fatalf("expect error, but did not get one") + } + + var bnf BucketNotFound + if !errors.As(err, &bnf) { + t.Errorf("expect %T error, got %v", bnf, err) + } + + if len(region) != 0 { + t.Errorf("expect region not to be set, got %q", region) + } +} diff --git a/feature/s3/manager/buffered_read_seeker.go b/feature/s3/manager/buffered_read_seeker.go new file mode 100644 index 00000000000..e781aef610d --- /dev/null +++ b/feature/s3/manager/buffered_read_seeker.go @@ -0,0 +1,79 @@ +package manager + +import ( + "io" +) + +// BufferedReadSeeker is buffered io.ReadSeeker +type BufferedReadSeeker struct { + r io.ReadSeeker + buffer []byte + readIdx, writeIdx int +} + +// NewBufferedReadSeeker returns a new BufferedReadSeeker +// if len(b) == 0 then the buffer will be initialized to 64 KiB. +func NewBufferedReadSeeker(r io.ReadSeeker, b []byte) *BufferedReadSeeker { + if len(b) == 0 { + b = make([]byte, 64*1024) + } + return &BufferedReadSeeker{r: r, buffer: b} +} + +func (b *BufferedReadSeeker) reset(r io.ReadSeeker) { + b.r = r + b.readIdx, b.writeIdx = 0, 0 +} + +// Read will read up len(p) bytes into p and will return +// the number of bytes read and any error that occurred. +// If the len(p) > the buffer size then a single read request +// will be issued to the underlying io.ReadSeeker for len(p) bytes. +// A Read request will at most perform a single Read to the underlying +// io.ReadSeeker, and may return < len(p) if serviced from the buffer. +func (b *BufferedReadSeeker) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return n, err + } + + if b.readIdx == b.writeIdx { + if len(p) >= len(b.buffer) { + n, err = b.r.Read(p) + return n, err + } + b.readIdx, b.writeIdx = 0, 0 + + n, err = b.r.Read(b.buffer) + if n == 0 { + return n, err + } + + b.writeIdx += n + } + + n = copy(p, b.buffer[b.readIdx:b.writeIdx]) + b.readIdx += n + + return n, err +} + +// Seek will position then underlying io.ReadSeeker to the given offset +// and will clear the buffer. +func (b *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) { + n, err := b.r.Seek(offset, whence) + + b.reset(b.r) + + return n, err +} + +// ReadAt will read up to len(p) bytes at the given file offset. +// This will result in the buffer being cleared. +func (b *BufferedReadSeeker) ReadAt(p []byte, off int64) (int, error) { + _, err := b.Seek(off, io.SeekStart) + if err != nil { + return 0, err + } + + return b.Read(p) +} diff --git a/feature/s3/manager/buffered_read_seeker_test.go b/feature/s3/manager/buffered_read_seeker_test.go new file mode 100644 index 00000000000..ed46668395f --- /dev/null +++ b/feature/s3/manager/buffered_read_seeker_test.go @@ -0,0 +1,79 @@ +package manager + +import ( + "bytes" + "io" + "testing" +) + +func TestBufferedReadSeekerRead(t *testing.T) { + expected := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(expected), make([]byte, 4)) + + var ( + actual []byte + buffer = make([]byte, 2) + ) + + for { + n, err := readSeeker.Read(buffer) + actual = append(actual, buffer[:n]...) + if err != nil && err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to read from reader: %v", err) + } + } + + if !bytes.Equal(expected, actual) { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +func TestBufferedReadSeekerSeek(t *testing.T) { + content := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 4)) + + _, err := readSeeker.Seek(4, io.SeekStart) + if err != nil { + t.Fatalf("failed to seek reader: %v", err) + } + + var ( + actual []byte + buffer = make([]byte, 4) + ) + + for { + n, err := readSeeker.Read(buffer) + actual = append(actual, buffer[:n]...) + if err != nil && err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to read from reader: %v", err) + } + } + + if e := []byte("Data"); !bytes.Equal(e, actual) { + t.Errorf("expected %v, got %v", e, actual) + } +} + +func TestBufferedReadSeekerReadAt(t *testing.T) { + content := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 2)) + + buffer := make([]byte, 4) + + _, err := readSeeker.ReadAt(buffer, 0) + if err != nil { + t.Fatalf("failed to seek reader: %v", err) + } + + if e := content[:4]; !bytes.Equal(e, buffer) { + t.Errorf("expected %v, got %v", e, buffer) + } +} diff --git a/feature/s3/manager/default_read_seeker_write_to.go b/feature/s3/manager/default_read_seeker_write_to.go new file mode 100644 index 00000000000..6d1dc6d2c42 --- /dev/null +++ b/feature/s3/manager/default_read_seeker_write_to.go @@ -0,0 +1,7 @@ +// +build !windows + +package manager + +func defaultUploadBufferProvider() ReadSeekerWriteToProvider { + return nil +} diff --git a/feature/s3/manager/default_read_seeker_write_to_windows.go b/feature/s3/manager/default_read_seeker_write_to_windows.go new file mode 100644 index 00000000000..1ae881c104a --- /dev/null +++ b/feature/s3/manager/default_read_seeker_write_to_windows.go @@ -0,0 +1,5 @@ +package manager + +func defaultUploadBufferProvider() ReadSeekerWriteToProvider { + return NewBufferedReadSeekerWriteToPool(1024 * 1024) +} diff --git a/feature/s3/manager/default_writer_read_from.go b/feature/s3/manager/default_writer_read_from.go new file mode 100644 index 00000000000..d5518145219 --- /dev/null +++ b/feature/s3/manager/default_writer_read_from.go @@ -0,0 +1,7 @@ +// +build !windows + +package manager + +func defaultDownloadBufferProvider() WriterReadFromProvider { + return nil +} diff --git a/feature/s3/manager/default_writer_read_from_windows.go b/feature/s3/manager/default_writer_read_from_windows.go new file mode 100644 index 00000000000..88887ff586e --- /dev/null +++ b/feature/s3/manager/default_writer_read_from_windows.go @@ -0,0 +1,5 @@ +package manager + +func defaultDownloadBufferProvider() WriterReadFromProvider { + return NewPooledBufferedWriterReadFromProvider(1024 * 1024) +} diff --git a/feature/s3/manager/doc.go b/feature/s3/manager/doc.go new file mode 100644 index 00000000000..31171a69875 --- /dev/null +++ b/feature/s3/manager/doc.go @@ -0,0 +1,3 @@ +// Package manager provides utilities to upload and download objects from +// S3 concurrently. Helpful for when working with large objects. +package manager diff --git a/feature/s3/manager/download.go b/feature/s3/manager/download.go new file mode 100644 index 00000000000..60e5a051f7c --- /dev/null +++ b/feature/s3/manager/download.go @@ -0,0 +1,493 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/internal/awsutil" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +const userAgentKey = "S3Manager" + +// DefaultDownloadPartSize is the default range of bytes to get at a time when +// using Download(). +const DefaultDownloadPartSize = 1024 * 1024 * 5 + +// DefaultDownloadConcurrency is the default number of goroutines to spin up +// when using Download(). +const DefaultDownloadConcurrency = 5 + +// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to upload. +const DefaultPartBodyMaxRetries = 3 + +type errReadingBody struct { + err error +} + +func (e *errReadingBody) Error() string { + return fmt.Sprintf("failed to read part body: %v", e.err) +} + +func (e *errReadingBody) Unwrap() error { + return e.err +} + +// The Downloader structure that calls Download(). It is safe to call Download() +// on this structure for multiple objects and across concurrent goroutines. +// Mutating the Downloader's properties is not safe to be done concurrently. +type Downloader struct { + // The size (in bytes) to request from S3 for each part. + // The minimum allowed part size is 5MB, and if this value is set to zero, + // the DefaultDownloadPartSize value will be used. + // + // PartSize is ignored if the Range input parameter is provided. + PartSize int64 + + // PartBodyMaxRetries is the number of retry attempts to make for failed part uploads + PartBodyMaxRetries int + + // The number of goroutines to spin up in parallel when sending parts. + // If this is set to zero, the DefaultDownloadConcurrency value will be used. + // + // Concurrency of 1 will download the parts sequentially. + // + // Concurrency is ignored if the Range input parameter is provided. + Concurrency int + + // An S3 client to use when performing downloads. + S3 DownloadAPIClient + + // List of client options that will be passed down to individual API + // operation requests made by the downloader. + ClientOptions []func(*s3.Options) + + // Defines the buffer strategy used when downloading a part. + // + // If a WriterReadFromProvider is given the Download manager + // will pass the io.WriterAt of the Download request to the provider + // and will use the returned WriterReadFrom from the provider as the + // destination writer when copying from http response body. + BufferProvider WriterReadFromProvider +} + +// WithDownloaderClientOptions appends to the Downloader's API request options. +func WithDownloaderClientOptions(opts ...func(*s3.Options)) func(*Downloader) { + return func(d *Downloader) { + d.ClientOptions = append(d.ClientOptions, opts...) + } +} + +// NewDownloader creates a new Downloader instance to downloads objects from +// S3 in concurrent chunks. Pass in additional functional options to customize +// the downloader behavior. Requires a client.ConfigProvider in order to create +// a S3 service client. The session.Session satisfies the client.ConfigProvider +// interface. +// +// Example: +// // Load AWS Config +// cfg, err := config.LoadDefaultConfig() +// if err != nil { +// panic(err) +// } +// +// // Create an S3 client using the loaded configuration +// s3.NewFromConfig(cfg) +// +// // Create a downloader passing it the S3 client +// downloader := s3manager.NewDownloader(s3.NewFromConfig(cfg)) +// +// // Create a downloader with the client and custom downloader options +// downloader := s3manager.NewDownloader(client, func(d *s3manager.Downloader) { +// d.PartSize = 64 * 1024 * 1024 // 64MB per part +// }) +func NewDownloader(c DownloadAPIClient, options ...func(*Downloader)) *Downloader { + d := &Downloader{ + S3: c, + PartSize: DefaultDownloadPartSize, + PartBodyMaxRetries: DefaultPartBodyMaxRetries, + Concurrency: DefaultDownloadConcurrency, + BufferProvider: defaultDownloadBufferProvider(), + } + for _, option := range options { + option(d) + } + + return d +} + +// Download downloads an object in S3 and writes the payload into w +// using concurrent GET requests. The n int64 returned is the size of the object downloaded +// in bytes. +// +// DownloadWithContext is the same as Download with the additional support for +// Context input parameters. The Context must not be nil. A nil Context will +// cause a panic. Use the Context to add deadlining, timeouts, etc. The +// DownloadWithContext may create sub-contexts for individual underlying +// requests. +// +// Additional functional options can be provided to configure the individual +// download. These options are copies of the Downloader instance Download is +// called from. Modifying the options will not impact the original Downloader +// instance. Use the WithDownloaderClientOptions helper function to pass in request +// options that will be applied to all API operations made with this downloader. +// +// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent +// downloads, or in memory []byte wrapper using aws.WriteAtBuffer. +// +// Specifying a Downloader.Concurrency of 1 will cause the Downloader to +// download the parts from S3 sequentially. +// +// It is safe to call this method concurrently across goroutines. +// +// If the GetObjectInput's Range value is provided that will cause the downloader +// to perform a single GetObjectInput request for that object's range. This will +// caused the part size, and concurrency configurations to be ignored. +func (d Downloader) Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) { + impl := downloader{w: w, in: input, cfg: d, ctx: ctx} + + // Copy ClientOptions + clientOptions := make([]func(*s3.Options), 0, len(impl.cfg.ClientOptions)+1) + clientOptions = append(clientOptions, func(o *s3.Options) { + o.APIOptions = append(o.APIOptions, middleware.AddUserAgentKey(userAgentKey)) + }) + clientOptions = append(clientOptions, impl.cfg.ClientOptions...) + impl.cfg.ClientOptions = clientOptions + + for _, option := range options { + option(&impl.cfg) + } + + impl.partBodyMaxRetries = d.PartBodyMaxRetries + + impl.totalBytes = -1 + if impl.cfg.Concurrency == 0 { + impl.cfg.Concurrency = DefaultDownloadConcurrency + } + + if impl.cfg.PartSize == 0 { + impl.cfg.PartSize = DefaultDownloadPartSize + } + + return impl.download() +} + +// downloader is the implementation structure used internally by Downloader. +type downloader struct { + ctx context.Context + cfg Downloader + + in *s3.GetObjectInput + w io.WriterAt + + wg sync.WaitGroup + m sync.Mutex + + pos int64 + totalBytes int64 + written int64 + err error + + partBodyMaxRetries int +} + +// download performs the implementation of the object download across ranged +// GETs. +func (d *downloader) download() (n int64, err error) { + // If range is specified fall back to single download of that range + // this enables the functionality of ranged gets with the downloader but + // at the cost of no multipart downloads. + if rng := aws.ToString(d.in.Range); len(rng) > 0 { + d.downloadRange(rng) + return d.written, d.err + } + + // Spin off first worker to check additional header information + d.getChunk() + + if total := d.getTotalBytes(); total >= 0 { + // Spin up workers + ch := make(chan dlchunk, d.cfg.Concurrency) + + for i := 0; i < d.cfg.Concurrency; i++ { + d.wg.Add(1) + go d.downloadPart(ch) + } + + // Assign work + for d.getErr() == nil { + if d.pos >= total { + break // We're finished queuing chunks + } + + // Queue the next range of bytes to read. + ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize} + d.pos += d.cfg.PartSize + } + + // Wait for completion + close(ch) + d.wg.Wait() + } else { + // Checking if we read anything new + for d.err == nil { + d.getChunk() + } + + // We expect a 416 error letting us know we are done downloading the + // total bytes. Since we do not know the content's length, this will + // keep grabbing chunks of data until the range of bytes specified in + // the request is out of range of the content. Once, this happens, a + // 416 should occur. + var responseError interface { + HTTPStatusCode() int + } + if errors.As(d.err, &responseError) { + if responseError.HTTPStatusCode() == http.StatusRequestedRangeNotSatisfiable { + d.err = nil + } + } + } + + // Return error + return d.written, d.err +} + +// downloadPart is an individual goroutine worker reading from the ch channel +// and performing a GetObject request on the data with a given byte range. +// +// If this is the first worker, this operation also resolves the total number +// of bytes to be read so that the worker manager knows when it is finished. +func (d *downloader) downloadPart(ch chan dlchunk) { + defer d.wg.Done() + for { + chunk, ok := <-ch + if !ok { + break + } + if d.getErr() != nil { + // Drain the channel if there is an error, to prevent deadlocking + // of download producer. + continue + } + + if err := d.downloadChunk(chunk); err != nil { + d.setErr(err) + } + } +} + +// getChunk grabs a chunk of data from the body. +// Not thread safe. Should only used when grabbing data on a single thread. +func (d *downloader) getChunk() { + if d.getErr() != nil { + return + } + + chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize} + d.pos += d.cfg.PartSize + + if err := d.downloadChunk(chunk); err != nil { + d.setErr(err) + } +} + +// downloadRange downloads an Object given the passed in Byte-Range value. +// The chunk used down download the range will be configured for that range. +func (d *downloader) downloadRange(rng string) { + if d.getErr() != nil { + return + } + + chunk := dlchunk{w: d.w, start: d.pos} + // Ranges specified will short circuit the multipart download + chunk.withRange = rng + + if err := d.downloadChunk(chunk); err != nil { + d.setErr(err) + } + + // Update the position based on the amount of data received. + d.pos = d.written +} + +// downloadChunk downloads the chunk from s3 +func (d *downloader) downloadChunk(chunk dlchunk) error { + in := &s3.GetObjectInput{} + awsutil.Copy(in, d.in) + + // Get the next byte range of data + in.Range = aws.String(chunk.ByteRange()) + + var n int64 + var err error + for retry := 0; retry <= d.partBodyMaxRetries; retry++ { + n, err = d.tryDownloadChunk(in, &chunk) + if err == nil { + break + } + // Check if the returned error is an errReadingBody. + // If err is errReadingBody this indicates that an error + // occurred while copying the http response body. + // If this occurs we unwrap the err to set the underlying error + // and attempt any remaining retries. + if bodyErr, ok := err.(*errReadingBody); ok { + err = bodyErr.Unwrap() + } else { + return err + } + + chunk.cur = 0 + + // TODO: Add Logging + //logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries, + // fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d", + // aws.StringValue(in.Key), err, retry)) + } + + d.incrWritten(n) + + return err +} + +func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) { + cleanup := func() {} + if d.cfg.BufferProvider != nil { + w, cleanup = d.cfg.BufferProvider.GetReadFrom(w) + } + defer cleanup() + + resp, err := d.cfg.S3.GetObject(d.ctx, in, d.cfg.ClientOptions...) + if err != nil { + return 0, err + } + d.setTotalBytes(resp) // Set total if not yet set. + + n, err := io.Copy(w, resp.Body) + resp.Body.Close() + if err != nil { + return n, &errReadingBody{err: err} + } + + return n, nil +} + +// getTotalBytes is a thread-safe getter for retrieving the total byte status. +func (d *downloader) getTotalBytes() int64 { + d.m.Lock() + defer d.m.Unlock() + + return d.totalBytes +} + +// setTotalBytes is a thread-safe setter for setting the total byte status. +// Will extract the object's total bytes from the Content-Range if the file +// will be chunked, or Content-Length. Content-Length is used when the response +// does not include a Content-Range. Meaning the object was not chunked. This +// occurs when the full file fits within the PartSize directive. +func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) { + d.m.Lock() + defer d.m.Unlock() + + if d.totalBytes >= 0 { + return + } + + if resp.ContentRange == nil { + // ContentRange is nil when the full file contents is provided, and + // is not chunked. Use ContentLength instead. + if resp.ContentLength != nil { + d.totalBytes = *resp.ContentLength + return + } + } else { + parts := strings.Split(*resp.ContentRange, "/") + + total := int64(-1) + var err error + // Checking for whether or not a numbered total exists + // If one does not exist, we will assume the total to be -1, undefined, + // and sequentially download each chunk until hitting a 416 error + totalStr := parts[len(parts)-1] + if totalStr != "*" { + total, err = strconv.ParseInt(totalStr, 10, 64) + if err != nil { + d.err = err + return + } + } + + d.totalBytes = total + } +} + +func (d *downloader) incrWritten(n int64) { + d.m.Lock() + defer d.m.Unlock() + + d.written += n +} + +// getErr is a thread-safe getter for the error object +func (d *downloader) getErr() error { + d.m.Lock() + defer d.m.Unlock() + + return d.err +} + +// setErr is a thread-safe setter for the error object +func (d *downloader) setErr(e error) { + d.m.Lock() + defer d.m.Unlock() + + d.err = e +} + +// dlchunk represents a single chunk of data to write by the worker routine. +// This structure also implements an io.SectionReader style interface for +// io.WriterAt, effectively making it an io.SectionWriter (which does not +// exist). +type dlchunk struct { + w io.WriterAt + start int64 + size int64 + cur int64 + + // specifies the byte range the chunk should be downloaded with. + withRange string +} + +// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start +// position to its end (or EOF). +// +// If a range is specified on the dlchunk the size will be ignored when writing. +// as the total size may not of be known ahead of time. +func (c *dlchunk) Write(p []byte) (n int, err error) { + if c.cur >= c.size && len(c.withRange) == 0 { + return 0, io.EOF + } + + n, err = c.w.WriteAt(p, c.start+c.cur) + c.cur += int64(n) + + return +} + +// ByteRange returns a HTTP Byte-Range header value that should be used by the +// client to request the chunk's range. +func (c *dlchunk) ByteRange() string { + if len(c.withRange) != 0 { + return c.withRange + } + + return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1) +} diff --git a/feature/s3/manager/download_test.go b/feature/s3/manager/download_test.go new file mode 100644 index 00000000000..7aa2de1ce33 --- /dev/null +++ b/feature/s3/manager/download_test.go @@ -0,0 +1,746 @@ +package manager_test + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + managertesting "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type downloadCaptureClient struct { + GetObjectFn func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) + GetObjectInvocations int + + RetrievedRanges []string + + lock sync.Mutex +} + +func (c *downloadCaptureClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + c.lock.Lock() + defer c.lock.Unlock() + + c.GetObjectInvocations++ + + if params.Range != nil { + c.RetrievedRanges = append(c.RetrievedRanges, aws.ToString(params.Range)) + } + + return c.GetObjectFn(ctx, params, optFns...) +} + +var rangeValueRegex = regexp.MustCompile(`bytes=(\d+)-(\d+)`) + +func parseRange(rangeValue string) (start, fin int64) { + rng := rangeValueRegex.FindStringSubmatch(rangeValue) + start, _ = strconv.ParseInt(rng[1], 10, 64) + fin, _ = strconv.ParseInt(rng[2], 10, 64) + return start, fin +} + +func newDownloadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) { + capture := &downloadCaptureClient{} + + capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + start, fin := parseRange(aws.ToString(params.Range)) + fin++ + + if fin >= int64(len(data)) { + fin = int64(len(data)) + } + + bodyBytes := data[start:fin] + + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(data))), + ContentLength: aws.Int64(int64(len(bodyBytes))), + }, nil + } + + return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges +} + +func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) { + capture := &downloadCaptureClient{} + + capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(data[:])), + ContentLength: aws.Int64(int64(len(data))), + }, nil + } + + return capture, &capture.GetObjectInvocations +} + +type mockHTTPStatusError struct { + StatusCode int +} + +func (m *mockHTTPStatusError) Error() string { + return fmt.Sprintf("http status code: %v", m.StatusCode) +} + +func (m *mockHTTPStatusError) HTTPStatusCode() int { + return m.StatusCode +} + +func newDownloadContentRangeTotalAnyClient(data []byte) (*downloadCaptureClient, *int) { + capture := &downloadCaptureClient{} + completed := false + + capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + if completed { + return nil, &mockHTTPStatusError{StatusCode: 416} + } + + start, fin := parseRange(aws.ToString(params.Range)) + fin++ + + if fin >= int64(len(data)) { + fin = int64(len(data)) + completed = true + } + + bodyBytes := data[start:fin] + + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/*", start, fin-1)), + }, nil + } + + return capture, &capture.GetObjectInvocations +} + +func newDownloadWithErrReaderClient(cases []testErrReader) (*downloadCaptureClient, *int) { + var index int + + c := &downloadCaptureClient{} + c.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + c := cases[index] + out := &s3.GetObjectOutput{ + Body: ioutil.NopCloser(&c), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len)), + ContentLength: aws.Int64(c.Len), + } + index++ + return out, nil + } + + return c, &c.GetObjectInvocations +} + +func TestDownloadOrder(t *testing.T) { + c, invocations, ranges := newDownloadRangeClient(buf12MB) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + + w := aws.NewWriteAtBuffer(make([]byte, len(buf12MB))) + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(len(buf12MB)), n; e != a { + t.Errorf("expect %d buffer length, got %d", e, a) + } + + if e, a := 3, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"} + if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } +} + +func TestDownloadZero(t *testing.T) { + c, invocations, ranges := newDownloadRangeClient([]byte{}) + + d := manager.NewDownloader(c) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if n != 0 { + t.Errorf("expect 0 bytes read, got %d", n) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + expectRngs := []string{"bytes=0-5242879"} + if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } +} + +func TestDownloadSetPartSize(t *testing.T) { + c, invocations, ranges := newDownloadRangeClient([]byte{1, 2, 3}) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + d.PartSize = 1 + }) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(3), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 3, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"} + if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } + expectBytes := []byte{1, 2, 3} + if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { + t.Errorf("expect %v bytes, got %v", e, a) + } +} + +func TestDownloadError(t *testing.T) { + c, invocations, _ := newDownloadRangeClient([]byte{1, 2, 3}) + + num := 0 + orig := c.GetObjectFn + c.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + out, err := orig(ctx, params, optFns...) + num++ + if num > 1 { + return &s3.GetObjectOutput{}, fmt.Errorf("s3 service error") + } + return out, err + } + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + d.PartSize = 1 + }) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := "s3 service error", err.Error(); e != a { + t.Errorf("expect %s error code, got %s", e, a) + } + if e, a := int64(1), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 2, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + expectBytes := []byte{1} + if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { + t.Errorf("expect %v bytes, got %v", e, a) + } +} + +func TestDownloadNonChunk(t *testing.T) { + c, invocations := newDownloadNonRangeClient(buf2MB) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(len(buf2MB)), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + count := 0 + for _, b := range w.Bytes() { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } +} + +func TestDownloadNoContentRangeLength(t *testing.T) { + s, invocations, _ := newDownloadRangeClient(buf2MB) + + d := manager.NewDownloader(s, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(len(buf2MB)), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + count := 0 + for _, b := range w.Bytes() { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } +} + +func TestDownloadContentRangeTotalAny(t *testing.T) { + s, invocations := newDownloadContentRangeTotalAnyClient(buf2MB) + + d := manager.NewDownloader(s, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(len(buf2MB)), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 2, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + count := 0 + for _, b := range w.Bytes() { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } +} + +func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) { + c, invocations := newDownloadWithErrReaderClient([]testErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + {Buf: []byte("123"), Len: 3, Err: io.EOF}, + }) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(3), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 2, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } +} + +func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) { + c, invocations := newDownloadWithErrReaderClient([]testErrReader{ + {Buf: []byte("abc"), Len: 3, Err: io.EOF}, + }) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + }) + + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(3), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + if e, a := "abc", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } +} + +func TestDownloadPartBodyRetry_FailRetry(t *testing.T) { + c, invocations := newDownloadWithErrReaderClient([]testErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + }) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 1 + d.PartBodyMaxRetries = 0 + }) + + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %q error message to be in %q", e, a) + } + if e, a := int64(2), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + if e, a := "ab", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } +} + +func TestDownloadWithContextCanceled(t *testing.T) { + d := manager.NewDownloader(s3.New(s3.Options{})) + + params := s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("Key"), + } + + ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} + ctx.Error = fmt.Errorf("context canceled") + close(ctx.DoneCh) + + w := &aws.WriteAtBuffer{} + + _, err := d.Download(ctx, w, ¶ms) + if err == nil { + t.Fatalf("expected error, did not get one") + } + if e, a := "canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expected error message to contain %q, but did not %q", e, a) + } +} + +func TestDownload_WithRange(t *testing.T) { + c, invocations, ranges := newDownloadRangeClient([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + + d := manager.NewDownloader(c, func(d *manager.Downloader) { + d.Concurrency = 10 // should be ignored + d.PartSize = 1 // should be ignored + }) + + w := &aws.WriteAtBuffer{} + n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Range: aws.String("bytes=2-6"), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := int64(5), n; e != a { + t.Errorf("expect %d bytes read, got %d", e, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + expectRngs := []string{"bytes=2-6"} + if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } + expectBytes := []byte{2, 3, 4, 5, 6} + if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { + t.Errorf("expect %v bytes, got %v", e, a) + } +} + +type mockDownloadCLient func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + +func (m mockDownloadCLient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + return m(ctx, params, optFns...) +} + +func TestDownload_WithFailure(t *testing.T) { + reqCount := int64(0) + startingByte := 0 + + client := mockDownloadCLient(func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (out *s3.GetObjectOutput, err error) { + switch atomic.LoadInt64(&reqCount) { + case 1: + // Give a chance for the multipart chunks to be queued up + time.Sleep(1 * time.Second) + err = fmt.Errorf("some connection error") + default: + body := bytes.NewReader(make([]byte, manager.DefaultDownloadPartSize)) + out = &s3.GetObjectOutput{ + Body: ioutil.NopCloser(body), + ContentLength: aws.Int64(int64(body.Len())), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)), + } + + startingByte += body.Len() + if reqCount > 0 { + // sleep here to ensure context switching between goroutines + time.Sleep(25 * time.Millisecond) + } + } + atomic.AddInt64(&reqCount, 1) + return out, err + }) + + d := manager.NewDownloader(client, func(d *manager.Downloader) { + d.Concurrency = 2 + }) + + w := &aws.WriteAtBuffer{} + params := s3.GetObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + } + + // Expect this request to exit quickly after failure + _, err := d.Download(context.Background(), w, ¶ms) + if err == nil { + t.Fatalf("expect error, got none") + } + + if atomic.LoadInt64(&reqCount) > 3 { + t.Errorf("expect no more than 3 requests, but received %d", reqCount) + } +} + +func TestDownloadBufferStrategy(t *testing.T) { + cases := map[string]struct { + partSize int64 + strategy *recordedWriterReadFromProvider + expectedSize int64 + }{ + "no strategy": { + partSize: manager.DefaultDownloadPartSize, + expectedSize: 10 * sdkio.MebiByte, + }, + "partSize modulo bufferSize == 0": { + partSize: 5 * sdkio.MebiByte, + strategy: &recordedWriterReadFromProvider{ + WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB + }, + expectedSize: 10 * sdkio.MebiByte, // 10 MiB + }, + "partSize modulo bufferSize > 0": { + partSize: 5 * 1024 * 1204, // 5 MiB + strategy: &recordedWriterReadFromProvider{ + WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB + }, + expectedSize: 10 * sdkio.MebiByte, // 10 MiB + }, + } + + for name, tCase := range cases { + t.Run(name, func(t *testing.T) { + expected := managertesting.GetTestBytes(int(tCase.expectedSize)) + + client, _, _ := newDownloadRangeClient(expected) + + d := manager.NewDownloader(client, func(d *manager.Downloader) { + d.PartSize = tCase.partSize + if tCase.strategy != nil { + d.BufferProvider = tCase.strategy + } + }) + + buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) + + n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err != nil { + t.Errorf("failed to download: %v", err) + } + + if e, a := len(expected), int(n); e != a { + t.Errorf("expected %v, got %v downloaded bytes", e, a) + } + + if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { + t.Errorf("downloaded bytes did not match expected") + } + + if tCase.strategy != nil { + if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a { + t.Errorf("expected %v, got %v", e, a) + } + } + }) + } +} + +type testErrReader struct { + Buf []byte + Err error + Len int64 + + off int +} + +func (r *testErrReader) Read(p []byte) (int, error) { + to := len(r.Buf) - r.off + + n := copy(p, r.Buf[r.off:to]) + r.off += n + + if n < len(p) { + return n, r.Err + + } + + return n, nil +} + +func TestDownloadBufferStrategy_Errors(t *testing.T) { + expected := managertesting.GetTestBytes(int(10 * sdkio.MebiByte)) + + client, _, _ := newDownloadRangeClient(expected) + strat := &recordedWriterReadFromProvider{ + WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)), + } + + seenOps := make(map[string]struct{}) + orig := client.GetObjectFn + client.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + out, err := orig(ctx, params, optFns...) + + fingerPrint := fmt.Sprintf("%s/%s/%s", *params.Bucket, *params.Key, *params.Range) + if _, ok := seenOps[fingerPrint]; ok { + return out, err + } + seenOps[fingerPrint] = struct{}{} + + _, _ = io.Copy(ioutil.Discard, out.Body) + + out.Body = ioutil.NopCloser(&badReader{err: io.ErrUnexpectedEOF}) + + return out, err + } + + d := manager.NewDownloader(client, func(d *manager.Downloader) { + d.PartSize = 5 * sdkio.MebiByte + d.BufferProvider = strat + d.Concurrency = 1 + }) + + buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) + + n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err != nil { + t.Errorf("failed to download: %v", err) + } + + if e, a := len(expected), int(n); e != a { + t.Errorf("expected %v, got %v downloaded bytes", e, a) + } + + if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { + t.Errorf("downloaded bytes did not match expected") + } + + if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a { + t.Errorf("expected %v, got %v", e, a) + } +} + +type recordedWriterReadFromProvider struct { + callbacksVended uint32 + callbacksExecuted uint32 + manager.WriterReadFromProvider +} + +func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (manager.WriterReadFrom, func()) { + w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer) + + atomic.AddUint32(&r.callbacksVended, 1) + return w, func() { + atomic.AddUint32(&r.callbacksExecuted, 1) + cleanup() + } +} + +type badReader struct { + err error +} + +func (b *badReader) Read(p []byte) (int, error) { + tb := managertesting.GetTestBytes(len(p)) + copy(p, tb) + + return len(p), b.err +} diff --git a/feature/s3/manager/examples_test.go b/feature/s3/manager/examples_test.go new file mode 100644 index 00000000000..ef4162fdb20 --- /dev/null +++ b/feature/s3/manager/examples_test.go @@ -0,0 +1,69 @@ +package manager_test + +import ( + "bytes" + "context" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// ExampleNewUploader_overrideReadSeekerProvider gives an example +// on a custom ReadSeekerWriteToProvider can be provided to Uploader +// to define how parts will be buffered in memory. +func ExampleNewUploader_overrideReadSeekerProvider() { + cfg, err := config.LoadDefaultConfig() + if err != nil { + panic(err) + } + + uploader := manager.NewUploader(s3.NewFromConfig(cfg), func(u *manager.Uploader) { + // Define a strategy that will buffer 25 MiB in memory + u.BufferProvider = manager.NewBufferedReadSeekerWriteToPool(25 * 1024 * 1024) + }) + + _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("examplebucket"), + Key: aws.String("largeobject"), + Body: bytes.NewReader([]byte("large_multi_part_upload")), + }) + if err != nil { + panic(err) + } +} + +// ExampleNewUploader_overrideTransport gives an example +// on how to override the default HTTP transport. This can +// be used to tune timeouts such as response headers, or +// write / read buffer usage when writing or reading respectively +// from the net/http transport. +func ExampleNewUploader_overrideTransport() { + cfg, err := config.LoadDefaultConfig() + if err != nil { + panic(err) + } + + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + // Override Default Transport Values + o.HTTPClient = aws.NewBuildableHTTPClient().WithTransportOptions(func(tr *http.Transport) { + tr.ResponseHeaderTimeout = 1 * time.Second + tr.WriteBufferSize = 1024 * 1024 + tr.ReadBufferSize = 1024 * 1024 + }) + }) + + uploader := manager.NewUploader(client) + + _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("examplebucket"), + Key: aws.String("largeobject"), + Body: bytes.NewReader([]byte("large_multi_part_upload")), + }) + if err != nil { + panic(err) + } +} diff --git a/feature/s3/manager/go.mod b/feature/s3/manager/go.mod new file mode 100644 index 00000000000..99b133cbca1 --- /dev/null +++ b/feature/s3/manager/go.mod @@ -0,0 +1,21 @@ +module github.com/aws/aws-sdk-go-v2/feature/s3/manager + +go 1.15 + +require ( + github.com/aws/aws-sdk-go-v2 v0.26.0 + github.com/aws/aws-sdk-go-v2/config v0.1.1 + github.com/aws/aws-sdk-go-v2/service/s3 v0.26.0 + github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1 + github.com/google/go-cmp v0.4.1 +) + +replace ( + github.com/aws/aws-sdk-go-v2 => ../../../ + github.com/aws/aws-sdk-go-v2/config => ../../../config/ + github.com/aws/aws-sdk-go-v2/credentials => ../../../credentials/ + github.com/aws/aws-sdk-go-v2/ec2imds => ../../../ec2imds + github.com/aws/aws-sdk-go-v2/service/internal/s3shared => ../../../service/internal/s3shared + github.com/aws/aws-sdk-go-v2/service/s3 => ../../../service/s3/ + github.com/aws/aws-sdk-go-v2/service/sts => ../../../service/sts +) diff --git a/feature/s3/manager/go.sum b/feature/s3/manager/go.sum new file mode 100644 index 00000000000..8ecbc067ee5 --- /dev/null +++ b/feature/s3/manager/go.sum @@ -0,0 +1,27 @@ +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v0.0.0-20200930084954-897dfb99530c h1:v1H0WQmb+pNOZ/xDXGT3wXn6aceSN3I2wqK0VpQM/ZM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v0.0.0-20200930084954-897dfb99530c/go.mod h1:GRJ/IvA6A00/2tAw9KgMTM8as5gAlNI0FVCKBc+aRnA= +github.com/awslabs/smithy-go v0.0.0-20200930175536-2cd7f70a8c2f/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI= +github.com/awslabs/smithy-go v0.0.0-20201009221937-21015eb9ec4b/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI= +github.com/awslabs/smithy-go v0.1.0/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI= +github.com/awslabs/smithy-go v0.1.1 h1:v1hUSAYf3w2ClKr58C+AtwoyPVoBjWyWT8thf7/VRtU= +github.com/awslabs/smithy-go v0.1.1/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI= +github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1 h1:5eAoxqWUc2VMuT3ob/pUYCLliBYEk3dccw6P/reTuRY= +github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/feature/s3/manager/integ_bucket_region_test.go b/feature/s3/manager/integ_bucket_region_test.go new file mode 100644 index 00000000000..d3f886619b7 --- /dev/null +++ b/feature/s3/manager/integ_bucket_region_test.go @@ -0,0 +1,25 @@ +// +build integration + +package manager_test + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func TestGetBucketRegion(t *testing.T) { + expectRegion := integConfig.Region + + region, err := manager.GetBucketRegion(context.Background(), s3.NewFromConfig(integConfig), aws.ToString(bucketName)) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := expectRegion, region; e != a { + t.Errorf("expect %s bucket region, got %s", e, a) + } +} diff --git a/feature/s3/manager/integ_shared_test.go b/feature/s3/manager/integ_shared_test.go new file mode 100644 index 00000000000..68377617359 --- /dev/null +++ b/feature/s3/manager/integ_shared_test.go @@ -0,0 +1,104 @@ +// +build integration + +package manager_test + +import ( + "context" + "crypto/md5" + "flag" + "fmt" + "io" + "os" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +var integConfig aws.Config + +func init() { + var err error + + integConfig, err = config.LoadDefaultConfig(config.WithDefaultRegion("us-west-2")) + if err != nil { + panic(err) + } +} + +var bucketName *string +var client *s3.Client + +func TestMain(m *testing.M) { + flag.Parse() + flag.CommandLine.Visit(func(f *flag.Flag) { + if !(f.Name == "run" || f.Name == "test.run") { + return + } + value := f.Value.String() + if value == `NONE` { + os.Exit(0) + } + }) + + client = s3.NewFromConfig(integConfig) + bucketName = aws.String(integration.GenerateBucketName()) + if err := integration.SetupBucket(client, *bucketName, integConfig.Region); err != nil { + panic(err) + } + + var result int + defer func() { + if err := integration.CleanupBucket(client, *bucketName); err != nil { + fmt.Fprintln(os.Stderr, err) + } + if r := recover(); r != nil { + fmt.Fprintln(os.Stderr, "S3 integration tests panicked,", r) + result = 1 + } + os.Exit(result) + }() + + result = m.Run() +} + +type dlwriter struct { + buf []byte +} + +func newDLWriter(size int) *dlwriter { + return &dlwriter{buf: make([]byte, size)} +} + +func (d dlwriter) WriteAt(p []byte, pos int64) (n int, err error) { + if pos > int64(len(d.buf)) { + return 0, io.EOF + } + + written := 0 + for i, b := range p { + if i >= len(d.buf) { + break + } + d.buf[pos+int64(i)] = b + written++ + } + return written, nil +} + +func validate(t *testing.T, key string, md5value string) { + mgr := manager.NewDownloader(client) + params := &s3.GetObjectInput{Bucket: bucketName, Key: &key} + + w := newDLWriter(1024 * 1024 * 20) + n, err := mgr.Download(context.Background(), w, params) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := md5value, fmt.Sprintf("%x", md5.Sum(w.buf[0:n])); e != a { + t.Errorf("expect %s md5 value, got %s", e, a) + } +} diff --git a/feature/s3/manager/integ_upload_test.go b/feature/s3/manager/integ_upload_test.go new file mode 100644 index 00000000000..fb34c6021b9 --- /dev/null +++ b/feature/s3/manager/integ_upload_test.go @@ -0,0 +1,99 @@ +// +build integration + +package manager_test + +import ( + "bytes" + "context" + "crypto/md5" + "errors" + "fmt" + "regexp" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/awslabs/smithy-go/middleware" +) + +var integBuf12MB = make([]byte, 1024*1024*12) +var integMD512MB = fmt.Sprintf("%x", md5.Sum(integBuf12MB)) + +func TestUploadConcurrently(t *testing.T) { + key := "12mb-1" + mgr := manager.NewUploader(client) + out, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: bucketName, + Key: &key, + Body: bytes.NewReader(integBuf12MB), + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if len(out.UploadID) == 0 { + t.Errorf("expect upload ID but was empty") + } + + re := regexp.MustCompile(`^https?://.+/` + key + `$`) + if e, a := re.String(), out.Location; !re.MatchString(a) { + t.Errorf("expect %s to match URL regexp %q, did not", e, a) + } + + validate(t, key, integMD512MB) +} + +type invalidateHash struct{} + +func (b *invalidateHash) ID() string { + return "s3manager:InvalidateHash" +} + +func (b *invalidateHash) RegisterMiddleware(stack *middleware.Stack) error { + return stack.Serialize.Add(b, middleware.After) +} + +func (b *invalidateHash) HandleSerialize(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + if input, ok := in.Parameters.(*s3.UploadPartInput); ok && aws.ToInt32(input.PartNumber) == 2 { + ctx = v4.SetPayloadHash(ctx, "000") + } + + return next.HandleSerialize(ctx, in) +} + +func TestUploadFailCleanup(t *testing.T) { + key := "12mb-leave" + mgr := manager.NewUploader(client, func(u *manager.Uploader) { + u.LeavePartsOnError = false + u.ClientOptions = append(u.ClientOptions, func(options *s3.Options) { + options.APIOptions = append(options.APIOptions, (&invalidateHash{}).RegisterMiddleware) + }) + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: bucketName, + Key: &key, + Body: bytes.NewReader(integBuf12MB), + }) + if err == nil { + t.Fatalf("expect error, but did not get one") + } + + uploadID := "" + var uf manager.MultiUploadFailure + if !errors.As(err, &uf) { + t.Errorf("") + } else if uploadID = uf.UploadID(); len(uploadID) == 0 { + t.Errorf("expect upload ID to not be empty, but was") + } + + _, err = client.ListParts(context.Background(), &s3.ListPartsInput{ + Bucket: bucketName, Key: &key, UploadId: &uploadID, + }) + if err == nil { + t.Errorf("expect error for list parts, but got none") + } +} diff --git a/feature/s3/manager/internal/integration/downloader/README.md b/feature/s3/manager/internal/integration/downloader/README.md new file mode 100644 index 00000000000..50d690c2060 --- /dev/null +++ b/feature/s3/manager/internal/integration/downloader/README.md @@ -0,0 +1,22 @@ +## Performance Utility + +Downloads a test file from a S3 bucket using the SDK's S3 download manager. Allows passing +in a custom configuration for the HTTP client and SDK's Download Manager behavior. + +## Build +```sh +go test -tags "integration perftest" -c -o download.test ./s3manager/internal/integration/download +``` + +## Usage Example: +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./download.test \ +-test.bench=. \ +-test.benchmem \ +-test.benchtime 1x \ +-bucket aws-sdk-go-data \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` diff --git a/feature/s3/manager/internal/integration/downloader/client.go b/feature/s3/manager/internal/integration/downloader/client.go new file mode 100644 index 00000000000..6e2a4b1d616 --- /dev/null +++ b/feature/s3/manager/internal/integration/downloader/client.go @@ -0,0 +1,34 @@ +// +build integration,perftest + +package downloader + +import ( + "net" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +func NewHTTPClient(cfg ClientConfig) aws.HTTPClient { + return aws.NewBuildableHTTPClient().WithTransportOptions(func(tr *http.Transport) { + *tr = http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: cfg.Timeouts.Connect, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + IdleConnTimeout: 90 * time.Second, + + DisableKeepAlives: !cfg.KeepAlive, + TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake, + ExpectContinueTimeout: cfg.Timeouts.ExpectContinue, + ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader, + + ReadBufferSize: cfg.ReadBufferSize, + WriteBufferSize: cfg.WriteBufferSize, + } + }) +} diff --git a/feature/s3/manager/internal/integration/downloader/config.go b/feature/s3/manager/internal/integration/downloader/config.go new file mode 100644 index 00000000000..fc4f5dd0018 --- /dev/null +++ b/feature/s3/manager/internal/integration/downloader/config.go @@ -0,0 +1,116 @@ +// +build integration,perftest + +package downloader + +import ( + "flag" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" +) + +type SDKConfig struct { + PartSize int64 + Concurrency int + BufferProvider manager.WriterReadFromProvider +} + +func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "sdk." + + flagset.Int64Var(&c.PartSize, prefix+"part-size", manager.DefaultDownloadPartSize, + "Specifies the `size` of parts of the object to download.") + flagset.IntVar(&c.Concurrency, prefix+"concurrency", manager.DefaultDownloadConcurrency, + "Specifies the number of parts to download `at once`.") +} + +func (c *SDKConfig) Validate() error { + return nil +} + +type ClientConfig struct { + KeepAlive bool + Timeouts Timeouts + + MaxIdleConns int + MaxIdleConnsPerHost int + + // Go 1.13 + ReadBufferSize int + WriteBufferSize int +} + +func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "client." + + flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true, + "Specifies if HTTP keep alive is enabled.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns, + "Specifies max idle connection pool size.") + + flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost, + "Specifies max idle connection pool per host, will be truncated by idle-conns.") + + flagset.IntVar(&c.ReadBufferSize, prefix+"read-buffer", defTR.ReadBufferSize, "size of the transport read buffer used") + flagset.IntVar(&c.WriteBufferSize, prefix+"writer-buffer", defTR.WriteBufferSize, "size of the transport write buffer used") + + c.Timeouts.SetupFlags(prefix, flagset) +} + +func (c *ClientConfig) Validate() error { + var errs Errors + + if err := c.Timeouts.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + return nil +} + +type Timeouts struct { + Connect time.Duration + TLSHandshake time.Duration + ExpectContinue time.Duration + ResponseHeader time.Duration +} + +func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "timeout." + + flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second, + "The `timeout` connecting to the remote host.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout, + "The `timeout` waiting for the TLS handshake to complete.") +} + +func (c *Timeouts) Validate() error { + return nil +} + +type Errors []error + +func (es Errors) Error() string { + var buf strings.Builder + for _, e := range es { + buf.WriteString(e.Error()) + } + + return buf.String() +} diff --git a/feature/s3/manager/internal/integration/downloader/main_test.go b/feature/s3/manager/internal/integration/downloader/main_test.go new file mode 100644 index 00000000000..38e275e00fb --- /dev/null +++ b/feature/s3/manager/internal/integration/downloader/main_test.go @@ -0,0 +1,277 @@ +// +build integration,perftest + +package downloader + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "os" + "runtime" + "strconv" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +var benchConfig BenchmarkConfig + +type BenchmarkConfig struct { + bucket string + tempdir string + clientConfig ClientConfig + sizes string + parts string + concurrency string + bufferSize string + uploadPartSize int64 +} + +func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) { + flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark") + flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files") + + flagSet.StringVar(&b.sizes, "size", + fmt.Sprintf("%d,%d", + 5*sdkio.MebiByte, + 1*sdkio.GibiByte), "file sizes to benchmark separated by comma") + + flagSet.StringVar(&b.parts, "part", + fmt.Sprintf("%d,%d,%d", + manager.DefaultDownloadPartSize, + 25*sdkio.MebiByte, + 100*sdkio.MebiByte), "part sizes to benchmark separated by comma") + + flagSet.StringVar(&b.concurrency, "concurrency", + fmt.Sprintf("%d,%d,%d", + manager.DefaultDownloadConcurrency, + 2*manager.DefaultDownloadConcurrency, + 100), + "part sizes to benchmark separated comma") + + flagSet.StringVar(&b.bufferSize, "buffer", fmt.Sprintf("%d,%d", 0, 1*sdkio.MebiByte), "part sizes to benchmark separated comma") + flagSet.Int64Var(&b.uploadPartSize, "upload-part-size", 0, "upload part size, defaults to download part size if not specified") + b.clientConfig.SetupFlags(prefix, flagSet) +} + +func (b *BenchmarkConfig) BufferSizes() []int { + ints, err := b.stringToInt(b.bufferSize) + if err != nil { + panic(fmt.Sprintf("failed to parse file sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) FileSizes() []int64 { + ints, err := b.stringToInt64(b.sizes) + if err != nil { + panic(fmt.Sprintf("failed to parse file sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) PartSizes() []int64 { + ints, err := b.stringToInt64(b.parts) + if err != nil { + panic(fmt.Sprintf("failed to parse part sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) Concurrences() []int { + ints, err := b.stringToInt(b.concurrency) + if err != nil { + panic(fmt.Sprintf("failed to parse part sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) stringToInt(s string) ([]int, error) { + int64s, err := b.stringToInt64(s) + if err != nil { + return nil, err + } + + var ints []int + for i := range int64s { + ints = append(ints, int(int64s[i])) + } + + return ints, nil +} + +func (b *BenchmarkConfig) stringToInt64(s string) ([]int64, error) { + var sizes []int64 + + split := strings.Split(s, ",") + + for _, size := range split { + size = strings.Trim(size, " ") + i, err := strconv.ParseInt(size, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer %s: %v", size, err) + } + + sizes = append(sizes, i) + } + + return sizes, nil +} + +func BenchmarkDownload(b *testing.B) { + baseSdkConfig := SDKConfig{} + + for _, fileSize := range benchConfig.FileSizes() { + b.Run(fmt.Sprintf("%s File", integration.SizeToName(int(fileSize))), func(b *testing.B) { + for _, partSize := range benchConfig.PartSizes() { + if partSize > fileSize { + continue + } + uploadPartSize := getUploadPartSize(fileSize, benchConfig.uploadPartSize, partSize) + b.Run(fmt.Sprintf("%s PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) { + b.Logf("setting up s3 file size") + key, err := setupDownloadTest(benchConfig.bucket, fileSize, uploadPartSize) + if err != nil { + b.Fatalf("failed to setup download test: %v", err) + } + for _, concurrency := range benchConfig.Concurrences() { + b.Run(fmt.Sprintf("%d Concurrency", concurrency), func(b *testing.B) { + for _, bufferSize := range benchConfig.BufferSizes() { + var name string + if bufferSize == 0 { + name = "unbuffered" + } else { + name = fmt.Sprintf("%s buffer", integration.SizeToName(bufferSize)) + } + b.Run(name, func(b *testing.B) { + sdkConfig := baseSdkConfig + sdkConfig.Concurrency = concurrency + sdkConfig.PartSize = partSize + if bufferSize > 0 { + sdkConfig.BufferProvider = manager.NewPooledBufferedWriterReadFromProvider(bufferSize) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchDownload(b, benchConfig.bucket, key, &awstesting.DiscardAt{}, sdkConfig, benchConfig.clientConfig) + } + }) + } + }) + } + b.Log("removing test file") + err = teardownDownloadTest(benchConfig.bucket, key) + if err != nil { + b.Fatalf("failed to cleanup test file: %v", err) + } + }) + } + }) + } +} + +func benchDownload(b *testing.B, bucket, key string, body io.WriterAt, sdkConfig SDKConfig, clientConfig ClientConfig) { + downloader := newDownloader(clientConfig, sdkConfig) + _, err := downloader.Download(context.Background(), body, &s3.GetObjectInput{ + Bucket: &bucket, + Key: &key, + }) + if err != nil { + b.Fatalf("failed to download object, %v", err) + } +} + +func TestMain(m *testing.M) { + if strings.EqualFold(os.Getenv("BUILD_ONLY"), "true") { + os.Exit(0) + } + benchConfig.SetupFlags("", flag.CommandLine) + flag.Parse() + os.Exit(m.Run()) +} + +func setupDownloadTest(bucket string, fileSize, partSize int64) (key string, err error) { + er := &awstesting.EndlessReader{} + lr := io.LimitReader(er, fileSize) + + key = integration.MustUUID() + + defaultConfig, err := config.LoadDefaultConfig() + if err != nil { + return "", err + } + + client := s3.NewFromConfig(defaultConfig) + + uploader := manager.NewUploader(client, func(u *manager.Uploader) { + u.PartSize = partSize + u.Concurrency = runtime.NumCPU() * 2 + }) + + _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &bucket, + Body: lr, + Key: &key, + }) + if err != nil { + err = fmt.Errorf("failed to upload test object to s3: %v", err) + } + + return +} + +func teardownDownloadTest(bucket, key string) error { + defaultConfig, err := config.LoadDefaultConfig() + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + + client := s3.NewFromConfig(defaultConfig) + + _, err = client.DeleteObject(context.Background(), &s3.DeleteObjectInput{Bucket: &bucket, Key: &key}) + return err +} + +func newDownloader(clientConfig ClientConfig, sdkConfig SDKConfig) *manager.Downloader { + defaultConfig, err := config.LoadDefaultConfig() + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + + client := s3.NewFromConfig(defaultConfig, func(options *s3.Options) { + options.HTTPClient = NewHTTPClient(clientConfig) + }) + + downloader := manager.NewDownloader(client, func(d *manager.Downloader) { + d.PartSize = sdkConfig.PartSize + d.Concurrency = sdkConfig.Concurrency + d.BufferProvider = sdkConfig.BufferProvider + }) + + return downloader +} + +func getUploadPartSize(fileSize, uploadPartSize, downloadPartSize int64) int64 { + partSize := uploadPartSize + + if partSize == 0 { + partSize = downloadPartSize + } + if fileSize/partSize > int64(manager.MaxUploadParts) { + partSize = (fileSize / int64(manager.MaxUploadParts)) + 1 + } + + return partSize +} diff --git a/feature/s3/manager/internal/integration/integration.go b/feature/s3/manager/internal/integration/integration.go new file mode 100644 index 00000000000..ded44f39ae2 --- /dev/null +++ b/feature/s3/manager/internal/integration/integration.go @@ -0,0 +1,204 @@ +package integration + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + smithyrand "github.com/awslabs/smithy-go/rand" +) + +var uuid = smithyrand.NewUUID(rand.Reader) + +// MustUUID returns an UUID string or panics +func MustUUID() string { + uuid, err := uuid.GetUUID() + if err != nil { + panic(err) + } + return uuid +} + +// CreateFileOfSize will return an *os.File that is of size bytes +func CreateFileOfSize(dir string, size int64) (*os.File, error) { + file, err := ioutil.TempFile(dir, "s3integration") + if err != nil { + return nil, err + } + + err = file.Truncate(size) + if err != nil { + file.Close() + os.Remove(file.Name()) + return nil, err + } + + return file, nil +} + +// SizeToName returns a human-readable string for the given size bytes +func SizeToName(size int) string { + units := []string{"B", "KB", "MB", "GB"} + i := 0 + for size >= 1024 { + size /= 1024 + i++ + } + + if i > len(units)-1 { + i = len(units) - 1 + } + + return fmt.Sprintf("%d%s", size, units[i]) +} + +// BucketPrefix is the root prefix of integration test buckets. +const BucketPrefix = "aws-sdk-go-v2-integration" + +// GenerateBucketName returns a unique bucket name. +func GenerateBucketName() string { + var id [16]byte + _, err := rand.Read(id[:]) + if err != nil { + panic(err) + } + + return fmt.Sprintf("%s-%x", + BucketPrefix, id) +} + +// SetupBucket returns a test bucket created for the integration tests. +func SetupBucket(client *s3.Client, bucketName, region string) (err error) { + fmt.Println("Setup: Creating test bucket,", bucketName) + _, err = client.CreateBucket(context.Background(), &s3.CreateBucketInput{ + Bucket: &bucketName, + CreateBucketConfiguration: &types.CreateBucketConfiguration{ + LocationConstraint: types.BucketLocationConstraint(region), + }, + }) + if err != nil { + return fmt.Errorf("failed to create bucket %s, %v", bucketName, err) + } + + fmt.Println("Setup: Waiting for bucket to exist,", bucketName) + err = waitUntilBucketExists(context.Background(), client, &s3.HeadBucketInput{Bucket: &bucketName}) + if err != nil { + return fmt.Errorf("failed waiting for bucket %s to be created, %v", + bucketName, err) + } + + return nil +} + +func waitUntilBucketExists(ctx context.Context, client *s3.Client, params *s3.HeadBucketInput) error { + for i := 0; i < 20; i++ { + _, err := client.HeadBucket(ctx, params) + if err == nil { + return nil + } + + var httpErr interface{ HTTPStatusCode() int } + + if !errors.As(err, &httpErr) { + return err + } + + if httpErr.HTTPStatusCode() == http.StatusMovedPermanently || httpErr.HTTPStatusCode() == http.StatusForbidden { + return nil + } + + if httpErr.HTTPStatusCode() != http.StatusNotFound { + return err + } + + time.Sleep(5 * time.Second) + } + return nil +} + +// CleanupBucket deletes the contents of a S3 bucket, before deleting the bucket +// it self. +func CleanupBucket(client *s3.Client, bucketName string) error { + var errs []error + + { + fmt.Println("TearDown: Deleting objects from test bucket,", bucketName) + input := &s3.ListObjectsV2Input{Bucket: &bucketName} + for { + listObjectsV2, err := client.ListObjectsV2(context.Background(), input) + if err != nil { + return fmt.Errorf("failed to list objects, %w", err) + } + + var delete types.Delete + for _, content := range listObjectsV2.Contents { + obj := content + delete.Objects = append(delete.Objects, &types.ObjectIdentifier{Key: obj.Key}) + } + + deleteObjects, err := client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{ + Bucket: &bucketName, + Delete: &delete, + }) + if err != nil { + errs = append(errs, err) + break + } + for _, deleteError := range deleteObjects.Errors { + errs = append(errs, fmt.Errorf("failed to delete %s, %s", aws.ToString(deleteError.Key), aws.ToString(deleteError.Message))) + } + + if aws.ToBool(listObjectsV2.IsTruncated) { + input.ContinuationToken = listObjectsV2.NextContinuationToken + } else { + break + } + } + } + + { + fmt.Println("TearDown: Deleting partial uploads from test bucket,", bucketName) + + input := &s3.ListMultipartUploadsInput{Bucket: &bucketName} + for { + uploads, err := client.ListMultipartUploads(context.Background(), input) + if err != nil { + return fmt.Errorf("failed to list multipart objects, %w", err) + } + + for _, upload := range uploads.Uploads { + client.AbortMultipartUpload(context.Background(), &s3.AbortMultipartUploadInput{ + Bucket: &bucketName, + Key: upload.Key, + UploadId: upload.UploadId, + }) + } + + if aws.ToBool(uploads.IsTruncated) { + input.KeyMarker = uploads.NextKeyMarker + input.UploadIdMarker = uploads.NextUploadIdMarker + } else { + break + } + } + } + + if len(errs) != 0 { + return fmt.Errorf("failed to delete objects, %s", errs) + } + + fmt.Println("TearDown: Deleting test bucket,", bucketName) + if _, err := client.DeleteBucket(context.Background(), &s3.DeleteBucketInput{Bucket: &bucketName}); err != nil { + return fmt.Errorf("failed to delete test bucket %s, %w", bucketName, err) + } + + return nil +} diff --git a/feature/s3/manager/internal/integration/uploader/README.md b/feature/s3/manager/internal/integration/uploader/README.md new file mode 100644 index 00000000000..6f12095fd35 --- /dev/null +++ b/feature/s3/manager/internal/integration/uploader/README.md @@ -0,0 +1,22 @@ +## Performance Utility + +Uploads a file to a S3 bucket using the SDK's S3 upload manager. Allows passing +in a custom configuration for the HTTP client and SDK's Upload Manager behavior. + +## Build +```sh +go test -tags "integration perftest" -c -o uploader.test ./s3manager/internal/integration/performance/uploader +``` + +## Usage Example: +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./uploader.test \ +-test.bench=. \ +-test.benchmem \ +-test.benchtime 1x \ +-bucket aws-sdk-go-data \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` diff --git a/feature/s3/manager/internal/integration/uploader/client.go b/feature/s3/manager/internal/integration/uploader/client.go new file mode 100644 index 00000000000..e409c61f890 --- /dev/null +++ b/feature/s3/manager/internal/integration/uploader/client.go @@ -0,0 +1,31 @@ +// +build integration,perftest + +package uploader + +import ( + "net" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +func NewHTTPClient(cfg ClientConfig) aws.HTTPClient { + return aws.NewBuildableHTTPClient().WithTransportOptions(func(transport *http.Transport) { + *transport = http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: cfg.Timeouts.Connect, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + IdleConnTimeout: 90 * time.Second, + + DisableKeepAlives: !cfg.KeepAlive, + TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake, + ExpectContinueTimeout: cfg.Timeouts.ExpectContinue, + ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader, + } + }) +} diff --git a/feature/s3/manager/internal/integration/uploader/config.go b/feature/s3/manager/internal/integration/uploader/config.go new file mode 100644 index 00000000000..a89f43de068 --- /dev/null +++ b/feature/s3/manager/internal/integration/uploader/config.go @@ -0,0 +1,109 @@ +// +build integration,perftest + +package uploader + +import ( + "flag" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" +) + +type SDKConfig struct { + PartSize int64 + Concurrency int + BufferProvider manager.ReadSeekerWriteToProvider +} + +func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "sdk." + + flagset.Int64Var(&c.PartSize, prefix+"part-size", manager.DefaultUploadPartSize, + "Specifies the `size` of parts of the object to upload.") + flagset.IntVar(&c.Concurrency, prefix+"concurrency", manager.DefaultUploadConcurrency, + "Specifies the number of parts to upload `at once`.") +} + +func (c *SDKConfig) Validate() error { + return nil +} + +type ClientConfig struct { + KeepAlive bool + Timeouts Timeouts + + MaxIdleConns int + MaxIdleConnsPerHost int +} + +func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "client." + + flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true, + "Specifies if HTTP keep alive is enabled.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns, + "Specifies max idle connection pool size.") + + flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost, + "Specifies max idle connection pool per host, will be truncated by idle-conns.") + + c.Timeouts.SetupFlags(prefix, flagset) +} + +func (c *ClientConfig) Validate() error { + var errs Errors + + if err := c.Timeouts.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + return nil +} + +type Timeouts struct { + Connect time.Duration + TLSHandshake time.Duration + ExpectContinue time.Duration + ResponseHeader time.Duration +} + +func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "timeout." + + flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second, + "The `timeout` connecting to the remote host.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout, + "The `timeout` waiting for the TLS handshake to complete.") +} + +func (c *Timeouts) Validate() error { + return nil +} + +type Errors []error + +func (es Errors) Error() string { + var buf strings.Builder + for _, e := range es { + buf.WriteString(e.Error()) + } + + return buf.String() +} diff --git a/feature/s3/manager/internal/integration/uploader/main_test.go b/feature/s3/manager/internal/integration/uploader/main_test.go new file mode 100644 index 00000000000..d333288f055 --- /dev/null +++ b/feature/s3/manager/internal/integration/uploader/main_test.go @@ -0,0 +1,231 @@ +// +build integration,perftest + +package uploader + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "os" + "strconv" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func newUploader(clientConfig ClientConfig, sdkConfig SDKConfig) *manager.Uploader { + defaultConfig, err := config.LoadDefaultConfig() + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + + client := s3.NewFromConfig(defaultConfig, func(o *s3.Options) { + o.HTTPClient = NewHTTPClient(clientConfig) + }) + + uploader := manager.NewUploader(client, func(u *manager.Uploader) { + u.PartSize = sdkConfig.PartSize + u.Concurrency = sdkConfig.Concurrency + u.BufferProvider = sdkConfig.BufferProvider + }) + + return uploader +} + +func getUploadPartSize(fileSize, uploadPartSize int64) int64 { + partSize := uploadPartSize + + if fileSize/partSize > int64(manager.MaxUploadParts) { + partSize = (fileSize / int64(manager.MaxUploadParts)) + 1 + } + + return partSize +} + +var benchConfig BenchmarkConfig + +type BenchmarkConfig struct { + bucket string + tempdir string + clientConfig ClientConfig + sizes string + parts string + concurrency string + bufferSize string +} + +func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) { + flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark") + flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files") + + flagSet.StringVar(&b.sizes, "size", + fmt.Sprintf("%d,%d", + 5*sdkio.MebiByte, + 1*sdkio.GibiByte), "file sizes to benchmark separated by comma") + + flagSet.StringVar(&b.parts, "part", + fmt.Sprintf("%d,%d,%d", + manager.DefaultUploadPartSize, + 25*sdkio.MebiByte, + 100*sdkio.MebiByte), "part sizes to benchmark separated by comma") + + flagSet.StringVar(&b.concurrency, "concurrency", + fmt.Sprintf("%d,%d,%d", + manager.DefaultUploadConcurrency, + 2*manager.DefaultUploadConcurrency, + 100), + "concurrences to benchmark separated comma") + + flagSet.StringVar(&b.bufferSize, "buffer", fmt.Sprintf("%d,%d", 0, 1*sdkio.MebiByte), "part sizes to benchmark separated comma") + b.clientConfig.SetupFlags(prefix, flagSet) +} + +func (b *BenchmarkConfig) BufferSizes() []int { + ints, err := b.stringToInt(b.bufferSize) + if err != nil { + panic(fmt.Sprintf("failed to parse file sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) FileSizes() []int64 { + ints, err := b.stringToInt64(b.sizes) + if err != nil { + panic(fmt.Sprintf("failed to parse file sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) PartSizes() []int64 { + ints, err := b.stringToInt64(b.parts) + if err != nil { + panic(fmt.Sprintf("failed to parse part sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) Concurrences() []int { + ints, err := b.stringToInt(b.concurrency) + if err != nil { + panic(fmt.Sprintf("failed to parse part sizes: %v", err)) + } + + return ints +} + +func (b *BenchmarkConfig) stringToInt(s string) ([]int, error) { + int64s, err := b.stringToInt64(s) + if err != nil { + return nil, err + } + + var ints []int + for i := range int64s { + ints = append(ints, int(int64s[i])) + } + + return ints, nil +} + +func (b *BenchmarkConfig) stringToInt64(s string) ([]int64, error) { + var sizes []int64 + + split := strings.Split(s, ",") + + for _, size := range split { + size = strings.Trim(size, " ") + i, err := strconv.ParseInt(size, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer %s: %v", size, err) + } + + sizes = append(sizes, i) + } + + return sizes, nil +} + +func BenchmarkUpload(b *testing.B) { + baseSdkConfig := SDKConfig{} + + for _, fileSize := range benchConfig.FileSizes() { + b.Run(fmt.Sprintf("%s File", integration.SizeToName(int(fileSize))), func(b *testing.B) { + for _, concurrency := range benchConfig.Concurrences() { + b.Run(fmt.Sprintf("%d Concurrency", concurrency), func(b *testing.B) { + for _, partSize := range benchConfig.PartSizes() { + if partSize > fileSize { + continue + } + partSize = getUploadPartSize(fileSize, partSize) + b.Run(fmt.Sprintf("%s PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) { + for _, bufferSize := range benchConfig.BufferSizes() { + var name string + if bufferSize == 0 { + name = "Unbuffered" + } else { + name = fmt.Sprintf("%s Buffer", integration.SizeToName(bufferSize)) + } + b.Run(name, func(b *testing.B) { + sdkConfig := baseSdkConfig + + sdkConfig.Concurrency = concurrency + sdkConfig.PartSize = partSize + if bufferSize > 0 { + sdkConfig.BufferProvider = manager.NewBufferedReadSeekerWriteToPool(bufferSize) + } + + for i := 0; i < b.N; i++ { + for { + b.ResetTimer() + reader := aws.ReadSeekCloser(io.LimitReader(&awstesting.EndlessReader{}, fileSize)) + err := benchUpload(b, benchConfig.bucket, integration.MustUUID(), reader, sdkConfig, benchConfig.clientConfig) + if err != nil { + b.Logf("upload failed, retrying: %v", err) + continue + } + break + } + } + }) + } + }) + } + }) + } + }) + } +} + +func benchUpload(b *testing.B, bucket, key string, reader io.ReadSeeker, sdkConfig SDKConfig, clientConfig ClientConfig) error { + uploader := newUploader(clientConfig, sdkConfig) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &bucket, + Key: &key, + Body: reader, + }) + if err != nil { + return err + } + return nil +} + +func TestMain(m *testing.M) { + if strings.EqualFold(os.Getenv("BUILD_ONLY"), "true") { + os.Exit(0) + } + benchConfig.SetupFlags("", flag.CommandLine) + flag.Parse() + os.Exit(m.Run()) +} diff --git a/feature/s3/manager/internal/testing/endpoints.go b/feature/s3/manager/internal/testing/endpoints.go new file mode 100644 index 00000000000..aa2f62ed72c --- /dev/null +++ b/feature/s3/manager/internal/testing/endpoints.go @@ -0,0 +1,14 @@ +package testing + +import ( + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// EndpointResolverFunc is a mock s3 endpoint resolver that wraps the given function +type EndpointResolverFunc func(region string, options s3.ResolverOptions) (aws.Endpoint, error) + +// ResolveEndpoint returns the results from the wrapped function. +func (m EndpointResolverFunc) ResolveEndpoint(region string, options s3.ResolverOptions) (aws.Endpoint, error) { + return m(region, options) +} diff --git a/feature/s3/manager/internal/testing/rand.go b/feature/s3/manager/internal/testing/rand.go new file mode 100644 index 00000000000..2c8d27e194f --- /dev/null +++ b/feature/s3/manager/internal/testing/rand.go @@ -0,0 +1,28 @@ +package testing + +import ( + "fmt" + "math/rand" + + "github.com/aws/aws-sdk-go-v2/internal/sdkio" +) + +var randBytes = func() []byte { + rr := rand.New(rand.NewSource(0)) + b := make([]byte, 10*sdkio.MebiByte) + + if _, err := rr.Read(b); err != nil { + panic(fmt.Sprintf("failed to read random bytes, %v", err)) + } + return b +}() + +// GetTestBytes returns a pseudo-random []byte of length size +func GetTestBytes(size int) []byte { + if len(randBytes) >= size { + return randBytes[:size] + } + + b := append(randBytes, GetTestBytes(size-len(randBytes))...) + return b +} diff --git a/feature/s3/manager/internal/testing/upload.go b/feature/s3/manager/internal/testing/upload.go new file mode 100644 index 00000000000..067f35e5dd2 --- /dev/null +++ b/feature/s3/manager/internal/testing/upload.go @@ -0,0 +1,196 @@ +package testing + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// UploadLoggingClient is a mock client that can be used to record and stub responses for testing the s3manager.Uploader. +type UploadLoggingClient struct { + Invocations []string + Params []interface{} + + ConsumeBody bool + + PutObjectFn func(*UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) + UploadPartFn func(*UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) + CreateMultipartUploadFn func(*UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) + CompleteMultipartUploadFn func(*UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) + AbortMultipartUploadFn func(*UploadLoggingClient, *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) + + ignoredOperations []string + + PartNum int + m sync.Mutex +} + +func (u *UploadLoggingClient) simulateHTTPClientOption(optFns ...func(*s3.Options)) error { + o := s3.Options{ + HTTPClient: httpDoFunc(func(request *http.Request) (*http.Response, error) { + return &http.Response{ + Request: request, + }, nil + }), + } + + for _, fn := range optFns { + fn(&o) + } + + _, err := o.HTTPClient.Do(&http.Request{URL: &url.URL{ + Scheme: "https", + Host: "mock.amazonaws.com", + Path: "/key", + RawQuery: "foo=bar", + }}) + if err != nil { + return err + } + + return nil +} + +type httpDoFunc func(*http.Request) (*http.Response, error) + +func (f httpDoFunc) Do(r *http.Request) (*http.Response, error) { + return f(r) +} + +func (u *UploadLoggingClient) traceOperation(name string, params interface{}) { + if contains(u.ignoredOperations, name) { + return + } + + u.Invocations = append(u.Invocations, name) + u.Params = append(u.Params, params) +} + +// PutObject is the S3 PutObject API. +func (u *UploadLoggingClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + u.m.Lock() + defer u.m.Unlock() + + if u.ConsumeBody { + io.Copy(ioutil.Discard, params.Body) + } + + u.traceOperation("PutObject", params) + if err := u.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if u.PutObjectFn != nil { + return u.PutObjectFn(u, params) + } + + return &s3.PutObjectOutput{ + VersionId: aws.String("VERSION-ID"), + }, nil +} + +// UploadPart is the S3 UploadPart API. +func (u *UploadLoggingClient) UploadPart(ctx context.Context, params *s3.UploadPartInput, optFns ...func(*s3.Options)) (*s3.UploadPartOutput, error) { + u.m.Lock() + defer u.m.Unlock() + + if u.ConsumeBody { + io.Copy(ioutil.Discard, params.Body) + } + + u.traceOperation("UploadPart", params) + if err := u.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + u.PartNum++ + + if u.UploadPartFn != nil { + return u.UploadPartFn(u, params) + } + + return &s3.UploadPartOutput{ + ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum)), + }, nil +} + +// CreateMultipartUpload is the S3 CreateMultipartUpload API. +func (u *UploadLoggingClient) CreateMultipartUpload(ctx context.Context, params *s3.CreateMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) { + u.m.Lock() + defer u.m.Unlock() + + u.traceOperation("CreateMultipartUpload", params) + if err := u.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if u.CreateMultipartUploadFn != nil { + return u.CreateMultipartUploadFn(u, params) + } + + return &s3.CreateMultipartUploadOutput{ + UploadId: aws.String("UPLOAD-ID"), + }, nil +} + +// CompleteMultipartUpload is the S3 CompleteMultipartUpload API. +func (u *UploadLoggingClient) CompleteMultipartUpload(ctx context.Context, params *s3.CompleteMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) { + u.m.Lock() + defer u.m.Unlock() + + u.traceOperation("CompleteMultipartUpload", params) + if err := u.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if u.CompleteMultipartUploadFn != nil { + return u.CompleteMultipartUploadFn(u, params) + } + + return &s3.CompleteMultipartUploadOutput{ + Location: aws.String("http://location"), + VersionId: aws.String("VERSION-ID"), + }, nil +} + +// AbortMultipartUpload is the S3 AbortMultipartUpload API. +func (u *UploadLoggingClient) AbortMultipartUpload(ctx context.Context, params *s3.AbortMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) { + u.m.Lock() + defer u.m.Unlock() + + u.traceOperation("AbortMultipartUpload", params) + if err := u.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if u.AbortMultipartUploadFn != nil { + return u.AbortMultipartUploadFn(u, params) + } + + return &s3.AbortMultipartUploadOutput{}, nil +} + +// NewUploadLoggingClient returns a new UploadLoggingClient. +func NewUploadLoggingClient(ignoreOps []string) (*UploadLoggingClient, *[]string, *[]interface{}) { + client := &UploadLoggingClient{ + ignoredOperations: ignoreOps, + } + + return client, &client.Invocations, &client.Params +} + +func contains(src []string, s string) bool { + for _, v := range src { + if s == v { + return true + } + } + return false +} diff --git a/feature/s3/manager/pool.go b/feature/s3/manager/pool.go new file mode 100644 index 00000000000..6b93a3bc443 --- /dev/null +++ b/feature/s3/manager/pool.go @@ -0,0 +1,251 @@ +package manager + +import ( + "context" + "fmt" + "sync" +) + +type byteSlicePool interface { + Get(context.Context) (*[]byte, error) + Put(*[]byte) + ModifyCapacity(int) + SliceSize() int64 + Close() +} + +type maxSlicePool struct { + // allocator is defined as a function pointer to allow + // for test cases to instrument custom tracers when allocations + // occur. + allocator sliceAllocator + + slices chan *[]byte + allocations chan struct{} + capacityChange chan struct{} + + max int + sliceSize int64 + + mtx sync.RWMutex +} + +func newMaxSlicePool(sliceSize int64) *maxSlicePool { + p := &maxSlicePool{sliceSize: sliceSize} + p.allocator = p.newSlice + + return p +} + +var errZeroCapacity = fmt.Errorf("get called on zero capacity pool") + +func (p *maxSlicePool) Get(ctx context.Context) (*[]byte, error) { + // check if context is canceled before attempting to get a slice + // this ensures priority is given to the cancel case first + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + p.mtx.RLock() + + for { + select { + case bs, ok := <-p.slices: + p.mtx.RUnlock() + if !ok { + // attempt to get on a zero capacity pool + return nil, errZeroCapacity + } + return bs, nil + case <-ctx.Done(): + p.mtx.RUnlock() + return nil, ctx.Err() + default: + // pass + } + + select { + case _, ok := <-p.allocations: + p.mtx.RUnlock() + if !ok { + // attempt to get on a zero capacity pool + return nil, errZeroCapacity + } + return p.allocator(), nil + case <-ctx.Done(): + p.mtx.RUnlock() + return nil, ctx.Err() + default: + // In the event that there are no slices or allocations available + // This prevents some deadlock situations that can occur around sync.RWMutex + // When a lock request occurs on ModifyCapacity, no new readers are allowed to acquire a read lock. + // By releasing the read lock here and waiting for a notification, we prevent a deadlock situation where + // Get could hold the read lock indefinitely waiting for capacity, ModifyCapacity is waiting for a write lock, + // and a Put is blocked trying to get a read-lock which is blocked by ModifyCapacity. + + // Short-circuit if the pool capacity is zero. + if p.max == 0 { + p.mtx.RUnlock() + return nil, errZeroCapacity + } + + // Since we will be releasing the read-lock we need to take the reference to the channel. + // Since channels are references we will still get notified if slices are added, or if + // the channel is closed due to a capacity modification. This specifically avoids a data race condition + // where ModifyCapacity both closes a channel and initializes a new one while we don't have a read-lock. + c := p.capacityChange + + p.mtx.RUnlock() + + select { + case _ = <-c: + p.mtx.RLock() + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } +} + +func (p *maxSlicePool) Put(bs *[]byte) { + p.mtx.RLock() + defer p.mtx.RUnlock() + + if p.max == 0 { + return + } + + select { + case p.slices <- bs: + p.notifyCapacity() + default: + // If the new channel when attempting to add the slice then we drop the slice. + // The logic here is to prevent a deadlock situation if channel is already at max capacity. + // Allows us to reap allocations that are returned and are no longer needed. + } +} + +func (p *maxSlicePool) ModifyCapacity(delta int) { + if delta == 0 { + return + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + p.max += delta + + if p.max == 0 { + p.empty() + return + } + + if p.capacityChange != nil { + close(p.capacityChange) + } + p.capacityChange = make(chan struct{}, p.max) + + origAllocations := p.allocations + p.allocations = make(chan struct{}, p.max) + + newAllocs := len(origAllocations) + delta + for i := 0; i < newAllocs; i++ { + p.allocations <- struct{}{} + } + + if origAllocations != nil { + close(origAllocations) + } + + origSlices := p.slices + p.slices = make(chan *[]byte, p.max) + if origSlices == nil { + return + } + + close(origSlices) + for bs := range origSlices { + select { + case p.slices <- bs: + default: + // If the new channel blocks while adding slices from the old channel + // then we drop the slice. The logic here is to prevent a deadlock situation + // if the new channel has a smaller capacity then the old. + } + } +} + +func (p *maxSlicePool) notifyCapacity() { + select { + case p.capacityChange <- struct{}{}: + default: + // This *shouldn't* happen as the channel is both buffered to the max pool capacity size and is resized + // on capacity modifications. This is just a safety to ensure that a blocking situation can't occur. + } +} + +func (p *maxSlicePool) SliceSize() int64 { + return p.sliceSize +} + +func (p *maxSlicePool) Close() { + p.mtx.Lock() + defer p.mtx.Unlock() + p.empty() +} + +func (p *maxSlicePool) empty() { + p.max = 0 + + if p.capacityChange != nil { + close(p.capacityChange) + p.capacityChange = nil + } + + if p.allocations != nil { + close(p.allocations) + for range p.allocations { + // drain channel + } + p.allocations = nil + } + + if p.slices != nil { + close(p.slices) + for range p.slices { + // drain channel + } + p.slices = nil + } +} + +func (p *maxSlicePool) newSlice() *[]byte { + bs := make([]byte, p.sliceSize) + return &bs +} + +type returnCapacityPoolCloser struct { + byteSlicePool + returnCapacity int +} + +func (n *returnCapacityPoolCloser) ModifyCapacity(delta int) { + if delta > 0 { + n.returnCapacity = -1 * delta + } + n.byteSlicePool.ModifyCapacity(delta) +} + +func (n *returnCapacityPoolCloser) Close() { + if n.returnCapacity < 0 { + n.byteSlicePool.ModifyCapacity(n.returnCapacity) + } +} + +type sliceAllocator func() *[]byte + +var newByteSlicePool = func(sliceSize int64) byteSlicePool { + return newMaxSlicePool(sliceSize) +} diff --git a/feature/s3/manager/pool_test.go b/feature/s3/manager/pool_test.go new file mode 100644 index 00000000000..9c6302e9b39 --- /dev/null +++ b/feature/s3/manager/pool_test.go @@ -0,0 +1,197 @@ +package manager + +import ( + "context" + "sync" + "sync/atomic" + "testing" +) + +func TestMaxSlicePool(t *testing.T) { + pool := newMaxSlicePool(0) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // increase pool capacity by 2 + pool.ModifyCapacity(2) + + // remove 2 items + bsOne, err := pool.Get(context.Background()) + if err != nil { + t.Errorf("failed to get slice from pool: %v", err) + } + bsTwo, err := pool.Get(context.Background()) + if err != nil { + t.Errorf("failed to get slice from pool: %v", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // attempt to remove a 3rd in parallel + bs, err := pool.Get(context.Background()) + if err != nil { + t.Errorf("failed to get slice from pool: %v", err) + } + pool.Put(bs) + + // attempt to remove a 4th that has been canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() + bs, err = pool.Get(ctx) + if err == nil { + pool.Put(bs) + t.Errorf("expected no slice to be returned") + return + } + }() + + pool.Put(bsOne) + + <-done + + pool.ModifyCapacity(-1) + + pool.Put(bsTwo) + + pool.ModifyCapacity(-1) + + // any excess returns should drop + rando := make([]byte, 0) + pool.Put(&rando) + }() + } + wg.Wait() + + if e, a := 0, len(pool.slices); e != a { + t.Errorf("expected %v, got %v", e, a) + } + if e, a := 0, len(pool.allocations); e != a { + t.Errorf("expected %v, got %v", e, a) + } + if e, a := 0, pool.max; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + _, err := pool.Get(context.Background()) + if err == nil { + t.Errorf("expected error on zero capacity pool") + } + + pool.Close() +} + +func TestPoolShouldPreferAllocatedSlicesOverNewAllocations(t *testing.T) { + pool := newMaxSlicePool(0) + defer pool.Close() + + // Prepare pool: make it so that pool contains 1 allocated slice and 1 allocation permit + pool.ModifyCapacity(2) + initialSlice, err := pool.Get(context.Background()) + if err != nil { + t.Errorf("failed to get slice from pool: %v", err) + } + pool.Put(initialSlice) + + for i := 0; i < 100; i++ { + newSlice, err := pool.Get(context.Background()) + if err != nil { + t.Errorf("failed to get slice from pool: %v", err) + return + } + + if newSlice != initialSlice { + t.Errorf("pool allocated a new slice despite it having pre-allocated one") + return + } + pool.Put(newSlice) + } +} + +type recordedPartPool struct { + recordedAllocs uint64 + recordedGets uint64 + recordedOutstanding int64 + *maxSlicePool +} + +func newRecordedPartPool(sliceSize int64) *recordedPartPool { + sp := newMaxSlicePool(sliceSize) + + rp := &recordedPartPool{} + + allocator := sp.allocator + sp.allocator = func() *[]byte { + atomic.AddUint64(&rp.recordedAllocs, 1) + return allocator() + } + + rp.maxSlicePool = sp + + return rp +} + +func (r *recordedPartPool) Get(ctx context.Context) (*[]byte, error) { + atomic.AddUint64(&r.recordedGets, 1) + atomic.AddInt64(&r.recordedOutstanding, 1) + return r.maxSlicePool.Get(ctx) +} + +func (r *recordedPartPool) Put(b *[]byte) { + atomic.AddInt64(&r.recordedOutstanding, -1) + r.maxSlicePool.Put(b) +} + +func swapByteSlicePool(f func(sliceSize int64) byteSlicePool) func() { + orig := newByteSlicePool + + newByteSlicePool = f + + return func() { + newByteSlicePool = orig + } +} + +type syncSlicePool struct { + sync.Pool + sliceSize int64 +} + +func newSyncSlicePool(sliceSize int64) *syncSlicePool { + p := &syncSlicePool{sliceSize: sliceSize} + p.New = func() interface{} { + bs := make([]byte, p.sliceSize) + return &bs + } + return p +} + +func (s *syncSlicePool) Get(ctx context.Context) (*[]byte, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return s.Pool.Get().(*[]byte), nil + } +} + +func (s *syncSlicePool) Put(bs *[]byte) { + s.Pool.Put(bs) +} + +func (s *syncSlicePool) ModifyCapacity(_ int) { + return +} + +func (s *syncSlicePool) SliceSize() int64 { + return s.sliceSize +} + +func (s *syncSlicePool) Close() { + return +} diff --git a/feature/s3/manager/read_seeker_write_to.go b/feature/s3/manager/read_seeker_write_to.go new file mode 100644 index 00000000000..ce117c32a13 --- /dev/null +++ b/feature/s3/manager/read_seeker_write_to.go @@ -0,0 +1,65 @@ +package manager + +import ( + "io" + "sync" +) + +// ReadSeekerWriteTo defines an interface implementing io.WriteTo and io.ReadSeeker +type ReadSeekerWriteTo interface { + io.ReadSeeker + io.WriterTo +} + +// BufferedReadSeekerWriteTo wraps a BufferedReadSeeker with an io.WriteAt +// implementation. +type BufferedReadSeekerWriteTo struct { + *BufferedReadSeeker +} + +// WriteTo writes to the given io.Writer from BufferedReadSeeker until there's no more data to write or +// an error occurs. Returns the number of bytes written and any error encountered during the write. +func (b *BufferedReadSeekerWriteTo) WriteTo(writer io.Writer) (int64, error) { + return io.Copy(writer, b.BufferedReadSeeker) +} + +// ReadSeekerWriteToProvider provides an implementation of io.WriteTo for an io.ReadSeeker +type ReadSeekerWriteToProvider interface { + GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func()) +} + +// BufferedReadSeekerWriteToPool uses a sync.Pool to create and reuse +// []byte slices for buffering parts in memory +type BufferedReadSeekerWriteToPool struct { + pool sync.Pool +} + +// NewBufferedReadSeekerWriteToPool will return a new BufferedReadSeekerWriteToPool that will create +// a pool of reusable buffers . If size is less then < 64 KiB then the buffer +// will default to 64 KiB. Reason: io.Copy from writers or readers that don't support io.WriteTo or io.ReadFrom +// respectively will default to copying 32 KiB. +func NewBufferedReadSeekerWriteToPool(size int) *BufferedReadSeekerWriteToPool { + if size < 65536 { + size = 65536 + } + + return &BufferedReadSeekerWriteToPool{ + pool: sync.Pool{New: func() interface{} { + return make([]byte, size) + }}, + } +} + +// GetWriteTo will wrap the provided io.ReadSeeker with a BufferedReadSeekerWriteTo. +// The provided cleanup must be called after operations have been completed on the +// returned io.ReadSeekerWriteTo in order to signal the return of resources to the pool. +func (p *BufferedReadSeekerWriteToPool) GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func()) { + buffer := p.pool.Get().([]byte) + + r = &BufferedReadSeekerWriteTo{BufferedReadSeeker: NewBufferedReadSeeker(seeker, buffer)} + cleanup = func() { + p.pool.Put(buffer) + } + + return r, cleanup +} diff --git a/feature/s3/manager/shared_test.go b/feature/s3/manager/shared_test.go new file mode 100644 index 00000000000..ab2deef300d --- /dev/null +++ b/feature/s3/manager/shared_test.go @@ -0,0 +1,4 @@ +package manager_test + +var buf12MB = make([]byte, 1024*1024*12) +var buf2MB = make([]byte, 1024*1024*2) diff --git a/feature/s3/manager/upload.go b/feature/s3/manager/upload.go new file mode 100644 index 00000000000..90aad5a9fd4 --- /dev/null +++ b/feature/s3/manager/upload.go @@ -0,0 +1,685 @@ +package manager + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "sort" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/internal/awsutil" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +// MaxUploadParts is the maximum allowed number of parts in a multi-part upload +// on Amazon S3. +const MaxUploadParts int32 = 10000 + +// MinUploadPartSize is the minimum allowed part size when uploading a part to +// Amazon S3. +const MinUploadPartSize int64 = 1024 * 1024 * 5 + +// DefaultUploadPartSize is the default part size to buffer chunks of a +// payload into. +const DefaultUploadPartSize = MinUploadPartSize + +// DefaultUploadConcurrency is the default number of goroutines to spin up when +// using Upload(). +const DefaultUploadConcurrency = 5 + +// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned +// will satisfy this interface when a multi part upload failed to upload all +// chucks to S3. In the case of a failure the UploadID is needed to operate on +// the chunks, if any, which were uploaded. +// +// Example: +// +// u := s3manager.NewUploader(client) +// output, err := u.upload(context.Background(), input) +// if err != nil { +// var multierr s3manager.MultiUploadFailure +// if errors.As(err, &multierr) { +// fmt.Printf("upload failure UploadID=%s, %s\n", multierr.UploadID(), multierr.Error()) +// } else { +// fmt.Printf("upload failure, %s\n", err.Error()) +// } +// } +// +type MultiUploadFailure interface { + error + + // UploadID returns the upload id for the S3 multipart upload that failed. + UploadID() string +} + +// A multiUploadError wraps the upload ID of a failed s3 multipart upload. +// Composed of BaseError for code, message, and original error +// +// Should be used for an error that occurred failing a S3 multipart upload, +// and a upload ID is available. If an uploadID is not available a more relevant +type multiUploadError struct { + err error + + // ID for multipart upload which failed. + uploadID string +} + +// batchItemError returns the string representation of the error. +// +// See apierr.BaseError ErrorWithExtra for output format +// +// Satisfies the error interface. +func (m *multiUploadError) Error() string { + var extra string + if m.err != nil { + extra = fmt.Sprintf(", cause: %s", m.err.Error()) + } + return fmt.Sprintf("upload multipart failed, upload id: %s%s", m.uploadID, extra) +} + +// Unwrap returns the underlying error that cause the upload failure +func (m *multiUploadError) Unwrap() error { + return m.err +} + +// UploadID returns the id of the S3 upload which failed. +func (m *multiUploadError) UploadID() string { + return m.uploadID +} + +// UploadOutput represents a response from the Upload() call. +type UploadOutput struct { + // The URL where the object was uploaded to. + Location string + + // The version of the object that was uploaded. Will only be populated if + // the S3 Bucket is versioned. If the bucket is not versioned this field + // will not be set. + VersionID *string + + // The ID for a multipart upload to S3. In the case of an error the error + // can be cast to the MultiUploadFailure interface to extract the upload ID. + UploadID string +} + +// WithUploaderRequestOptions appends to the Uploader's API client options. +func WithUploaderRequestOptions(opts ...func(*s3.Options)) func(*Uploader) { + return func(u *Uploader) { + u.ClientOptions = append(u.ClientOptions, opts...) + } +} + +// The Uploader structure that calls Upload(). It is safe to call Upload() +// on this structure for multiple objects and across concurrent goroutines. +// Mutating the Uploader's properties is not safe to be done concurrently. +type Uploader struct { + // The buffer size (in bytes) to use when buffering data into chunks and + // sending them as parts to S3. The minimum allowed part size is 5MB, and + // if this value is set to zero, the DefaultUploadPartSize value will be used. + PartSize int64 + + // The number of goroutines to spin up in parallel per call to Upload when + // sending parts. If this is set to zero, the DefaultUploadConcurrency value + // will be used. + // + // The concurrency pool is not shared between calls to Upload. + Concurrency int + + // Setting this value to true will cause the SDK to avoid calling + // AbortMultipartUpload on a failure, leaving all successfully uploaded + // parts on S3 for manual recovery. + // + // Note that storing parts of an incomplete multipart upload counts towards + // space usage on S3 and will add additional costs if not cleaned up. + LeavePartsOnError bool + + // MaxUploadParts is the max number of parts which will be uploaded to S3. + // Will be used to calculate the partsize of the object to be uploaded. + // E.g: 5GB file, with MaxUploadParts set to 100, will upload the file + // as 100, 50MB parts. With a limited of s3.MaxUploadParts (10,000 parts). + // + // MaxUploadParts must not be used to limit the total number of bytes uploaded. + // Use a type like to io.LimitReader (https://golang.org/pkg/io/#LimitedReader) + // instead. An io.LimitReader is helpful when uploading an unbounded reader + // to S3, and you know its maximum size. Otherwise the reader's io.EOF returned + // error must be used to signal end of stream. + // + // Defaults to package const's MaxUploadParts value. + MaxUploadParts int32 + + // The client to use when uploading to S3. + S3 UploadAPIClient + + // List of request options that will be passed down to individual API + // operation requests made by the uploader. + ClientOptions []func(*s3.Options) + + // Defines the buffer strategy used when uploading a part + BufferProvider ReadSeekerWriteToProvider + + // partPool allows for the re-usage of streaming payload part buffers between upload calls + partPool byteSlicePool +} + +// NewUploader creates a new Uploader instance to upload objects to S3. Pass In +// additional functional options to customize the uploader's behavior. Requires a +// client.ConfigProvider in order to create a S3 service client. The session.Session +// satisfies the client.ConfigProvider interface. +// +// Example: +// // Load AWS Config +// cfg, err := config.LoadDefaultConfig() +// if err != nil { +// panic(err) +// } +// +// // Create an S3 Client with the config +// client := s3.NewFromConfig(cfg) +// +// // Create an uploader passing it the client +// uploader := s3manager.NewUploader(client) +// +// // Create an uploader with the client and custom options +// uploader := s3manager.NewUploader(client, func(u *s3manager.Uploader) { +// u.PartSize = 64 * 1024 * 1024 // 64MB per part +// }) +func NewUploader(client UploadAPIClient, options ...func(*Uploader)) *Uploader { + u := &Uploader{ + S3: client, + PartSize: DefaultUploadPartSize, + Concurrency: DefaultUploadConcurrency, + LeavePartsOnError: false, + MaxUploadParts: MaxUploadParts, + BufferProvider: defaultUploadBufferProvider(), + } + + for _, option := range options { + option(u) + } + + u.partPool = newByteSlicePool(u.PartSize) + + return u +} + +// Upload uploads an object to S3, intelligently buffering large +// files into smaller chunks and sending them in parallel across multiple +// goroutines. You can configure the buffer size and concurrency through the +// Uploader parameters. +// +// Additional functional options can be provided to configure the individual +// upload. These options are copies of the Uploader instance Upload is called from. +// Modifying the options will not impact the original Uploader instance. +// +// Use the WithUploaderRequestOptions helper function to pass in request +// options that will be applied to all API operations made with this uploader. +// +// It is safe to call this method concurrently across goroutines. +func (u Uploader) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*Uploader)) (*UploadOutput, error) { + i := uploader{in: input, cfg: u, ctx: ctx} + + // Copy ClientOptions + clientOptions := make([]func(*s3.Options), 0, len(i.cfg.ClientOptions)+1) + clientOptions = append(clientOptions, func(o *s3.Options) { + o.APIOptions = append(o.APIOptions, middleware.AddUserAgentKey(userAgentKey)) + }) + clientOptions = append(clientOptions, i.cfg.ClientOptions...) + i.cfg.ClientOptions = clientOptions + + for _, opt := range opts { + opt(&i.cfg) + } + + return i.upload() +} + +// internal structure to manage an upload to S3. +type uploader struct { + ctx context.Context + cfg Uploader + + in *s3.PutObjectInput + + readerPos int64 // current reader position + totalSize int64 // set to -1 if the size is not known +} + +// internal logic for deciding whether to upload a single part or use a +// multipart upload. +func (u *uploader) upload() (*UploadOutput, error) { + if err := u.init(); err != nil { + return nil, fmt.Errorf("unable to initialize upload: %w", err) + } + defer u.cfg.partPool.Close() + + if u.cfg.PartSize < MinUploadPartSize { + return nil, fmt.Errorf("part size must be at least %d bytes", MinUploadPartSize) + } + + // Do one read to determine if we have more than one part + reader, _, cleanup, err := u.nextReader() + if err == io.EOF { // single part + return u.singlePart(reader, cleanup) + } else if err != nil { + cleanup() + return nil, fmt.Errorf("read upload data failed: %w", err) + } + + mu := multiuploader{uploader: u} + return mu.upload(reader, cleanup) +} + +// init will initialize all default options. +func (u *uploader) init() error { + if u.cfg.Concurrency == 0 { + u.cfg.Concurrency = DefaultUploadConcurrency + } + if u.cfg.PartSize == 0 { + u.cfg.PartSize = DefaultUploadPartSize + } + if u.cfg.MaxUploadParts == 0 { + u.cfg.MaxUploadParts = MaxUploadParts + } + + // Try to get the total size for some optimizations + if err := u.initSize(); err != nil { + return err + } + + // If PartSize was changed or partPool was never setup then we need to allocated a new pool + // so that we return []byte slices of the correct size + poolCap := u.cfg.Concurrency + 1 + if u.cfg.partPool == nil || u.cfg.partPool.SliceSize() != u.cfg.PartSize { + u.cfg.partPool = newByteSlicePool(u.cfg.PartSize) + u.cfg.partPool.ModifyCapacity(poolCap) + } else { + u.cfg.partPool = &returnCapacityPoolCloser{byteSlicePool: u.cfg.partPool} + u.cfg.partPool.ModifyCapacity(poolCap) + } + + return nil +} + +// initSize tries to detect the total stream size, setting u.totalSize. If +// the size is not known, totalSize is set to -1. +func (u *uploader) initSize() error { + u.totalSize = -1 + + switch r := u.in.Body.(type) { + case io.Seeker: + n, err := aws.SeekerLen(r) + if err != nil { + return err + } + u.totalSize = n + + // Try to adjust partSize if it is too small and account for + // integer division truncation. + if u.totalSize/u.cfg.PartSize >= int64(u.cfg.MaxUploadParts) { + // Add one to the part size to account for remainders + // during the size calculation. e.g odd number of bytes. + u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1 + } + } + + return nil +} + +// nextReader returns a seekable reader representing the next packet of data. +// This operation increases the shared u.readerPos counter, but note that it +// does not need to be wrapped in a mutex because nextReader is only called +// from the main thread. +func (u *uploader) nextReader() (io.ReadSeeker, int, func(), error) { + switch r := u.in.Body.(type) { + case readerAtSeeker: + var err error + + n := u.cfg.PartSize + if u.totalSize >= 0 { + bytesLeft := u.totalSize - u.readerPos + + if bytesLeft <= u.cfg.PartSize { + err = io.EOF + n = bytesLeft + } + } + + var ( + reader io.ReadSeeker + cleanup func() + ) + + reader = io.NewSectionReader(r, u.readerPos, n) + if u.cfg.BufferProvider != nil { + reader, cleanup = u.cfg.BufferProvider.GetWriteTo(reader) + } else { + cleanup = func() {} + } + + u.readerPos += n + + return reader, int(n), cleanup, err + + default: + part, err := u.cfg.partPool.Get(u.ctx) + if err != nil { + return nil, 0, func() {}, err + } + + n, err := readFillBuf(r, *part) + u.readerPos += int64(n) + + cleanup := func() { + u.cfg.partPool.Put(part) + } + + return bytes.NewReader((*part)[0:n]), n, cleanup, err + } +} + +func readFillBuf(r io.Reader, b []byte) (offset int, err error) { + for offset < len(b) && err == nil { + var n int + n, err = r.Read(b[offset:]) + offset += n + } + + return offset, err +} + +// singlePart contains upload logic for uploading a single chunk via +// a regular PutObject request. Multipart requests require at least two +// parts, or at least 5MB of data. +func (u *uploader) singlePart(r io.ReadSeeker, cleanup func()) (*UploadOutput, error) { + defer cleanup() + + params := &s3.PutObjectInput{} + awsutil.Copy(params, u.in) + params.Body = r + + // Need to use request form because URL generated in request is + // used in return. + + var locationRecorder recordLocationClient + out, err := u.cfg.S3.PutObject(u.ctx, params, append(u.cfg.ClientOptions, locationRecorder.WrapClient())...) + if err != nil { + return nil, err + } + + return &UploadOutput{ + Location: locationRecorder.location, + VersionID: out.VersionId, + }, nil +} + +type httpClient interface { + Do(r *http.Request) (*http.Response, error) +} + +type recordLocationClient struct { + httpClient + location string +} + +func (c *recordLocationClient) WrapClient() func(o *s3.Options) { + return func(o *s3.Options) { + c.httpClient = o.HTTPClient + o.HTTPClient = c + } +} + +func (c *recordLocationClient) Do(r *http.Request) (resp *http.Response, err error) { + resp, err = c.httpClient.Do(r) + if err != nil { + return resp, err + } + + if resp.Request != nil && resp.Request.URL != nil { + url := *resp.Request.URL + url.RawQuery = "" + c.location = url.String() + } + + return resp, err +} + +// internal structure to manage a specific multipart upload to S3. +type multiuploader struct { + *uploader + wg sync.WaitGroup + m sync.Mutex + err error + uploadID string + parts completedParts +} + +// keeps track of a single chunk of data being sent to S3. +type chunk struct { + buf io.ReadSeeker + num int32 + cleanup func() +} + +// completedParts is a wrapper to make parts sortable by their part number, +// since S3 required this list to be sent in sorted order. +type completedParts []*types.CompletedPart + +func (a completedParts) Len() int { return len(a) } +func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber } + +// upload will perform a multipart upload using the firstBuf buffer containing +// the first chunk of data. +func (u *multiuploader) upload(firstBuf io.ReadSeeker, cleanup func()) (*UploadOutput, error) { + params := &s3.CreateMultipartUploadInput{} + awsutil.Copy(params, u.in) + + // Create the multipart + var locationRecorder recordLocationClient + resp, err := u.cfg.S3.CreateMultipartUpload(u.ctx, params, append(u.cfg.ClientOptions, locationRecorder.WrapClient())...) + if err != nil { + cleanup() + return nil, err + } + u.uploadID = *resp.UploadId + + // Create the workers + ch := make(chan chunk, u.cfg.Concurrency) + for i := 0; i < u.cfg.Concurrency; i++ { + u.wg.Add(1) + go u.readChunk(ch) + } + + // Send part 1 to the workers + var num int32 = 1 + ch <- chunk{buf: firstBuf, num: num, cleanup: cleanup} + + // Read and queue the rest of the parts + for u.geterr() == nil && err == nil { + var ( + reader io.ReadSeeker + nextChunkLen int + ok bool + ) + + reader, nextChunkLen, cleanup, err = u.nextReader() + ok, err = u.shouldContinue(num, nextChunkLen, err) + if !ok { + cleanup() + if err != nil { + u.seterr(err) + } + break + } + + num++ + + ch <- chunk{buf: reader, num: num, cleanup: cleanup} + } + + // Close the channel, wait for workers, and complete upload + close(ch) + u.wg.Wait() + complete := u.complete() + + if err := u.geterr(); err != nil { + return nil, &multiUploadError{ + err: err, + uploadID: u.uploadID, + } + } + + return &UploadOutput{ + Location: locationRecorder.location, + VersionID: complete.VersionId, + UploadID: u.uploadID, + }, nil +} + +func (u *multiuploader) shouldContinue(part int32, nextChunkLen int, err error) (bool, error) { + if err != nil && err != io.EOF { + return false, fmt.Errorf("read multipart upload data failed, %w", err) + } + + if nextChunkLen == 0 { + // No need to upload empty part, if file was empty to start + // with empty single part would of been created and never + // started multipart upload. + return false, nil + } + + part++ + // This upload exceeded maximum number of supported parts, error now. + if part > u.cfg.MaxUploadParts || part > MaxUploadParts { + var msg string + if part > u.cfg.MaxUploadParts { + msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit", + u.cfg.MaxUploadParts) + } else { + msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit", + MaxUploadParts) + } + return false, fmt.Errorf(msg) + } + + return true, err +} + +// readChunk runs in worker goroutines to pull chunks off of the ch channel +// and send() them as UploadPart requests. +func (u *multiuploader) readChunk(ch chan chunk) { + defer u.wg.Done() + for { + data, ok := <-ch + + if !ok { + break + } + + if u.geterr() == nil { + if err := u.send(data); err != nil { + u.seterr(err) + } + } + + data.cleanup() + } +} + +// send performs an UploadPart request and keeps track of the completed +// part information. +func (u *multiuploader) send(c chunk) error { + params := &s3.UploadPartInput{ + Bucket: u.in.Bucket, + Key: u.in.Key, + Body: c.buf, + UploadId: &u.uploadID, + SSECustomerAlgorithm: u.in.SSECustomerAlgorithm, + SSECustomerKey: u.in.SSECustomerKey, + PartNumber: &c.num, + } + + resp, err := u.cfg.S3.UploadPart(u.ctx, params, u.cfg.ClientOptions...) + if err != nil { + return err + } + + n := c.num + completed := &types.CompletedPart{ETag: resp.ETag, PartNumber: &n} + + u.m.Lock() + u.parts = append(u.parts, completed) + u.m.Unlock() + + return nil +} + +// geterr is a thread-safe getter for the error object +func (u *multiuploader) geterr() error { + u.m.Lock() + defer u.m.Unlock() + + return u.err +} + +// seterr is a thread-safe setter for the error object +func (u *multiuploader) seterr(e error) { + u.m.Lock() + defer u.m.Unlock() + + u.err = e +} + +// fail will abort the multipart unless LeavePartsOnError is set to true. +func (u *multiuploader) fail() { + if u.cfg.LeavePartsOnError { + return + } + + params := &s3.AbortMultipartUploadInput{ + Bucket: u.in.Bucket, + Key: u.in.Key, + UploadId: &u.uploadID, + } + _, err := u.cfg.S3.AbortMultipartUpload(u.ctx, params, u.cfg.ClientOptions...) + if err != nil { + // TODO: Add logging + //logMessage(u.cfg.S3, aws.LogDebug, fmt.Sprintf("failed to abort multipart upload, %v", err)) + _ = err + } +} + +// complete successfully completes a multipart upload and returns the response. +func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput { + if u.geterr() != nil { + u.fail() + return nil + } + + // Parts must be sorted in PartNumber order. + sort.Sort(u.parts) + + params := &s3.CompleteMultipartUploadInput{ + Bucket: u.in.Bucket, + Key: u.in.Key, + UploadId: &u.uploadID, + MultipartUpload: &types.CompletedMultipartUpload{Parts: u.parts}, + } + resp, err := u.cfg.S3.CompleteMultipartUpload(u.ctx, params, u.cfg.ClientOptions...) + if err != nil { + u.seterr(err) + u.fail() + } + + return resp +} + +type readerAtSeeker interface { + io.ReaderAt + io.ReadSeeker +} diff --git a/feature/s3/manager/upload_internal_test.go b/feature/s3/manager/upload_internal_test.go new file mode 100644 index 00000000000..03088bf0df8 --- /dev/null +++ b/feature/s3/manager/upload_internal_test.go @@ -0,0 +1,320 @@ +package manager + +import ( + "bytes" + "context" + "fmt" + "strconv" + "sync" + "sync/atomic" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type testReader struct { + br *bytes.Reader + m sync.Mutex +} + +func (r *testReader) Read(p []byte) (n int, err error) { + r.m.Lock() + defer r.m.Unlock() + return r.br.Read(p) +} + +func TestUploadByteSlicePool(t *testing.T) { + cases := map[string]struct { + PartSize int64 + FileSize int64 + Concurrency int + ExAllocations uint64 + }{ + "single part, single concurrency": { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 5, + ExAllocations: 2, + Concurrency: 1, + }, + "multi-part, single concurrency": { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 10, + ExAllocations: 2, + Concurrency: 1, + }, + "multi-part, multiple concurrency": { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 20, + ExAllocations: 3, + Concurrency: 2, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + var p *recordedPartPool + + unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { + p = newRecordedPartPool(sliceSize) + return p + }) + defer unswap() + + client, _, _ := s3testing.NewUploadLoggingClient(nil) + + uploader := NewUploader(client, func(u *Uploader) { + u.PartSize = tt.PartSize + u.Concurrency = tt.Concurrency + }) + + expected := s3testing.GetTestBytes(int(tt.FileSize)) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: &testReader{br: bytes.NewReader(expected)}, + }) + if err != nil { + t.Errorf("expected no error, but got %v", err) + } + + if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { + t.Fatalf("expected zero outsnatding pool parts, got %d", v) + } + + gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs) + + t.Logf("total gets %v, total allocations %v", gets, allocs) + if e, a := tt.ExAllocations, allocs; a > e { + t.Errorf("expected %v allocations, got %v", e, a) + } + }) + } +} + +func TestUploadByteSlicePool_Failures(t *testing.T) { + const ( + putObject = "PutObject" + createMultipartUpload = "CreateMultipartUpload" + uploadPart = "UploadPart" + completeMultipartUpload = "CompleteMultipartUpload" + ) + + cases := map[string]struct { + PartSize int64 + FileSize int64 + Operations []string + }{ + "single part": { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 4, + Operations: []string{ + putObject, + }, + }, + "multi-part": { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 10, + Operations: []string{ + createMultipartUpload, + uploadPart, + completeMultipartUpload, + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + for _, operation := range tt.Operations { + t.Run(operation, func(t *testing.T) { + var p *recordedPartPool + + unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { + p = newRecordedPartPool(sliceSize) + return p + }) + defer unswap() + + client, _, _ := s3testing.NewUploadLoggingClient(nil) + + switch operation { + case putObject: + client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) { + return nil, fmt.Errorf("put object failure") + } + case createMultipartUpload: + client.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { + return nil, fmt.Errorf("create multipart upload failure") + } + case uploadPart: + client.UploadPartFn = func(*s3testing.UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + return nil, fmt.Errorf("upload part failure") + } + case completeMultipartUpload: + client.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { + return nil, fmt.Errorf("complete multipart upload failure") + } + } + + uploader := NewUploader(client, func(u *Uploader) { + u.Concurrency = 1 + u.PartSize = tt.PartSize + }) + + expected := s3testing.GetTestBytes(int(tt.FileSize)) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: &testReader{br: bytes.NewReader(expected)}, + }) + if err == nil { + t.Fatalf("expected error but got none") + } + + if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { + t.Fatalf("expected zero outsnatding pool parts, got %d", v) + } + }) + } + }) + } +} + +func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) { + var ( + pools []*recordedPartPool + mtx sync.Mutex + ) + + unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { + mtx.Lock() + defer mtx.Unlock() + b := newRecordedPartPool(sliceSize) + pools = append(pools, b) + return b + }) + defer unswap() + + client, _, _ := s3testing.NewUploadLoggingClient(nil) + + uploader := NewUploader(client, func(u *Uploader) { + u.PartSize = 5 * sdkio.MebiByte + u.Concurrency = 2 + }) + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(2) + go func() { + defer wg.Done() + expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: &testReader{br: bytes.NewReader(expected)}, + }) + if err != nil { + t.Errorf("expected no error, but got %v", err) + } + }() + go func() { + defer wg.Done() + expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: &testReader{br: bytes.NewReader(expected)}, + }, func(u *Uploader) { + u.PartSize = 6 * sdkio.MebiByte + }) + if err != nil { + t.Errorf("expected no error, but got %v", err) + } + }() + } + + wg.Wait() + + if e, a := 3, len(pools); e != a { + t.Errorf("expected %v, got %v", e, a) + } + + for _, p := range pools { + if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { + t.Fatalf("expected zero outsnatding pool parts, got %d", v) + } + + t.Logf("total gets %v, total allocations %v", + atomic.LoadUint64(&p.recordedGets), + atomic.LoadUint64(&p.recordedAllocs)) + } +} + +func BenchmarkPools(b *testing.B) { + cases := []struct { + PartSize int64 + FileSize int64 + Concurrency int + ExAllocations uint64 + }{ + 0: { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 5, + Concurrency: 1, + }, + 1: { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 10, + Concurrency: 1, + }, + 2: { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 20, + Concurrency: 2, + }, + 3: { + PartSize: sdkio.MebiByte * 5, + FileSize: sdkio.MebiByte * 250, + Concurrency: 10, + }, + } + + client, _, _ := s3testing.NewUploadLoggingClient(nil) + + pools := map[string]func(sliceSize int64) byteSlicePool{ + "sync.Pool": func(sliceSize int64) byteSlicePool { + return newSyncSlicePool(sliceSize) + }, + "custom": func(sliceSize int64) byteSlicePool { + return newMaxSlicePool(sliceSize) + }, + } + + for name, poolFunc := range pools { + b.Run(name, func(b *testing.B) { + unswap := swapByteSlicePool(poolFunc) + defer unswap() + for i, c := range cases { + b.Run(strconv.Itoa(i), func(b *testing.B) { + uploader := NewUploader(client, func(u *Uploader) { + u.PartSize = c.PartSize + u.Concurrency = c.Concurrency + }) + + expected := s3testing.GetTestBytes(int(c.FileSize)) + b.ResetTimer() + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: &testReader{br: bytes.NewReader(expected)}, + }) + if err != nil { + b.Fatalf("expected no error, but got %v", err) + } + }) + } + }) + } +} diff --git a/feature/s3/manager/upload_test.go b/feature/s3/manager/upload_test.go new file mode 100644 index 00000000000..17a42c5336d --- /dev/null +++ b/feature/s3/manager/upload_test.go @@ -0,0 +1,1134 @@ +package manager_test + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/internal/sdk" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/google/go-cmp/cmp" +) + +// getReaderLength discards the bytes from reader and returns the length +func getReaderLength(r io.Reader) int64 { + n, _ := io.Copy(ioutil.Discard, r) + return n +} + +func TestUploadOrderMulti(t *testing.T) { + c, invocations, args := s3testing.NewUploadLoggingClient(nil) + u := manager.NewUploader(c) + + resp, err := u.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key - value"), + Body: bytes.NewReader(buf12MB), + ServerSideEncryption: types.ServerSideEncryptionAwsKms, + SSEKMSKeyId: aws.String("KmsId"), + ContentType: aws.String("content/type"), + }) + + if err != nil { + t.Errorf("Expected no error but received %v", err) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", + "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(err) + } + + if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if "UPLOAD-ID" != resp.UploadID { + t.Errorf("expect %q, got %q", "UPLOAD-ID", resp.UploadID) + } + + if "VERSION-ID" != *resp.VersionID { + t.Errorf("expect %q, got %q", "VERSION-ID", *resp.VersionID) + } + + // Validate input values + + // UploadPart + for i := 1; i < 4; i++ { + v := aws.ToString((*args)[i].(*s3.UploadPartInput).UploadId) + if "UPLOAD-ID" != v { + t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v) + } + } + + // CompleteMultipartUpload + v := aws.ToString((*args)[4].(*s3.CompleteMultipartUploadInput).UploadId) + if "UPLOAD-ID" != v { + t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v) + } + + parts := (*args)[4].(*s3.CompleteMultipartUploadInput).MultipartUpload.Parts + + for i := 0; i < 3; i++ { + num := aws.ToInt32(parts[i].PartNumber) + etag := aws.ToString(parts[i].ETag) + + if int32(i+1) != num { + t.Errorf("expect %d, got %d", i+1, num) + } + + if matched, err := regexp.MatchString(`^ETAG\d+$`, etag); !matched || err != nil { + t.Errorf("Failed regexp expression `^ETAG\\d+$`") + } + } + + // Custom headers + cmu := (*args)[0].(*s3.CreateMultipartUploadInput) + + if e, a := types.ServerSideEncryptionAwsKms, cmu.ServerSideEncryption; e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if e, a := "KmsId", aws.ToString(cmu.SSEKMSKeyId); e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if e, a := "content/type", aws.ToString(cmu.ContentType); e != a { + t.Errorf("expect %q, got %q", e, a) + } +} + +func TestUploadOrderMultiDifferentPartSize(t *testing.T) { + s, ops, args := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(s, func(u *manager.Uploader) { + u.PartSize = 1024 * 1024 * 7 + u.Concurrency = 1 + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf12MB), + }) + + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"} + if !reflect.DeepEqual(vals, *ops) { + t.Errorf("expect %v, got %v", vals, *ops) + } + + // Part lengths + if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); 1024*1024*7 != len { + t.Errorf("expect %d, got %d", 1024*1024*7, len) + } + if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); 1024*1024*5 != len { + t.Errorf("expect %d, got %d", 1024*1024*5, len) + } +} + +func TestUploadIncreasePartSize(t *testing.T) { + s, invocations, args := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(s, func(u *manager.Uploader) { + u.Concurrency = 1 + u.MaxUploadParts = 2 + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf12MB), + }) + + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if int64(manager.DefaultDownloadPartSize) != mgr.PartSize { + t.Errorf("expect %d, got %d", manager.DefaultDownloadPartSize, mgr.PartSize) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + // Part lengths + if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); (1024*1024*6)+1 != len { + t.Errorf("expect %d, got %d", (1024*1024*6)+1, len) + } + + if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); (1024*1024*6)-1 != len { + t.Errorf("expect %d, got %d", (1024*1024*6)-1, len) + } +} + +func TestUploadFailIfPartSizeTooSmall(t *testing.T) { + mgr := manager.NewUploader(s3.New(s3.Options{}), func(u *manager.Uploader) { + u.PartSize = 5 + }) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf12MB), + }) + + if resp != nil { + t.Errorf("Expected response to be nil, but received %v", resp) + } + + if err == nil { + t.Errorf("Expected error, but received nil") + } + + if e, a := "part size must be at least", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v to be in %v", e, a) + } +} + +func TestUploadOrderSingle(t *testing.T) { + client, invocations, params := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(client) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key - value"), + Body: bytes.NewReader(buf2MB), + ServerSideEncryption: types.ServerSideEncryptionAwsKms, + SSEKMSKeyId: aws.String("KmsId"), + ContentType: aws.String("content/type"), + }) + + if err != nil { + t.Errorf("expect no error but received %v", err) + } + + if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if e := "VERSION-ID"; e != *resp.VersionID { + t.Errorf("expect %q, got %q", e, *resp.VersionID) + } + + if len(resp.UploadID) > 0 { + t.Errorf("expect empty string, got %q", resp.UploadID) + } + + putObjectInput := (*params)[0].(*s3.PutObjectInput) + + if e, a := types.ServerSideEncryptionAwsKms, putObjectInput.ServerSideEncryption; e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if e, a := "KmsId", aws.ToString(putObjectInput.SSEKMSKeyId); e != a { + t.Errorf("expect %q, got %q", e, a) + } + + if e, a := "content/type", aws.ToString(putObjectInput.ContentType); e != a { + t.Errorf("Expected %q, but received %q", e, a) + } +} + +func TestUploadOrderSingleFailure(t *testing.T) { + client, ops, _ := s3testing.NewUploadLoggingClient(nil) + + client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) { + return nil, fmt.Errorf("put object failure") + } + + mgr := manager.NewUploader(client) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf2MB), + }) + + if err == nil { + t.Error("expect error, got nil") + } + + if diff := cmp.Diff([]string{"PutObject"}, *ops); len(diff) > 0 { + t.Error(diff) + } + + if resp != nil { + t.Errorf("expect response to be nil, got %v", resp) + } +} + +func TestUploadOrderZero(t *testing.T) { + c, invocations, params := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(make([]byte, 0)), + }) + + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + if len(resp.Location) == 0 { + t.Error("expect Location to not be empty") + } + + if len(resp.UploadID) > 0 { + t.Errorf("expect empty string, got %q", resp.UploadID) + } + + if e, a := int64(0), getReaderLength((*params)[0].(*s3.PutObjectInput).Body); e != a { + t.Errorf("Expected %d, but received %d", e, a) + } +} + +func TestUploadOrderMultiFailure(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + + c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + if u.PartNum == 2 { + return nil, fmt.Errorf("an unexpected error") + } + return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil + } + + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf12MB), + }) + + if err == nil { + t.Error("expect error, got nil") + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderMultiFailureOnComplete(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + + c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { + return nil, fmt.Errorf("complete multipart error") + } + + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(buf12MB), + }) + + if err == nil { + t.Error("expect error, got nil") + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", + "CompleteMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderMultiFailureOnCreate(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + + c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { + return nil, fmt.Errorf("create multipart upload failure") + } + + mgr := manager.NewUploader(c) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(make([]byte, 1024*1024*12)), + }) + + if err == nil { + t.Error("expect error, got nil") + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderMultiFailureLeaveParts(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + + c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + if u.PartNum == 2 { + return nil, fmt.Errorf("upload part failure") + } + return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil + } + + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + u.LeavePartsOnError = true + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(make([]byte, 1024*1024*12)), + }) + + if err == nil { + t.Error("expect error, got nil") + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *invocations); len(diff) > 0 { + t.Error(err) + } +} + +type failreader struct { + times int + failCount int +} + +func (f *failreader) Read(b []byte) (int, error) { + f.failCount++ + if f.failCount >= f.times { + return 0, fmt.Errorf("random failure") + } + return len(b), nil +} + +func TestUploadOrderReadFail1(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &failreader{times: 1}, + }) + if err == nil { + t.Fatalf("expect error to not be nil") + } + + if e, a := "random failure", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v, got %v", e, a) + } + + if diff := cmp.Diff([]string(nil), *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderReadFail2(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"}) + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &failreader{times: 2}, + }) + if err == nil { + t.Fatalf("expect error to not be nil") + } + + if e, a := "random failure", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v, got %q", e, a) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +type sizedReader struct { + size int + cur int + err error +} + +func (s *sizedReader) Read(p []byte) (n int, err error) { + if s.cur >= s.size { + if s.err == nil { + s.err = io.EOF + } + return 0, s.err + } + + n = len(p) + s.cur += len(p) + if s.cur > s.size { + n -= s.cur - s.size + } + + return n, err +} + +func TestUploadOrderMultiBufferedReader(t *testing.T) { + c, invocations, params := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &sizedReader{size: 1024 * 1024 * 12}, + }) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", + "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + // Part lengths + var parts []int64 + for i := 1; i <= 3; i++ { + parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) + } + sort.Slice(parts, func(i, j int) bool { + return parts[i] < parts[j] + }) + + if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) { + c, invocations, params := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF}, + }) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", + "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + // Part lengths + var parts []int64 + for i := 1; i <= 3; i++ { + parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) + } + sort.Slice(parts, func(i, j int) bool { + return parts[i] < parts[j] + }) + + if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { + t.Error(diff) + } +} + +// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the +// file size is the same as part size. +func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) { + c, invocations, params := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF}, + }) + + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + // Part lengths + var parts []int64 + for i := 1; i <= 2; i++ { + parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) + } + sort.Slice(parts, func(i, j int) bool { + return parts[i] < parts[j] + }) + + if diff := cmp.Diff([]int64{1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { + t.Error(diff) + } +} + +func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"}) + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + u.MaxUploadParts = 2 + }) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &sizedReader{size: 1024 * 1024 * 12}, + }) + if err == nil { + t.Fatal("expect error, got nil") + } + + if resp != nil { + t.Errorf("expect nil, got %v", resp) + } + + if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + if !strings.Contains(err.Error(), "configured MaxUploadParts (2)") { + t.Errorf("expect 'configured MaxUploadParts (2)', got %q", err.Error()) + } +} + +func TestUploadOrderSingleBufferedReader(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &sizedReader{size: 1024 * 1024 * 2}, + }) + + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { + t.Error(diff) + } + + if len(resp.Location) == 0 { + t.Error("expect a value in Location") + } + + if len(resp.UploadID) > 0 { + t.Errorf("expect no value, got %q", resp.UploadID) + } +} + +func TestUploadZeroLenObject(t *testing.T) { + client, invocations, _ := s3testing.NewUploadLoggingClient(nil) + + mgr := manager.NewUploader(client) + resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: strings.NewReader(""), + }) + + if err != nil { + t.Errorf("expect no error but received %v", err) + } + if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { + t.Errorf("expect request to have been made, but was not, %v", diff) + } + + // TODO: not needed? + if len(resp.Location) == 0 { + t.Error("expect a non-empty string value for Location") + } + + if len(resp.UploadID) > 0 { + t.Errorf("expect empty string, but received %q", resp.UploadID) + } +} + +type testIncompleteReader struct { + Size int64 + read int64 +} + +func (r *testIncompleteReader) Read(p []byte) (n int, err error) { + r.read += int64(len(p)) + if r.read >= r.Size { + return int(r.read - r.Size), io.ErrUnexpectedEOF + } + return len(p), nil +} + +func TestUploadUnexpectedEOF(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + u.PartSize = manager.MinUploadPartSize + }) + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: &testIncompleteReader{ + Size: manager.MinUploadPartSize + 1, + }, + }) + if err == nil { + t.Error("expect error, got nil") + } + + // Ensure upload started. + if e, a := "CreateMultipartUpload", (*invocations)[0]; e != a { + t.Errorf("expect %q, got %q", e, a) + } + + // Part may or may not be sent because of timing of sending parts and + // reading next part in upload manager. Just check for the last abort. + if e, a := "AbortMultipartUpload", (*invocations)[len(*invocations)-1]; e != a { + t.Errorf("expect %q, got %q", e, a) + } +} + +func TestSSE(t *testing.T) { + client, _, _ := s3testing.NewUploadLoggingClient(nil) + client.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + if params.SSECustomerAlgorithm == nil { + t.Fatal("SSECustomerAlgoritm should not be nil") + } + if params.SSECustomerKey == nil { + t.Fatal("SSECustomerKey should not be nil") + } + return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil + } + + mgr := manager.NewUploader(client, func(u *manager.Uploader) { + u.Concurrency = 5 + }) + + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + SSECustomerAlgorithm: aws.String("AES256"), + SSECustomerKey: aws.String("foo"), + Body: bytes.NewBuffer(make([]byte, 1024*1024*10)), + }) + + if err != nil { + t.Fatal("Expected no error, but received" + err.Error()) + } +} + +func TestUploadWithContextCanceled(t *testing.T) { + u := manager.NewUploader(s3.New(s3.Options{ + UsePathStyle: true, + })) + + params := s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: bytes.NewReader(make([]byte, 0)), + } + + ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} + ctx.Error = fmt.Errorf("context canceled") + close(ctx.DoneCh) + + _, err := u.Upload(ctx, ¶ms) + if err == nil { + t.Fatalf("expect error, got nil") + } + + if e, a := "canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expected error message to contain %q, but did not %q", e, a) + } +} + +// S3 Uploader incorrectly fails an upload if the content being uploaded +// has a size of MinPartSize * MaxUploadParts. +// Github: aws/aws-sdk-go#2557 +func TestUploadMaxPartsEOF(t *testing.T) { + c, invocations, _ := s3testing.NewUploadLoggingClient(nil) + mgr := manager.NewUploader(c, func(u *manager.Uploader) { + u.Concurrency = 1 + u.PartSize = manager.DefaultUploadPartSize + u.MaxUploadParts = 2 + }) + f := bytes.NewReader(make([]byte, int(mgr.PartSize)*int(mgr.MaxUploadParts))) + + r1 := io.NewSectionReader(f, 0, manager.DefaultUploadPartSize) + r2 := io.NewSectionReader(f, manager.DefaultUploadPartSize, 2*manager.DefaultUploadPartSize) + body := io.MultiReader(r1, r2) + + _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("Bucket"), + Key: aws.String("Key"), + Body: body, + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expectOps := []string{ + "CreateMultipartUpload", + "UploadPart", + "UploadPart", + "CompleteMultipartUpload", + } + if diff := cmp.Diff(expectOps, *invocations); len(diff) > 0 { + t.Error(diff) + } +} + +func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) { + file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name()) + if err != nil { + return nil, nil, err + } + filename := file.Name() + if err := file.Truncate(size); err != nil { + return nil, nil, err + } + + return file, + func(t *testing.T) { + if err := file.Close(); err != nil { + t.Errorf("failed to close temp file, %s, %v", filename, err) + } + if err := os.Remove(filename); err != nil { + t.Errorf("failed to remove temp file, %s, %v", filename, err) + } + }, + nil +} + +func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler { + handlers := make([]http.Handler, parts) + for i := 0; i < len(handlers); i++ { + handlers[i] = &failPartHandler{ + tb: tb, + failsRemaining: retry, + successHandler: successPartHandler{tb: tb}, + } + } + + return handlers +} + +func TestUploadRetry(t *testing.T) { + const numParts, retries = 3, 10 + + testFile, testFileCleanup, err := createTempFile(t, manager.DefaultUploadPartSize*numParts) + if err != nil { + t.Fatalf("failed to create test file, %v", err) + } + defer testFileCleanup(t) + + cases := map[string]struct { + Body io.Reader + PartHandlers func(testing.TB) []http.Handler + }{ + "bytes.Buffer": { + Body: bytes.NewBuffer(make([]byte, manager.DefaultUploadPartSize*numParts)), + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + "bytes.Reader": { + Body: bytes.NewReader(make([]byte, manager.DefaultUploadPartSize*numParts)), + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + "os.File": { + Body: testFile, + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + restoreSleep := sdk.TestingUseNoOpSleep() + defer restoreSleep() + + mux := newMockS3UploadServer(t, c.PartHandlers(t)) + server := httptest.NewServer(mux) + defer server.Close() + + client := s3.New(s3.Options{ + EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) { + return aws.Endpoint{ + URL: server.URL, + }, nil + }), + UsePathStyle: true, + Retryer: retry.NewStandard(func(o *retry.StandardOptions) { + o.MaxAttempts = retries + 1 + }), + }) + + uploader := manager.NewUploader(client) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: c.Body, + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestUploadBufferStrategy(t *testing.T) { + cases := map[string]struct { + PartSize int64 + Size int64 + Strategy manager.ReadSeekerWriteToProvider + callbacks int + }{ + "NoBuffer": { + PartSize: manager.DefaultUploadPartSize, + Strategy: nil, + }, + "SinglePart": { + PartSize: manager.DefaultUploadPartSize, + Size: manager.DefaultUploadPartSize, + Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)}, + callbacks: 1, + }, + "MultiPart": { + PartSize: manager.DefaultUploadPartSize, + Size: manager.DefaultUploadPartSize * 2, + Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)}, + callbacks: 2, + }, + } + + for name, tCase := range cases { + t.Run(name, func(t *testing.T) { + client, _, _ := s3testing.NewUploadLoggingClient(nil) + client.ConsumeBody = true + + uploader := manager.NewUploader(client, func(u *manager.Uploader) { + u.PartSize = tCase.PartSize + u.BufferProvider = tCase.Strategy + u.Concurrency = 1 + }) + + expected := s3testing.GetTestBytes(int(tCase.Size)) + _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: bytes.NewReader(expected), + }) + if err != nil { + t.Fatalf("failed to upload file: %v", err) + } + + switch strat := tCase.Strategy.(type) { + case *recordedBufferProvider: + if !bytes.Equal(expected, strat.content) { + t.Errorf("content buffered did not match expected") + } + if tCase.callbacks != strat.callbackCount { + t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount) + } + } + }) + } +} + +type mockS3UploadServer struct { + *http.ServeMux + + tb testing.TB + partHandler []http.Handler +} + +func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer { + s := &mockS3UploadServer{ + ServeMux: http.NewServeMux(), + partHandler: partHandler, + tb: tb, + } + + s.HandleFunc("/", s.handleRequest) + + return s +} + +func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + _, hasUploads := r.URL.Query()["uploads"] + + switch { + case r.Method == "POST" && hasUploads: + // CreateMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp))) + w.Write([]byte(createUploadResp)) + + case r.Method == "PUT": + // UploadPart + partNumStr := r.URL.Query().Get("partNumber") + id, err := strconv.Atoi(partNumStr) + if err != nil { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("unable to parse partNumber, %q, %v", + partNumStr, err)) + return + } + id-- + if id < 0 || id >= len(s.partHandler) { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("invalid partNumber %v", id)) + return + } + s.partHandler[id].ServeHTTP(w, r) + + case r.Method == "POST": + // CompleteMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp))) + w.Write([]byte(completeUploadResp)) + + case r.Method == "DELETE": + // AbortMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp))) + w.WriteHeader(200) + w.Write([]byte(abortUploadResp)) + + default: + failRequest(w, 400, "BadRequest", + fmt.Sprintf("invalid request %v %v", r.Method, r.URL)) + } +} + +func failRequest(w http.ResponseWriter, status int, code, msg string) { + msg = fmt.Sprintf(baseRequestErrorResp, code, msg) + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) + w.WriteHeader(status) + w.Write([]byte(msg)) +} + +type successPartHandler struct { + tb testing.TB +} + +func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + n, err := io.Copy(ioutil.Discard, r.Body) + if err != nil { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("failed to read body, %v", err)) + return + } + + contLenStr := r.Header.Get("Content-Length") + expectLen, err := strconv.ParseInt(contLenStr, 10, 64) + if err != nil { + h.tb.Logf("expect content-length, got %q, %v", contLenStr, err) + failRequest(w, 400, "BadRequest", + fmt.Sprintf("unable to get content-length %v", err)) + return + } + if e, a := expectLen, n; e != a { + h.tb.Logf("expect %v read, got %v", e, a) + failRequest(w, 400, "BadRequest", + fmt.Sprintf( + "content-length and body do not match, %v, %v", e, a)) + return + } + + w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp))) + w.Write([]byte(uploadPartResp)) +} + +type failPartHandler struct { + tb testing.TB + + failsRemaining int + successHandler http.Handler +} + +func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if h.failsRemaining == 0 && h.successHandler != nil { + h.successHandler.ServeHTTP(w, r) + return + } + + io.Copy(ioutil.Discard, r.Body) + + failRequest(w, 500, "InternalException", + fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber"))) + + h.failsRemaining-- +} + +type recordedBufferProvider struct { + content []byte + size int + callbackCount int +} + +func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (manager.ReadSeekerWriteTo, func()) { + b := make([]byte, r.size) + w := &manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: manager.NewBufferedReadSeeker(seeker, b)} + + return w, func() { + r.content = append(r.content, b...) + r.callbackCount++ + } +} + +const createUploadResp = ` + bucket + key + abc123 +` + +const uploadPartResp = ` + key +` +const baseRequestErrorResp = ` + %s + %s + request-id + host-id +` + +const completeUploadResp = ` + bucket + key + key + https://bucket.us-west-2.amazonaws.com/key + abc123 +` + +const abortUploadResp = `` diff --git a/feature/s3/manager/writer_read_from.go b/feature/s3/manager/writer_read_from.go new file mode 100644 index 00000000000..3df983a652a --- /dev/null +++ b/feature/s3/manager/writer_read_from.go @@ -0,0 +1,75 @@ +package manager + +import ( + "bufio" + "io" + "sync" + + "github.com/aws/aws-sdk-go-v2/internal/sdkio" +) + +// WriterReadFrom defines an interface implementing io.Writer and io.ReaderFrom +type WriterReadFrom interface { + io.Writer + io.ReaderFrom +} + +// WriterReadFromProvider provides an implementation of io.ReadFrom for the given io.Writer +type WriterReadFromProvider interface { + GetReadFrom(writer io.Writer) (w WriterReadFrom, cleanup func()) +} + +type bufferedWriter interface { + WriterReadFrom + Flush() error + Reset(io.Writer) +} + +type bufferedReadFrom struct { + bufferedWriter +} + +func (b *bufferedReadFrom) ReadFrom(r io.Reader) (int64, error) { + n, err := b.bufferedWriter.ReadFrom(r) + if flushErr := b.Flush(); flushErr != nil && err == nil { + err = flushErr + } + return n, err +} + +// PooledBufferedReadFromProvider is a WriterReadFromProvider that uses a sync.Pool +// to manage allocation and reuse of *bufio.Writer structures. +type PooledBufferedReadFromProvider struct { + pool sync.Pool +} + +// NewPooledBufferedWriterReadFromProvider returns a new PooledBufferedReadFromProvider +// Size is used to control the size of the underlying *bufio.Writer created for +// calls to GetReadFrom. +func NewPooledBufferedWriterReadFromProvider(size int) *PooledBufferedReadFromProvider { + if size < int(32*sdkio.KibiByte) { + size = int(64 * sdkio.KibiByte) + } + + return &PooledBufferedReadFromProvider{ + pool: sync.Pool{ + New: func() interface{} { + return &bufferedReadFrom{bufferedWriter: bufio.NewWriterSize(nil, size)} + }, + }, + } +} + +// GetReadFrom takes an io.Writer and wraps it with a type which satisfies the WriterReadFrom +// interface/ Additionally a cleanup function is provided which must be called after usage of the WriterReadFrom +// has been completed in order to allow the reuse of the *bufio.Writer +func (p *PooledBufferedReadFromProvider) GetReadFrom(writer io.Writer) (r WriterReadFrom, cleanup func()) { + buffer := p.pool.Get().(*bufferedReadFrom) + buffer.Reset(writer) + r = buffer + cleanup = func() { + buffer.Reset(nil) // Reset to nil writer to release reference + p.pool.Put(buffer) + } + return r, cleanup +} diff --git a/feature/s3/manager/writer_read_from_test.go b/feature/s3/manager/writer_read_from_test.go new file mode 100644 index 00000000000..4f59f68cdc3 --- /dev/null +++ b/feature/s3/manager/writer_read_from_test.go @@ -0,0 +1,73 @@ +package manager + +import ( + "fmt" + "io" + "reflect" + "testing" +) + +type testBufioWriter struct { + ReadFromN int64 + ReadFromErr error + FlushReturn error +} + +func (t testBufioWriter) Write(p []byte) (n int, err error) { + panic("unused") +} + +func (t testBufioWriter) ReadFrom(r io.Reader) (n int64, err error) { + return t.ReadFromN, t.ReadFromErr +} + +func (t testBufioWriter) Flush() error { + return t.FlushReturn +} + +func (t *testBufioWriter) Reset(io.Writer) { + panic("unused") +} + +func TestBufferedReadFromFlusher_ReadFrom(t *testing.T) { + cases := map[string]struct { + w testBufioWriter + expectedErr error + }{ + "no errors": {}, + "error returned from underlying ReadFrom": { + w: testBufioWriter{ + ReadFromN: 42, + ReadFromErr: fmt.Errorf("readfrom"), + }, + expectedErr: fmt.Errorf("readfrom"), + }, + "error returned from Flush": { + w: testBufioWriter{ + ReadFromN: 7, + FlushReturn: fmt.Errorf("flush"), + }, + expectedErr: fmt.Errorf("flush"), + }, + "error returned from ReadFrom and Flush": { + w: testBufioWriter{ + ReadFromN: 1337, + ReadFromErr: fmt.Errorf("readfrom"), + FlushReturn: fmt.Errorf("flush"), + }, + expectedErr: fmt.Errorf("readfrom"), + }, + } + + for name, tCase := range cases { + t.Log(name) + readFromFlusher := bufferedReadFrom{bufferedWriter: &tCase.w} + n, err := readFromFlusher.ReadFrom(nil) + if e, a := tCase.w.ReadFromN, n; e != a { + t.Errorf("expected %v bytes, got %v", e, a) + } + if e, a := tCase.expectedErr, err; !reflect.DeepEqual(e, a) { + t.Errorf("expected error %v. got %v", e, a) + } + } +}