From 1d018af3bfab9f3a1c761be9382403a4038a4478 Mon Sep 17 00:00:00 2001 From: Alan Protasio Date: Wed, 14 Jun 2023 07:39:30 -0700 Subject: [PATCH] Fix multipart upload with encryption context (#1838) --- api-put-object-multipart.go | 5 +- api-put-object_test.go | 108 ++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/api-put-object-multipart.go b/api-put-object-multipart.go index 85d6c70a2..5f117afa4 100644 --- a/api-put-object-multipart.go +++ b/api-put-object-multipart.go @@ -389,8 +389,9 @@ func (c *Client) completeMultipartUpload(ctx context.Context, bucketName, object headers := opts.Header() if s3utils.IsAmazonEndpoint(*c.endpointURL) { - headers.Del(encrypt.SseKmsKeyID) // Remove X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id not supported in CompleteMultipartUpload - headers.Del(encrypt.SseGenericHeader) // Remove X-Amz-Server-Side-Encryption not supported in CompleteMultipartUpload + headers.Del(encrypt.SseKmsKeyID) // Remove X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id not supported in CompleteMultipartUpload + headers.Del(encrypt.SseGenericHeader) // Remove X-Amz-Server-Side-Encryption not supported in CompleteMultipartUpload + headers.Del(encrypt.SseEncryptionContext) // Remove X-Amz-Server-Side-Encryption-Context not supported in CompleteMultipartUpload } // Instantiate all the complete multipart buffer. diff --git a/api-put-object_test.go b/api-put-object_test.go index 2867f3a19..ca35b12ed 100644 --- a/api-put-object_test.go +++ b/api-put-object_test.go @@ -17,7 +17,13 @@ package minio import ( + "context" + "encoding/base64" + "net/http" + "reflect" "testing" + + "github.com/minio/minio-go/v7/pkg/encrypt" ) func TestPutObjectOptionsValidate(t *testing.T) { @@ -61,3 +67,105 @@ func TestPutObjectOptionsValidate(t *testing.T) { } } } + +type InterceptRouteTripper struct { + request *http.Request +} + +func (i *InterceptRouteTripper) RoundTrip(request *http.Request) (*http.Response, error) { + i.request = request + return &http.Response{StatusCode: 200}, nil +} + +func Test_SSEHeaders(t *testing.T) { + rt := &InterceptRouteTripper{} + c, err := New("s3.amazonaws.com", &Options{ + Transport: rt, + }) + + if err != nil { + t.Error(err) + } + + testCases := map[string]struct { + sse func() encrypt.ServerSide + initiateMultipartUploadHeaders http.Header + headerNotAllowedAfterInit []string + }{ + "noEncryption": { + sse: func() encrypt.ServerSide { return nil }, + initiateMultipartUploadHeaders: http.Header{}, + }, + "sse": { + sse: func() encrypt.ServerSide { + s, err := encrypt.NewSSEKMS("keyId", nil) + if err != nil { + t.Error(err) + } + return s + }, + initiateMultipartUploadHeaders: http.Header{ + encrypt.SseGenericHeader: []string{"aws:kms"}, + encrypt.SseKmsKeyID: []string{"keyId"}, + }, + headerNotAllowedAfterInit: []string{encrypt.SseGenericHeader, encrypt.SseKmsKeyID, encrypt.SseEncryptionContext}, + }, + "sse with context": { + sse: func() encrypt.ServerSide { + s, err := encrypt.NewSSEKMS("keyId", "context") + if err != nil { + t.Error(err) + } + return s + }, + initiateMultipartUploadHeaders: http.Header{ + encrypt.SseGenericHeader: []string{"aws:kms"}, + encrypt.SseKmsKeyID: []string{"keyId"}, + encrypt.SseEncryptionContext: []string{base64.StdEncoding.EncodeToString([]byte("\"context\""))}, + }, + headerNotAllowedAfterInit: []string{encrypt.SseGenericHeader, encrypt.SseKmsKeyID, encrypt.SseEncryptionContext}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + opts := PutObjectOptions{ + ServerSideEncryption: tc.sse(), + } + c.bucketLocCache.Set("test", "region") + c.initiateMultipartUpload(context.Background(), "test", "test", opts) + for s, vls := range tc.initiateMultipartUploadHeaders { + if !reflect.DeepEqual(rt.request.Header[s], vls) { + t.Errorf("Header %v are not equal, want: %v got %v", s, vls, rt.request.Header[s]) + } + } + + _, err := c.uploadPart(context.Background(), uploadPartParams{ + bucketName: "test", + objectName: "test", + partNumber: 1, + uploadID: "upId", + sse: opts.ServerSideEncryption, + }) + + if err != nil { + t.Error(err) + } + + for _, k := range tc.headerNotAllowedAfterInit { + if rt.request.Header.Get(k) != "" { + t.Errorf("header %v should not be set", k) + } + } + + c.completeMultipartUpload(context.Background(), "test", "test", "upId", completeMultipartUpload{}, opts) + + for _, k := range tc.headerNotAllowedAfterInit { + if rt.request.Header.Get(k) != "" { + t.Errorf("header %v should not be set", k) + } + } + }) + } + +}