Skip to content

Commit

Permalink
Add CloudKMS attestation to STET client
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559139481
Change-Id: Ia607d54666a38a7d4f40580b06206c0e0886b358
  • Loading branch information
jessieqliu authored and copybara-github committed Aug 22, 2023
1 parent 35fc5b2 commit cd01d04
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 80 deletions.
3 changes: 3 additions & 0 deletions client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ go_library(
importpath = "github.com/GoogleCloudPlatform/stet/client",
deps = [
"//client/cloudkms",
"//client/confidentialspace",
"//client/jwt",
"//client/securesession",
"//client/shares",
Expand All @@ -53,6 +54,7 @@ go_test(
embed = [":client"],
deps = [
"//client/cloudkms",
"//client/confidentialspace",
"//client/securesession",
"//client/shares",
"//client/testutil",
Expand All @@ -61,5 +63,6 @@ go_test(
"@com_github_googleapis_gax_go_v2//:go_default_library",
"@go_googleapis//google/cloud/kms/v1:kms_go_proto",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//types/known/wrapperspb",
],
)
65 changes: 51 additions & 14 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"strings"

"github.com/GoogleCloudPlatform/stet/client/cloudkms"
"github.com/GoogleCloudPlatform/stet/client/confidentialspace"
"github.com/GoogleCloudPlatform/stet/client/jwt"
"github.com/GoogleCloudPlatform/stet/client/securesession"
"github.com/GoogleCloudPlatform/stet/client/shares"
Expand Down Expand Up @@ -210,9 +211,15 @@ func getKekURIMetadata(ctx context.Context, kmsClient cloudkms.Client, kekInfo *
// list of wrapped shares, and a list of key URIs used for shares that were
// wrapped by communicating with an external KMS (these lists might not
// correspond one-to-one if some shares are wrapped via asymmetric key).
func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, kekInfos []*configpb.KekInfo, keys *configpb.AsymmetricKeys) (wrappedShares []*configpb.WrappedShare, keyURIs []string, err error) {
if len(unwrappedShares) != len(kekInfos) {
return nil, nil, fmt.Errorf("number of shares to wrap (%d) does not match number of KEKs (%d)", len(unwrappedShares), len(kekInfos))
type sharesOpts struct {
kekInfos []*configpb.KekInfo
asymmetricKeys *configpb.AsymmetricKeys
confSpaceConfig *confidentialspace.Config
}

func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, opts sharesOpts) (wrappedShares []*configpb.WrappedShare, keyURIs []string, err error) {
if len(unwrappedShares) != len(opts.kekInfos) {
return nil, nil, fmt.Errorf("number of shares to wrap (%d) does not match number of KEKs (%d)", len(unwrappedShares), len(opts.kekInfos))
}

var kmsClients *cloudkms.ClientFactory
Expand All @@ -228,11 +235,11 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k
Hash: shares.HashShare(share),
}

kek := kekInfos[i]
kek := opts.kekInfos[i]

switch x := kek.KekType.(type) {
case *configpb.KekInfo_RsaFingerprint:
key, err := PublicKeyForRSAFingerprint(kek, keys)
key, err := PublicKeyForRSAFingerprint(kek, opts.asymmetricKeys)
if err != nil {
return nil, nil, fmt.Errorf("failed to find public key for RSA fingerprint: %w", err)
}
Expand All @@ -243,7 +250,12 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k
}

case *configpb.KekInfo_KekUri:
// Configure CloudKMS Client, with Confidential Space credentials if applicable.
creds := ""
if opts.confSpaceConfig != nil {
creds = opts.confSpaceConfig.FindMatchingCredentials(kek.GetKekUri(), configpb.CredentialMode_ENCRYPT_ONLY_MODE)
}

kmsClient, err := kmsClients.Client(ctx, creds)
if err != nil {
return nil, nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err)
Expand Down Expand Up @@ -293,10 +305,11 @@ func (c *StetClient) wrapShares(ctx context.Context, unwrappedShares [][]byte, k
}

// unwrapAndValidateShares decrypts the given wrapped share based on its URI.
func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares []*configpb.WrappedShare, kekInfos []*configpb.KekInfo, keys *configpb.AsymmetricKeys) ([]shares.UnwrappedShare, error) {
if len(wrappedShares) != len(kekInfos) {
return nil, fmt.Errorf("number of shares to unwrap (%d) does not match number of KEKs (%d)", len(wrappedShares), len(kekInfos))
func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares []*configpb.WrappedShare, opts sharesOpts) ([]shares.UnwrappedShare, error) {
if len(wrappedShares) != len(opts.kekInfos) {
return nil, fmt.Errorf("number of shares to unwrap (%d) does not match number of KEKs (%d)", len(wrappedShares), len(opts.kekInfos))
}

var kmsClients *cloudkms.ClientFactory
if c.testKMSClients != nil {
kmsClients = c.testKMSClients
Expand All @@ -313,11 +326,11 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
for i, wrapped := range wrappedShares {
glog.Infof("Attempting to unwrap share #%v", i+1)
unwrapped := shares.UnwrappedShare{}
kek := kekInfos[i]
kek := opts.kekInfos[i]

switch x := kek.KekType.(type) {
case *configpb.KekInfo_RsaFingerprint:
key, err := PrivateKeyForRSAFingerprint(kek, keys)
key, err := PrivateKeyForRSAFingerprint(kek, opts.asymmetricKeys)
if err != nil {
glog.Warningf("Failed to find public key for RSA fingerprint: %v", err)
continue
Expand All @@ -330,7 +343,12 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
}

case *configpb.KekInfo_KekUri:
// Configure CloudKMS Client, with Confidential Space credentials if applicable.
creds := ""
if opts.confSpaceConfig != nil {
creds = opts.confSpaceConfig.FindMatchingCredentials(kek.GetKekUri(), configpb.CredentialMode_DECRYPT_ONLY_MODE)
}

kmsClient, err := kmsClients.Client(ctx, creds)
if err != nil {
return nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err)
Expand Down Expand Up @@ -386,7 +404,8 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
}

// Encrypt generates a DEK and creates EncryptedData in accordance with the EKM encryption protocol.
func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Writer, config *configpb.EncryptConfig, keys *configpb.AsymmetricKeys, blobID string) (*StetMetadata, error) {
func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig, blobID string) (*StetMetadata, error) {
config := stetConfig.GetEncryptConfig()
if config == nil {
return nil, fmt.Errorf("nil EncryptConfig passed to Encrypt()")
}
Expand All @@ -407,7 +426,15 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri
metadata := &configpb.Metadata{BlobId: blobID, KeyConfig: keyCfg}

var keyURIs []string
metadata.Shares, keyURIs, err = c.wrapShares(ctx, shares, keyCfg.GetKekInfos(), keys)
opts := sharesOpts{
kekInfos: keyCfg.GetKekInfos(),
asymmetricKeys: stetConfig.GetAsymmetricKeys(),
}
if csConfigs := stetConfig.GetConfidentialSpaceConfigs(); csConfigs != nil {
opts.confSpaceConfig = confidentialspace.NewConfig(csConfigs)
}

metadata.Shares, keyURIs, err = c.wrapShares(ctx, shares, opts)
if err != nil {
return nil, fmt.Errorf("error wrapping shares: %v", err)
}
Expand Down Expand Up @@ -442,11 +469,13 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri
KeyUris: keyURIs,
BlobID: metadata.GetBlobId(),
}, nil

}

// Decrypt writes the decrypted data to the `output` writer, and returns the
// key URIs used during decryption and the blob ID decrypted.
func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, config *configpb.DecryptConfig, keys *configpb.AsymmetricKeys) (*StetMetadata, error) {
func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig) (*StetMetadata, error) {
config := stetConfig.GetDecryptConfig()
if config == nil {
return nil, fmt.Errorf("nil DecryptConfig passed to Decrypt()")
}
Expand All @@ -471,7 +500,15 @@ func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Wri
}

// Unwrap shares and validate.
unwrappedShares, err := c.unwrapAndValidateShares(ctx, metadata.GetShares(), matchingKeyConfig.GetKekInfos(), keys)
opts := sharesOpts{
kekInfos: matchingKeyConfig.GetKekInfos(),
asymmetricKeys: stetConfig.GetAsymmetricKeys(),
}
if csConfigs := stetConfig.GetConfidentialSpaceConfigs(); csConfigs != nil {
opts.confSpaceConfig = confidentialspace.NewConfig(csConfigs)
}

unwrappedShares, err := c.unwrapAndValidateShares(ctx, metadata.GetShares(), opts)
if err != nil {
return nil, fmt.Errorf("error unwrapping and validating shares: %v", err)
}
Expand Down
Loading

0 comments on commit cd01d04

Please sign in to comment.