diff --git a/client/BUILD b/client/BUILD index 526c3c9..2ef9244 100644 --- a/client/BUILD +++ b/client/BUILD @@ -30,6 +30,7 @@ go_library( importpath = "github.com/GoogleCloudPlatform/stet/client", deps = [ "//client/cloudkms", + "//client/confidentialspace", "//client/jwt", "//client/securesession", "//client/shares", @@ -53,6 +54,7 @@ go_test( embed = [":client"], deps = [ "//client/cloudkms", + "//client/confidentialspace", "//client/securesession", "//client/shares", "//client/testutil", @@ -61,5 +63,6 @@ go_test( "@com_github_googleapis_gax_go_v2//:go_default_library", "@go_googleapis//google/cloud/kms/v1:kms_go_proto", "@org_golang_google_protobuf//proto", + "@org_golang_google_protobuf//types/known/wrapperspb", ], ) diff --git a/client/client.go b/client/client.go index 1a1580f..3dd3927 100644 --- a/client/client.go +++ b/client/client.go @@ -27,6 +27,7 @@ import ( "strings" "github.com/GoogleCloudPlatform/stet/client/cloudkms" + "github.com/GoogleCloudPlatform/stet/client/confidentialspace" "github.com/GoogleCloudPlatform/stet/client/jwt" "github.com/GoogleCloudPlatform/stet/client/securesession" "github.com/GoogleCloudPlatform/stet/client/shares" @@ -210,9 +211,15 @@ func getKekURIMetadata(ctx context.Context, kmsClient cloudkms.Client, kekInfo * // list of wrapped shares, and a list of key URIs used for shares that were // wrapped by communicating with an external KMS (these lists might not // correspond one-to-one if some shares are wrapped via asymmetric key). -func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, kekInfos []*configpb.KekInfo, keys *configpb.AsymmetricKeys) (wrappedShares []*configpb.WrappedShare, keyURIs []string, err error) { - if len(unwrappedShares) != len(kekInfos) { - return nil, nil, fmt.Errorf("number of shares to wrap (%d) does not match number of KEKs (%d)", len(unwrappedShares), len(kekInfos)) +type sharesOpts struct { + kekInfos []*configpb.KekInfo + asymmetricKeys *configpb.AsymmetricKeys + confSpaceConfig *confidentialspace.Config +} + +func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, opts sharesOpts) (wrappedShares []*configpb.WrappedShare, keyURIs []string, err error) { + if len(unwrappedShares) != len(opts.kekInfos) { + return nil, nil, fmt.Errorf("number of shares to wrap (%d) does not match number of KEKs (%d)", len(unwrappedShares), len(opts.kekInfos)) } var kmsClients *cloudkms.ClientFactory @@ -228,11 +235,11 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k Hash: shares.HashShare(share), } - kek := kekInfos[i] + kek := opts.kekInfos[i] switch x := kek.KekType.(type) { case *configpb.KekInfo_RsaFingerprint: - key, err := PublicKeyForRSAFingerprint(kek, keys) + key, err := PublicKeyForRSAFingerprint(kek, opts.asymmetricKeys) if err != nil { return nil, nil, fmt.Errorf("failed to find public key for RSA fingerprint: %w", err) } @@ -243,7 +250,12 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k } case *configpb.KekInfo_KekUri: + // Configure CloudKMS Client, with Confidential Space credentials if applicable. creds := "" + if opts.confSpaceConfig != nil { + creds = opts.confSpaceConfig.FindMatchingCredentials(kek.GetKekUri(), configpb.CredentialMode_ENCRYPT_ONLY_MODE) + } + kmsClient, err := kmsClients.Client(ctx, creds) if err != nil { return nil, nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err) @@ -293,10 +305,11 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k } // unwrapAndValidateShares decrypts the given wrapped share based on its URI. -func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares []*configpb.WrappedShare, kekInfos []*configpb.KekInfo, keys *configpb.AsymmetricKeys) ([]shares.UnwrappedShare, error) { - if len(wrappedShares) != len(kekInfos) { - return nil, fmt.Errorf("number of shares to unwrap (%d) does not match number of KEKs (%d)", len(wrappedShares), len(kekInfos)) +func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares []*configpb.WrappedShare, opts sharesOpts) ([]shares.UnwrappedShare, error) { + if len(wrappedShares) != len(opts.kekInfos) { + return nil, fmt.Errorf("number of shares to unwrap (%d) does not match number of KEKs (%d)", len(wrappedShares), len(opts.kekInfos)) } + var kmsClients *cloudkms.ClientFactory if c.testKMSClients != nil { kmsClients = c.testKMSClients @@ -313,11 +326,11 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares for i, wrapped := range wrappedShares { glog.Infof("Attempting to unwrap share #%v", i+1) unwrapped := shares.UnwrappedShare{} - kek := kekInfos[i] + kek := opts.kekInfos[i] switch x := kek.KekType.(type) { case *configpb.KekInfo_RsaFingerprint: - key, err := PrivateKeyForRSAFingerprint(kek, keys) + key, err := PrivateKeyForRSAFingerprint(kek, opts.asymmetricKeys) if err != nil { glog.Warningf("Failed to find public key for RSA fingerprint: %v", err) continue @@ -330,7 +343,12 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares } case *configpb.KekInfo_KekUri: + // Configure CloudKMS Client, with Confidential Space credentials if applicable. creds := "" + if opts.confSpaceConfig != nil { + creds = opts.confSpaceConfig.FindMatchingCredentials(kek.GetKekUri(), configpb.CredentialMode_DECRYPT_ONLY_MODE) + } + kmsClient, err := kmsClients.Client(ctx, creds) if err != nil { return nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err) @@ -386,7 +404,8 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares } // Encrypt generates a DEK and creates EncryptedData in accordance with the EKM encryption protocol. -func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Writer, config *configpb.EncryptConfig, keys *configpb.AsymmetricKeys, blobID string) (*StetMetadata, error) { +func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig, blobID string) (*StetMetadata, error) { + config := stetConfig.GetEncryptConfig() if config == nil { return nil, fmt.Errorf("nil EncryptConfig passed to Encrypt()") } @@ -407,7 +426,15 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri metadata := &configpb.Metadata{BlobId: blobID, KeyConfig: keyCfg} var keyURIs []string - metadata.Shares, keyURIs, err = c.wrapShares(ctx, shares, keyCfg.GetKekInfos(), keys) + opts := sharesOpts{ + kekInfos: keyCfg.GetKekInfos(), + asymmetricKeys: stetConfig.GetAsymmetricKeys(), + } + if csConfigs := stetConfig.GetConfidentialSpaceConfigs(); csConfigs != nil { + opts.confSpaceConfig = confidentialspace.NewConfig(csConfigs) + } + + metadata.Shares, keyURIs, err = c.wrapShares(ctx, shares, opts) if err != nil { return nil, fmt.Errorf("error wrapping shares: %v", err) } @@ -442,11 +469,13 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri KeyUris: keyURIs, BlobID: metadata.GetBlobId(), }, nil + } // Decrypt writes the decrypted data to the `output` writer, and returns the // key URIs used during decryption and the blob ID decrypted. -func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, config *configpb.DecryptConfig, keys *configpb.AsymmetricKeys) (*StetMetadata, error) { +func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig) (*StetMetadata, error) { + config := stetConfig.GetDecryptConfig() if config == nil { return nil, fmt.Errorf("nil DecryptConfig passed to Decrypt()") } @@ -471,7 +500,15 @@ func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Wri } // Unwrap shares and validate. - unwrappedShares, err := c.unwrapAndValidateShares(ctx, metadata.GetShares(), matchingKeyConfig.GetKekInfos(), keys) + opts := sharesOpts{ + kekInfos: matchingKeyConfig.GetKekInfos(), + asymmetricKeys: stetConfig.GetAsymmetricKeys(), + } + if csConfigs := stetConfig.GetConfidentialSpaceConfigs(); csConfigs != nil { + opts.confSpaceConfig = confidentialspace.NewConfig(csConfigs) + } + + unwrappedShares, err := c.unwrapAndValidateShares(ctx, metadata.GetShares(), opts) if err != nil { return nil, fmt.Errorf("error unwrapping and validating shares: %v", err) } diff --git a/client/client_test.go b/client/client_test.go index f51a723..2b3b481 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/GoogleCloudPlatform/stet/client/cloudkms" + confspace "github.com/GoogleCloudPlatform/stet/client/confidentialspace" "github.com/GoogleCloudPlatform/stet/client/securesession" "github.com/GoogleCloudPlatform/stet/client/shares" "github.com/GoogleCloudPlatform/stet/client/testutil" @@ -33,6 +34,7 @@ import ( configpb "github.com/GoogleCloudPlatform/stet/proto/config_go_proto" kmsrpb "google.golang.org/genproto/googleapis/cloud/kms/v1" kmsspb "google.golang.org/genproto/googleapis/cloud/kms/v1" + wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" ) // Fake version of secure session client, used to communicate with external EKM. @@ -407,7 +409,8 @@ func TestWrapSharesIndividually(t *testing.T) { }, } - wrappedShares, _, err := stetClient.wrapShares(ctx, [][]byte{testShare}, ki, &configpb.AsymmetricKeys{}) + opts := sharesOpts{kekInfos: ki, asymmetricKeys: &configpb.AsymmetricKeys{}} + wrappedShares, _, err := stetClient.wrapShares(ctx, [][]byte{testShare}, opts) if err != nil { t.Fatalf("wrapShares returned with error: %v", err) @@ -461,7 +464,8 @@ func TestWrapUnwrapShareAsymmetricKey(t *testing.T) { } var stetClient StetClient - wrappedShares, keyURIs, err := stetClient.wrapShares(ctx, [][]byte{testShare}, ki, keys) + opts := sharesOpts{kekInfos: ki, asymmetricKeys: keys} + wrappedShares, keyURIs, err := stetClient.wrapShares(ctx, [][]byte{testShare}, opts) if err != nil { t.Fatalf("wrapShares returned with error: %v", err) @@ -479,7 +483,7 @@ func TestWrapUnwrapShareAsymmetricKey(t *testing.T) { t.Fatalf("wrapShares(ctx, %s, %v) expected to return 0 key URIs, got %v", testShare, ki, len(keyURIs)) } - unwrappedShares, err := stetClient.unwrapAndValidateShares(ctx, wrappedShares, ki, keys) + unwrappedShares, err := stetClient.unwrapAndValidateShares(ctx, wrappedShares, opts) if err != nil { t.Fatalf("unwrapAndValidateShares returned with error: %v", err) @@ -570,13 +574,14 @@ func TestWrapUnwrapShareAsymmetricKeyError(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { var stetClient StetClient - wrappedShares, _, err := stetClient.wrapShares(ctx, testCase.unwrappedShares, testCase.kekInfos, testCase.asymmetricKeys) + opts := sharesOpts{kekInfos: testCase.kekInfos, asymmetricKeys: testCase.asymmetricKeys} + wrappedShares, _, err := stetClient.wrapShares(ctx, testCase.unwrappedShares, opts) if err == nil && testCase.errorOnWrap { t.Errorf("wrapShares(%s, %s) expected to return error, but did not", testCase.unwrappedShares, testCase.kekInfos) } - _, err = stetClient.unwrapAndValidateShares(ctx, wrappedShares, testCase.kekInfos, testCase.asymmetricKeys) + _, err = stetClient.unwrapAndValidateShares(ctx, wrappedShares, opts) if err == nil { t.Errorf("unwrapAndValidateShares(%s, %s, %v) expected to return error, but did not", wrappedShares, testCase.kekInfos, testCase.asymmetricKeys) @@ -615,7 +620,8 @@ func TestWrapSharesWithMultipleShares(t *testing.T) { fakeSecureSessionClient: &fakeSecureSessionClient{}, } - wrapped, uris, err := stetClient.wrapShares(ctx, sharesList, kekInfoList, &configpb.AsymmetricKeys{}) + wrapOpts := sharesOpts{kekInfos: kekInfoList, asymmetricKeys: &configpb.AsymmetricKeys{}} + wrapped, uris, err := stetClient.wrapShares(ctx, sharesList, wrapOpts) if err != nil { t.Fatalf("wrapShares(%s, %s) returned with error %v", sharesList, kekInfoList, err) @@ -640,6 +646,104 @@ func TestWrapSharesWithMultipleShares(t *testing.T) { } } +func TestWrapSharesWithConfidentialSpace(t *testing.T) { + ctx := context.Background() + tokenFile := testutil.CreateTempTokenFile(t) + + // Define three test KEKs, each of which should map to a different KMS client. + keks := []struct { + kekURI string + plaintext []byte + expectedSuffix []byte + }{ + {"gcp-kms://test-kek-0", []byte("Share 0"), []byte("-with-credentials")}, + {"gcp-kms://test-kek-1", []byte("Share 1"), []byte("-wip-only-credentials")}, + {"gcp-kms://test-kek-2", []byte("Share 2"), []byte("-no-credentials")}, + } + + // Define credentials for only the KEKs that require them. + csProto := &configpb.ConfidentialSpaceConfigs{ + KekCredentials: []*configpb.KekCredentialConfig{ + { + // A set of credentials. + KekUriPattern: keks[0].kekURI, + WipName: "test WIP name", + ServiceAccount: "test@system.gserviceaccount.com", + }, + { + // Same credentials, but without service account. + KekUriPattern: keks[1].kekURI, + WipName: "test WIP name", + }, + }, + } + + createFakeKMSClient := func(index int) *testutil.FakeKeyManagementClient { + return &testutil.FakeKeyManagementClient{ + GetCryptoKeyFunc: func(_ context.Context, req *kmsspb.GetCryptoKeyRequest, _ ...gax.CallOption) (*kmsrpb.CryptoKey, error) { + return &kmsrpb.CryptoKey{ + Primary: &kmsrpb.CryptoKeyVersion{ + Name: req.GetName(), + State: kmsrpb.CryptoKeyVersion_ENABLED, + ProtectionLevel: kmsrpb.ProtectionLevel_SOFTWARE, + }, + }, nil + }, + EncryptFunc: func(_ context.Context, req *kmsspb.EncryptRequest, _ ...gax.CallOption) (*kmsspb.EncryptResponse, error) { + wrappedShare := append(req.GetPlaintext(), keks[index].expectedSuffix...) + + return &kmsspb.EncryptResponse{ + Name: req.GetName(), + Ciphertext: wrappedShare, + CiphertextCrc32C: wrapperspb.Int64(int64(testutil.CRC32C(wrappedShare))), + VerifiedPlaintextCrc32C: true, + }, nil + }, + } + } + + // Define a fake Client for each KEK credentials (including no credentials). + kmsClients := cloudkms.ClientFactory{ + CredsMap: map[string]cloudkms.Client{ + confspace.CreateJSONCredentials(csProto.GetKekCredentials()[0], tokenFile): createFakeKMSClient(0), + confspace.CreateJSONCredentials(csProto.GetKekCredentials()[1], tokenFile): createFakeKMSClient(1), + "": createFakeKMSClient(2), + }, + } + + var kekInfos []*configpb.KekInfo + var shares [][]byte + for i := 0; i < len(keks); i++ { + shares = append(shares, keks[i].plaintext) + kekInfos = append(kekInfos, &configpb.KekInfo{ + KekType: &configpb.KekInfo_KekUri{KekUri: keks[i].kekURI}, + }) + } + + client := &StetClient{testKMSClients: &kmsClients} + + opts := sharesOpts{ + kekInfos: kekInfos, + asymmetricKeys: &configpb.AsymmetricKeys{}, + confSpaceConfig: confspace.NewConfigWithTokenFile(csProto, tokenFile), + } + wrappedShares, keyURIs, err := client.wrapShares(ctx, shares, opts) + if err != nil { + t.Fatalf("wrapShares returned with error %v", err) + } + if len(keyURIs) != len(shares) { + t.Fatalf("wrapShares did not return the expected number of keyURIs. Got %v, want %v", len(keyURIs), len(shares)) + } + + for i := 0; i < len(keks); i++ { + i := i + expectedShare := append(shares[i], keks[i].expectedSuffix...) + if !bytes.Equal(wrappedShares[i].GetShare(), expectedShare) { + t.Errorf("wrapShares did not return the expected wrapped share for share %v. Got %s, want %s", i, wrappedShares[i].GetShare(), expectedShare) + } + } +} + func TestWrapSharesError(t *testing.T) { testCases := []struct { name string @@ -756,7 +860,8 @@ func TestWrapSharesError(t *testing.T) { }, fakeSecureSessionClient: testCase.fakeSSClient, } - _, _, err := stetClient.wrapShares(ctx, testCase.unwrappedShares, testCase.kekInfos, &configpb.AsymmetricKeys{}) + opts := sharesOpts{kekInfos: testCase.kekInfos, asymmetricKeys: &configpb.AsymmetricKeys{}} + _, _, err := stetClient.wrapShares(ctx, testCase.unwrappedShares, opts) if err == nil { t.Errorf("wrapShares(%s, %s) expected to return error, but did not", testCase.unwrappedShares, testCase.kekInfos) @@ -817,13 +922,13 @@ func TestUnwrapAndValidateSharesIndividually(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - unwrappedShares, err := stetClient.unwrapAndValidateShares(ctx, testCase.wrappedShare, [](*configpb.KekInfo){ - &configpb.KekInfo{ - KekType: &configpb.KekInfo_KekUri{ - KekUri: testCase.uri, - }, + opts := sharesOpts{ + kekInfos: []*configpb.KekInfo{ + &configpb.KekInfo{KekType: &configpb.KekInfo_KekUri{KekUri: testCase.uri}}, }, - }, &configpb.AsymmetricKeys{}) + asymmetricKeys: &configpb.AsymmetricKeys{}, + } + unwrappedShares, err := stetClient.unwrapAndValidateShares(ctx, testCase.wrappedShare, opts) if err != nil { t.Fatalf("unwrapAndValidateShares returned with error: %v", err) @@ -840,6 +945,102 @@ func TestUnwrapAndValidateSharesIndividually(t *testing.T) { } } +func TestUnwrapAndValidateSharesWithConfidentialSpace(t *testing.T) { + ctx := context.Background() + tokenFile := testutil.CreateTempTokenFile(t) + + // Define three test KEKs, each of which should map to a different KMS client. + keks := []struct { + kekURI string + ciphertext []byte + expectedSuffix []byte + }{ + {"gcp-kms://test-kek-0", []byte("Share 0"), []byte("-with-credentials")}, + {"gcp-kms://test-kek-1", []byte("Share 1"), []byte("-wip-only-credentials")}, + {"gcp-kms://test-kek-2", []byte("Share 2"), []byte("-no-credentials")}, + } + + // Define credentials for only the KEKs that require them. + csProto := &configpb.ConfidentialSpaceConfigs{ + KekCredentials: []*configpb.KekCredentialConfig{ + { + // A set of credentials. + KekUriPattern: keks[0].kekURI, + WipName: "test WIP name", + ServiceAccount: "test@system.gserviceaccount.com", + }, + { + // Same credentials, but without service account. + KekUriPattern: keks[1].kekURI, + WipName: "test WIP name", + }, + }, + } + + createFakeKMSClient := func(index int) *testutil.FakeKeyManagementClient { + return &testutil.FakeKeyManagementClient{ + GetCryptoKeyFunc: func(_ context.Context, req *kmsspb.GetCryptoKeyRequest, _ ...gax.CallOption) (*kmsrpb.CryptoKey, error) { + return &kmsrpb.CryptoKey{ + Primary: &kmsrpb.CryptoKeyVersion{ + Name: req.GetName(), + State: kmsrpb.CryptoKeyVersion_ENABLED, + ProtectionLevel: kmsrpb.ProtectionLevel_SOFTWARE, + }, + }, nil + }, + DecryptFunc: func(ctx context.Context, req *kmsspb.DecryptRequest, opts ...gax.CallOption) (*kmsspb.DecryptResponse, error) { + unwrappedShare := append(req.GetCiphertext(), keks[index].expectedSuffix...) + + return &kmsspb.DecryptResponse{ + Plaintext: unwrappedShare, + PlaintextCrc32C: wrapperspb.Int64(int64(testutil.CRC32C(unwrappedShare))), + }, nil + }, + } + } + + // Define a fake Client for each KEK credentials (including no credentials). + kmsClients := cloudkms.ClientFactory{ + CredsMap: map[string]cloudkms.Client{ + confspace.CreateJSONCredentials(csProto.GetKekCredentials()[0], tokenFile): createFakeKMSClient(0), + confspace.CreateJSONCredentials(csProto.GetKekCredentials()[1], tokenFile): createFakeKMSClient(1), + "": createFakeKMSClient(2), + }, + } + + var kekInfos []*configpb.KekInfo + var wrapped []*configpb.WrappedShare + for i := 0; i < len(keks); i++ { + wrapped = append(wrapped, &configpb.WrappedShare{ + Share: keks[i].ciphertext, + Hash: shares.HashShare(append(keks[i].ciphertext, keks[i].expectedSuffix...)), + }) + kekInfos = append(kekInfos, &configpb.KekInfo{ + KekType: &configpb.KekInfo_KekUri{KekUri: keks[i].kekURI}, + }) + } + + client := &StetClient{testKMSClients: &kmsClients} + + opts := sharesOpts{ + kekInfos: kekInfos, + asymmetricKeys: &configpb.AsymmetricKeys{}, + confSpaceConfig: confspace.NewConfigWithTokenFile(csProto, tokenFile), + } + unwrappedShares, err := client.unwrapAndValidateShares(ctx, wrapped, opts) + if err != nil { + t.Fatalf("wrapShares returned with error %v", err) + } + + for i := 0; i < len(keks); i++ { + i := i + expectedShare := append(wrapped[i].GetShare(), keks[i].expectedSuffix...) + if !bytes.Equal(unwrappedShares[i].Share, expectedShare) { + t.Errorf("wrapShares did not return the expected wrapped share for share %v. Got %s, want %s", i, unwrappedShares[i].Share, expectedShare) + } + } +} + func TestUnwrapAndValidateSharesWithMultipleShares(t *testing.T) { // Create lists of shares and kekInfos of appropriate length. share := []byte("expected unwrapped share") @@ -881,7 +1082,8 @@ func TestUnwrapAndValidateSharesWithMultipleShares(t *testing.T) { fakeSecureSessionClient: &fakeSecureSessionClient{}, } - unwrapped, err := stetClient.unwrapAndValidateShares(ctx, wrappedSharesList, kekInfoList, &configpb.AsymmetricKeys{}) + opts := sharesOpts{kekInfos: kekInfoList, asymmetricKeys: &configpb.AsymmetricKeys{}} + unwrapped, err := stetClient.unwrapAndValidateShares(ctx, wrappedSharesList, opts) if err != nil { t.Fatalf("wrapShares returned with error %v", err) @@ -983,7 +1185,9 @@ func TestUnwrapAndValidateSharesError(t *testing.T) { }, fakeSecureSessionClient: testCase.fakeSSClient, } - shares, err := stetClient.unwrapAndValidateShares(ctx, testCase.wrappedShares, testCase.kekInfos, &configpb.AsymmetricKeys{}) + + opts := sharesOpts{kekInfos: testCase.kekInfos, asymmetricKeys: &configpb.AsymmetricKeys{}} + shares, err := stetClient.unwrapAndValidateShares(ctx, testCase.wrappedShares, opts) if testCase.expectedErrSubstr != "" && err == nil { t.Errorf("unwrapAndValidateShares(context.Background(), %s, %s) expected to return error, but did not", testCase.wrappedShares, testCase.kekInfos) @@ -1020,12 +1224,13 @@ func TestWrapAndUnwrapWorkflow(t *testing.T) { fakeSecureSessionClient: &fakeSecureSessionClient{}, } - wrapped, _, err := stetClient.wrapShares(ctx, sharesList, kekInfoList, &configpb.AsymmetricKeys{}) + opts := sharesOpts{kekInfos: kekInfoList, asymmetricKeys: &configpb.AsymmetricKeys{}} + wrapped, _, err := stetClient.wrapShares(ctx, sharesList, opts) if err != nil { t.Fatalf("wrapShares(context.Background(), %v, %v, {}) returned with error %v", sharesList, kekInfoList, err) } - unwrapped, err := stetClient.unwrapAndValidateShares(ctx, wrapped, kekInfoList, &configpb.AsymmetricKeys{}) + unwrapped, err := stetClient.unwrapAndValidateShares(ctx, wrapped, opts) if err != nil { t.Errorf("unwrapAndValidateShares(context.Background(), %v, %v, {}) returned with error %v", wrapped, kekInfoList, err) } @@ -1053,12 +1258,10 @@ func TestEncryptAndDecryptWithNoSplitSucceeds(t *testing.T) { KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true}, } - encryptConfig := &configpb.EncryptConfig{ - KeyConfig: keyConfig, - } - - decryptConfig := &configpb.DecryptConfig{ - KeyConfigs: []*configpb.KeyConfig{keyConfig}, + stetConfig := &configpb.StetConfig{ + EncryptConfig: &configpb.EncryptConfig{KeyConfig: keyConfig}, + DecryptConfig: &configpb.DecryptConfig{KeyConfigs: []*configpb.KeyConfig{keyConfig}}, + AsymmetricKeys: &configpb.AsymmetricKeys{}, } testCases := []struct { @@ -1088,30 +1291,30 @@ func TestEncryptAndDecryptWithNoSplitSucceeds(t *testing.T) { plaintextBuf := bytes.NewReader(tc.plaintext) var ciphertextBuf bytes.Buffer - if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, encryptConfig, &configpb.AsymmetricKeys{}, testBlobID); err != nil { - t.Errorf("Encrypt(ctx, %v, buf, %v, {}, %v) returned error \"%v\", want no error", tc.plaintext, encryptConfig, testBlobID, err) + if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, stetConfig, testBlobID); err != nil { + t.Errorf("Encrypt(ctx, %v, buf, %v, {}, %v) returned error \"%v\", want no error", tc.plaintext, stetConfig.GetEncryptConfig(), testBlobID, err) } // Decrypt the returned data and verify fields. var output bytes.Buffer - decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, decryptConfig, &configpb.AsymmetricKeys{}) + decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, stetConfig) if err != nil { - t.Fatalf("Error calling client.Decrypt(ctx, buf, buf, %v, {}): %v", decryptConfig, err) + t.Fatalf("Error calling client.Decrypt(ctx, buf, buf, %v, {}): %v", stetConfig.GetDecryptConfig(), err) } if decryptedMd.BlobID != testBlobID { - t.Errorf("Decrypt(ctx, input, output, %v, {}) does not contain the expected blob ID. Got %v, want %v", decryptConfig, decryptedMd.BlobID, testBlobID) + t.Errorf("Decrypt(ctx, input, output, %v, {}) does not contain the expected blob ID. Got %v, want %v", stetConfig.GetDecryptConfig(), decryptedMd.BlobID, testBlobID) } if len(decryptedMd.KeyUris) != len(keyConfig.GetKekInfos()) { - t.Fatalf("Decrypt(ctx, input, output, %v, {}) does not have the expected number of key URIS. Got %v, want %v", decryptConfig, len(decryptedMd.KeyUris), len(keyConfig.GetKekInfos())) + t.Fatalf("Decrypt(ctx, input, output, %v, {}) does not have the expected number of key URIS. Got %v, want %v", stetConfig.GetDecryptConfig(), len(decryptedMd.KeyUris), len(keyConfig.GetKekInfos())) } if decryptedMd.KeyUris[0] != kekInfo.GetKekUri() { - t.Errorf("Decrypt(ctx, input, output, %v, {}) does not contain the expected key URI. Got { %v }, want { %v }", decryptConfig, decryptedMd.KeyUris[0], kekInfo.GetKekUri()) + t.Errorf("Decrypt(ctx, input, output, %v, {}) does not contain the expected key URI. Got { %v }, want { %v }", stetConfig.GetDecryptConfig(), decryptedMd.KeyUris[0], kekInfo.GetKekUri()) } if !bytes.Equal(output.Bytes(), tc.plaintext) { - t.Errorf("Decrypt(ctx, input, output, %v, {}) returned ciphertext that does not match original plaintext. Got %v, want %v.", decryptConfig, output.Bytes(), tc.plaintext) + t.Errorf("Decrypt(ctx, input, output, %v, {}) returned ciphertext that does not match original plaintext. Got %v, want %v.", stetConfig.GetDecryptConfig(), output.Bytes(), tc.plaintext) } }) } @@ -1129,8 +1332,9 @@ func TestEncryptFailsForNoSplitWithTooManyKekInfos(t *testing.T) { KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true}, } - encryptConfig := configpb.EncryptConfig{ - KeyConfig: &keyConfig, + stetConfig := &configpb.StetConfig{ + EncryptConfig: &configpb.EncryptConfig{KeyConfig: &keyConfig}, + AsymmetricKeys: &configpb.AsymmetricKeys{}, } plaintext := []byte("This is data to be encrypted.") @@ -1144,7 +1348,7 @@ func TestEncryptFailsForNoSplitWithTooManyKekInfos(t *testing.T) { plaintextBuf := bytes.NewReader(plaintext) var ciphertextBuf bytes.Buffer - if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, &encryptConfig, &configpb.AsymmetricKeys{}, testBlobID); err == nil { + if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, stetConfig, testBlobID); err == nil { t.Errorf("Encrypt with no split option and more than one KekInfo in the KeyConfig should return an error") } } @@ -1166,12 +1370,12 @@ func TestEncryptAndDecryptWithShamirSucceeds(t *testing.T) { KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{shamirConfig}, } - encryptConfig := &configpb.EncryptConfig{ - KeyConfig: keyConfig, - } - - decryptConfig := &configpb.DecryptConfig{ - KeyConfigs: []*configpb.KeyConfig{keyConfig}, + stetConfig := &configpb.StetConfig{ + EncryptConfig: &configpb.EncryptConfig{KeyConfig: keyConfig}, + DecryptConfig: &configpb.DecryptConfig{ + KeyConfigs: []*configpb.KeyConfig{keyConfig}, + }, + AsymmetricKeys: &configpb.AsymmetricKeys{}, } testCases := []struct { @@ -1206,13 +1410,13 @@ func TestEncryptAndDecryptWithShamirSucceeds(t *testing.T) { t.Run(tc.name, func(t *testing.T) { plaintextBuf := bytes.NewReader(tc.plaintext) var ciphertextBuf bytes.Buffer - if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, encryptConfig, &configpb.AsymmetricKeys{}, testBlobID); err != nil { + if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, stetConfig, testBlobID); err != nil { t.Fatalf("Encrypt did not complete successfully: %v", err) } // Decrypt the returned data and verify fields. var output bytes.Buffer - decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, decryptConfig, &configpb.AsymmetricKeys{}) + decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, stetConfig) if err != nil { t.Fatalf("Error decrypting data: %v", err) } @@ -1250,8 +1454,10 @@ func TestEncryptFailsForInvalidShamirConfiguration(t *testing.T) { KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&shamirConfig}, } - encryptConfig := configpb.EncryptConfig{ - KeyConfig: &keyConfig, + stetConfig := &configpb.StetConfig{ + EncryptConfig: &configpb.EncryptConfig{ + KeyConfig: &keyConfig, + }, } plaintext := []byte("This is data to be encrypted.") @@ -1271,7 +1477,7 @@ func TestEncryptFailsForInvalidShamirConfiguration(t *testing.T) { plaintextBuf := bytes.NewReader(plaintext) var ciphertextBuf bytes.Buffer - if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, &encryptConfig, &configpb.AsymmetricKeys{}, testBlobID); err == nil { + if _, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, stetConfig, testBlobID); err == nil { t.Errorf("Encrypt expected to fail due to invalid Shamir's Secret Sharing configuration.") } } @@ -1290,12 +1496,13 @@ func TestEncryptGeneratesUUIDForBlobID(t *testing.T) { KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&shamirConfig}, } - encryptConfig := configpb.EncryptConfig{ - KeyConfig: &keyConfig, - } - - decryptConfig := &configpb.DecryptConfig{ - KeyConfigs: []*configpb.KeyConfig{&keyConfig}, + stetConfig := &configpb.StetConfig{ + EncryptConfig: &configpb.EncryptConfig{ + KeyConfig: &keyConfig, + }, + DecryptConfig: &configpb.DecryptConfig{ + KeyConfigs: []*configpb.KeyConfig{&keyConfig}, + }, } plaintext := []byte("This is data to be encrypted.") @@ -1319,14 +1526,14 @@ func TestEncryptGeneratesUUIDForBlobID(t *testing.T) { plaintextBuf := bytes.NewReader(plaintext) var ciphertextBuf bytes.Buffer - encryptedMd, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, &encryptConfig, &configpb.AsymmetricKeys{}, "") + encryptedMd, err := stetClient.Encrypt(ctx, plaintextBuf, &ciphertextBuf, stetConfig, "") if err != nil { t.Fatalf("Encrypt expected to succeed, but failed with: %v", err.Error()) } // Decrypt to ensure the data can still be decrypted based on the blob ID in the metadata. var output bytes.Buffer - decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, decryptConfig, &configpb.AsymmetricKeys{}) + decryptedMd, err := stetClient.Decrypt(ctx, &ciphertextBuf, &output, stetConfig) if err != nil { t.Fatalf("Error decrypting data: %v", err) } @@ -1348,7 +1555,9 @@ func TestEncryptFailsWithNilConfig(t *testing.T) { plaintextBuf := bytes.NewReader([]byte("This is data to be encrypted.")) var ciphertextBuf bytes.Buffer - if _, err := stetClient.Encrypt(context.Background(), plaintextBuf, &ciphertextBuf, nil, &configpb.AsymmetricKeys{}, ""); err == nil { + + stetConfig := &configpb.StetConfig{EncryptConfig: nil} + if _, err := stetClient.Encrypt(context.Background(), plaintextBuf, &ciphertextBuf, stetConfig, ""); err == nil { t.Errorf("Encrypt expected to fail due to nil EncryptConfig.") } } @@ -1487,8 +1696,13 @@ func TestDecryptErrors(t *testing.T) { } input.Write(ciphertext) + stetConfig := &configpb.StetConfig{ + DecryptConfig: tc.config, + AsymmetricKeys: &configpb.AsymmetricKeys{}, + } + var output bytes.Buffer - if _, err := stetClient.Decrypt(ctx, &input, &output, tc.config, &configpb.AsymmetricKeys{}); err == nil { + if _, err := stetClient.Decrypt(ctx, &input, &output, stetConfig); err == nil { t.Errorf("Got no error, want error related to %q.", tc.errSubstr) } }) diff --git a/client/confidentialspace/BUILD b/client/confidentialspace/BUILD index a31a90e..0bf7f35 100644 --- a/client/confidentialspace/BUILD +++ b/client/confidentialspace/BUILD @@ -32,5 +32,8 @@ go_test( name = "confidentialspace_test", srcs = ["confidentialspace_test.go"], embed = [":confidentialspace"], - deps = ["//proto:config_go_proto"], + deps = [ + "//client/testutil", + "//proto:config_go_proto", + ], ) diff --git a/client/confidentialspace/confidentialspace_test.go b/client/confidentialspace/confidentialspace_test.go index 3a066ae..9ee51f1 100644 --- a/client/confidentialspace/confidentialspace_test.go +++ b/client/confidentialspace/confidentialspace_test.go @@ -18,6 +18,8 @@ import ( "os" "testing" + "github.com/GoogleCloudPlatform/stet/client/testutil" + configpb "github.com/GoogleCloudPlatform/stet/proto/config_go_proto" ) @@ -55,7 +57,7 @@ func TestFileExists(t *testing.T) { func TestFindMatchingCredentials(t *testing.T) { // Create token file. - tokenFile := createTempTokenFile(t) + tokenFile := testutil.CreateTempTokenFile(t) testCfg := &configpb.ConfidentialSpaceConfigs{ KekCredentials: []*configpb.KekCredentialConfig{ @@ -149,7 +151,7 @@ func TestFindMatchingCredentialsWithoutConfidentialSpace(t *testing.T) { // Tests scenarios where we don't expect FindMatchingCredentials to return a match. func TestFindMatchingCredentialsWithoutMatch(t *testing.T) { // Create token file. - tokenFile := createTempTokenFile(t) + tokenFile := testutil.CreateTempTokenFile(t) testCfg := &configpb.ConfidentialSpaceConfigs{ KekCredentials: []*configpb.KekCredentialConfig{ diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 87a05db..3e359f4 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -18,6 +18,8 @@ package testutil import ( "context" "hash/crc32" + "os" + "testing" "cloud.google.com/go/kms/apiv1" "github.com/googleapis/gax-go/v2" @@ -52,7 +54,20 @@ var ( TestSoftwareKEKURI = gcpKMSPrefix + TestSoftwareKEKName ) -func crc32c(data []byte) uint32 { +// CreateTempTokenFile creates a temp directory/file as a stand-in for the attestation token. +func CreateTempTokenFile(t *testing.T) string { + // Create token file. + tempDir := t.TempDir() + tokenFile := tempDir + "test_token" + if err := os.WriteFile(tokenFile, []byte("test token"), 0755); err != nil { + t.Fatalf("Error creating token file at %v: %v", tokenFile, err) + } + + return tokenFile +} + +// CRC32C returns the Castagnoli CRC32 checksum of the given data. +func CRC32C(data []byte) uint32 { t := crc32.MakeTable(crc32.Castagnoli) return crc32.Checksum(data, t) } @@ -125,7 +140,7 @@ func ValidEncryptResponse(req *kmsspb.EncryptRequest) *kmsspb.EncryptResponse { return &kmsspb.EncryptResponse{ Name: req.GetName(), Ciphertext: wrappedShare, - CiphertextCrc32C: wrapperspb.Int64(int64(crc32c(wrappedShare))), + CiphertextCrc32C: wrapperspb.Int64(int64(CRC32C(wrappedShare))), VerifiedPlaintextCrc32C: true, } } @@ -163,7 +178,7 @@ func ValidDecryptResponse(req *kmsspb.DecryptRequest) *kmsspb.DecryptResponse { return &kmsspb.DecryptResponse{ Plaintext: unwrappedShare, - PlaintextCrc32C: wrapperspb.Int64(int64(crc32c(unwrappedShare))), + PlaintextCrc32C: wrapperspb.Int64(int64(CRC32C(unwrappedShare))), } } diff --git a/cmd/stet/main.go b/cmd/stet/main.go index 0c0eeb2..eea1c1f 100644 --- a/cmd/stet/main.go +++ b/cmd/stet/main.go @@ -159,7 +159,7 @@ func (e *encryptCmd) SetFlags(f *flag.FlagSet) { f.BoolVar(&e.quiet, "quiet", false, "Suppress logging output.") } -func (e *encryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { +func (e *encryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...any) subcommands.ExitStatus { yamlBytes, err := os.ReadFile(e.configFile) if err != nil { glog.Errorf("Failed to read config file: %v", err.Error()) @@ -227,7 +227,7 @@ func (e *encryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfac Version: version, } - md, err := c.Encrypt(ctx, inFile, outFile, stetConfig.GetEncryptConfig(), stetConfig.GetAsymmetricKeys(), e.blobID) + md, err := c.Encrypt(ctx, inFile, outFile, stetConfig, e.blobID) if err != nil { glog.Errorf("Failed to encrypt plaintext: %v", err.Error()) return subcommands.ExitFailure @@ -325,7 +325,7 @@ func (d *decryptCmd) SetFlags(f *flag.FlagSet) { f.BoolVar(&d.quiet, "quiet", false, "Suppress logging output.") } -func (d *decryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { +func (d *decryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...any) subcommands.ExitStatus { yamlBytes, err := os.ReadFile(d.configFile) if err != nil { glog.Errorf("Failed to read config file: %v", err.Error()) @@ -392,7 +392,7 @@ func (d *decryptCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfac Version: version, } - md, err := c.Decrypt(ctx, inFile, outFile, stetConfig.GetDecryptConfig(), stetConfig.GetAsymmetricKeys()) + md, err := c.Decrypt(ctx, inFile, outFile, stetConfig) if err != nil { glog.Errorf("Failed to decrypt ciphertext: %v", err.Error()) return subcommands.ExitFailure @@ -430,7 +430,7 @@ func (*versionCmd) Name() string { return "version" } func (*versionCmd) Synopsis() string { return "prints the current version" } func (*versionCmd) Usage() string { return "Usage: stet version" } func (*versionCmd) SetFlags(*flag.FlagSet) {} -func (*versionCmd) Execute(context.Context, *flag.FlagSet, ...interface{}) subcommands.ExitStatus { +func (*versionCmd) Execute(context.Context, *flag.FlagSet, ...any) subcommands.ExitStatus { if version == "" { fmt.Println("Version: development") } else {