diff --git a/Cargo.toml b/Cargo.toml index 3e0ac98..2b42d66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ serde_bytes = { version = "0.11", features = ["std"] } serde_with = { version = "1.5", default_features = false } openssl = { version = "0.10", optional = true } tss-esapi = { version = "6.1", optional = true } +aws-sdk-kms = { version = "0.16", optional = true } +tokio = { version = "1.20", features = ["rt"], optional = true } [dependencies.serde] version = "1.0" @@ -23,8 +25,10 @@ features = ["derive"] [dev-dependencies] hex = "0.4" +tokio = { version = "1.20", features = ["macros"] } [features] default = ["key_openssl_pkey"] key_openssl_pkey = ["openssl"] key_tpm = ["tss-esapi", "openssl"] +key_kms = ["aws-sdk-kms", "tokio"] diff --git a/src/crypto/kms.rs b/src/crypto/kms.rs new file mode 100644 index 0000000..84ca645 --- /dev/null +++ b/src/crypto/kms.rs @@ -0,0 +1,237 @@ +//! KMS implementation for cryptography + +use openssl::{ + bn::BigNum, + ecdsa::EcdsaSig, + hash::MessageDigest, + pkey::{PKey, Public}, +}; +use tokio::runtime::Runtime; + +use aws_sdk_kms::{ + error::{VerifyError, VerifyErrorKind}, + model::{MessageType, SigningAlgorithmSpec}, + Blob, Client, SdkError, +}; + +use crate::{ + crypto::{ec_curve_to_parameters, SigningPrivateKey, SigningPublicKey}, + error::CoseError, + sign::SignatureAlgorithm, +}; + +/// A reference to an AWS KMS key and client +pub struct KmsKey { + client: Client, + key_id: String, + + sig_alg: SignatureAlgorithm, + + public_key: Option>, + + runtime: Runtime, +} + +impl KmsKey { + fn new_runtime() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Error creating tokio runtime") + } + + /// Create a new KmsKey, using the specified client and key_id. + /// + /// The sig_alg needs to be valid for the specified key. + /// This version will use the KMS Verify call to verify signatures. + /// + /// AWS Permissions required on the specified key: + /// - Sign (for creating new signatures) + /// - Verify (for verifying existing signatures) + pub fn new( + client: Client, + key_id: String, + sig_alg: SignatureAlgorithm, + ) -> Result { + Ok(KmsKey { + client, + key_id, + sig_alg, + + public_key: None, + + runtime: Self::new_runtime(), + }) + } + + /// Create a new KmsKey, using the specified client and key_id. + /// + /// The sig_alg needs to be valid for the specified key. + /// This version will use local signature verification. + /// If no public key is passed in, the key will be retrieved with GetPublicKey. + /// + /// AWS Permissions required on the specified key: + /// - Sign (for creating new signatures) + /// - GetPublicKey (to get the public key if it wasn't passed in) + #[cfg(feature = "key_openssl_pkey")] + pub fn new_with_public_key( + client: Client, + key_id: String, + public_key: Option>, + ) -> Result { + let runtime = Self::new_runtime(); + let public_key = match public_key { + Some(key) => key, + None => { + // Retrieve public key from AWS + let request = client.get_public_key().key_id(key_id.clone()).send(); + + let public_key = runtime + .block_on(request) + .map_err(CoseError::AwsGetPublicKeyError)? + .public_key + .ok_or_else(|| { + CoseError::UnsupportedError("No public key returned".to_string()) + })?; + + PKey::public_key_from_der(public_key.as_ref()).map_err(CoseError::SignatureError)? + } + }; + + let curve_name = public_key + .ec_key() + .map_err(|_| CoseError::UnsupportedError("Non-EC keys are not supported".to_string()))? + .group() + .curve_name() + .ok_or_else(|| { + CoseError::UnsupportedError("Anonymous EC keys are not supported".to_string()) + })?; + let sig_alg = ec_curve_to_parameters(curve_name)?.0; + + Ok(KmsKey { + client, + key_id, + + sig_alg, + public_key: Some(public_key), + + runtime, + }) + } + + fn get_sig_alg_spec(&self) -> SigningAlgorithmSpec { + match self.sig_alg { + SignatureAlgorithm::ES256 => SigningAlgorithmSpec::EcdsaSha256, + SignatureAlgorithm::ES384 => SigningAlgorithmSpec::EcdsaSha384, + SignatureAlgorithm::ES512 => SigningAlgorithmSpec::EcdsaSha512, + } + } + + #[cfg(feature = "key_openssl_pkey")] + fn verify_with_public_key(&self, data: &[u8], signature: &[u8]) -> Result { + self.public_key.as_ref().unwrap().verify(data, signature) + } +} + +impl SigningPublicKey for KmsKey { + fn get_parameters(&self) -> Result<(SignatureAlgorithm, MessageDigest), CoseError> { + Ok((self.sig_alg, self.sig_alg.suggested_message_digest())) + } + + fn verify(&self, data: &[u8], signature: &[u8]) -> Result { + if self.public_key.is_some() { + #[cfg(feature = "key_openssl_pkey")] + return self.verify_with_public_key(data, signature); + + #[cfg(not(feature = "key_openssl_pkey"))] + panic!("Would have been impossible to get public_key set"); + } else { + // Call KMS to verify + + // Recover the R and S factors from the signature contained in the object + let (bytes_r, bytes_s) = signature.split_at(self.sig_alg.key_length()); + + let r = BigNum::from_slice(&bytes_r).map_err(CoseError::SignatureError)?; + let s = BigNum::from_slice(&bytes_s).map_err(CoseError::SignatureError)?; + + let sig = EcdsaSig::from_private_components(r, s).map_err(CoseError::SignatureError)?; + let sig = sig.to_der().map_err(CoseError::SignatureError)?; + + let request = self + .client + .verify() + .key_id(self.key_id.clone()) + .message(Blob::new(data.to_vec())) + .message_type(MessageType::Digest) + .signing_algorithm(self.get_sig_alg_spec()) + .signature(Blob::new(sig)) + .send(); + + let reply = self.runtime.block_on(request); + + match reply { + Ok(v) => Ok(v.signature_valid), + Err(SdkError::ServiceError { + err: + VerifyError { + kind: VerifyErrorKind::KmsInvalidSignatureException(_), + .. + }, + .. + }) => Ok(false), + Err(e) => Err(CoseError::AwsVerifyError(e)), + } + } + } +} + +impl SigningPrivateKey for KmsKey { + fn sign(&self, data: &[u8]) -> Result, CoseError> { + let request = self + .client + .sign() + .key_id(self.key_id.clone()) + .message(Blob::new(data.to_vec())) + .message_type(MessageType::Digest) + .signing_algorithm(self.get_sig_alg_spec()) + .send(); + + let signature = self + .runtime + .block_on(request) + .map_err(CoseError::AwsSignError)? + .signature + .ok_or_else(|| CoseError::UnsupportedError("No signature returned".to_string()))?; + + let signature = + EcdsaSig::from_der(signature.as_ref()).map_err(CoseError::SignatureError)?; + + let key_length = self.sig_alg.key_length(); + + // The spec defines the signature as: + // Signature = I2OSP(R, n) | I2OSP(S, n), where n = ceiling(key_length / 8) + // The Signer interface doesn't provide this, so this will use EcdsaSig interface instead + // and concatenate R and S. + // See https://tools.ietf.org/html/rfc8017#section-4.1 for details. + let bytes_r = signature.r().to_vec(); + let bytes_s = signature.s().to_vec(); + + // These should *never* exceed ceiling(key_length / 8) + assert!(bytes_r.len() <= key_length); + assert!(bytes_s.len() <= key_length); + + let mut signature_bytes = vec![0u8; key_length * 2]; + + // This is big-endian encoding so padding might be added at the start if the factor is + // too short. + let offset_copy = key_length - bytes_r.len(); + signature_bytes[offset_copy..offset_copy + bytes_r.len()].copy_from_slice(&bytes_r); + + // This is big-endian encoding so padding might be added at the start if the factor is + // too short. + let offset_copy = key_length - bytes_s.len() + key_length; + signature_bytes[offset_copy..offset_copy + bytes_s.len()].copy_from_slice(&bytes_s); + + Ok(signature_bytes) + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index e680e0c..6155bba 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -14,6 +14,8 @@ mod openssl; #[cfg(feature = "openssl")] pub use self::openssl::Openssl; +#[cfg(feature = "key_kms")] +pub mod kms; #[cfg(feature = "key_openssl_pkey")] mod openssl_pkey; #[cfg(feature = "key_tpm")] diff --git a/src/error.rs b/src/error.rs index 7b23cc1..6f4df79 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,6 +33,15 @@ pub enum CoseError { /// TPM error occured #[cfg(feature = "key_tpm")] TpmError(tss_esapi::Error), + /// AWS sign error occured + #[cfg(feature = "key_kms")] + AwsSignError(aws_sdk_kms::SdkError), + /// AWS verify error occured + #[cfg(feature = "key_kms")] + AwsVerifyError(aws_sdk_kms::SdkError), + /// AWS GetPublicKey error occured + #[cfg(all(feature = "key_kms", feature = "key_openssl_pkey"))] + AwsGetPublicKeyError(aws_sdk_kms::SdkError), } impl fmt::Display for CoseError { @@ -51,6 +60,12 @@ impl fmt::Display for CoseError { CoseError::EncryptionError(e) => write!(f, "Encryption error: {}", e), #[cfg(feature = "key_tpm")] CoseError::TpmError(e) => write!(f, "TPM error: {}", e), + #[cfg(feature = "key_kms")] + CoseError::AwsSignError(e) => write!(f, "AWS sign error: {}", e), + #[cfg(feature = "key_kms")] + CoseError::AwsVerifyError(e) => write!(f, "AWS verify error: {}", e), + #[cfg(all(feature = "key_kms", feature = "key_openssl_pkey"))] + CoseError::AwsGetPublicKeyError(e) => write!(f, "AWS GetPublicKey error: {}", e), } } } diff --git a/src/sign.rs b/src/sign.rs index b5533ad..182a04b 100644 --- a/src/sign.rs +++ b/src/sign.rs @@ -460,7 +460,7 @@ mod tests { fn map_serialization() { // Empty map let map: HeaderMap = HeaderMap::new(); - assert_eq!(map_to_empty_or_serialized(&map).unwrap(), []); + assert_eq!(map_to_empty_or_serialized(&map).unwrap(), [] as [u8; 0]); // Checks that the body_protected field will be serialized correctly let map: HeaderMap = SignatureAlgorithm::ES256.into(); @@ -1288,4 +1288,149 @@ mod tests { } } } + + #[cfg(feature = "key_kms")] + mod kms { + use std::str::FromStr; + + use super::TEXT; + use crate::{crypto::kms::KmsKey, sign::*}; + + use std::env; + + #[tokio::test] + async fn cose_sign_kms() { + let config = aws_config::from_env().load().await; + let kms_client = aws_sdk_kms::Client::new(&config); + + tokio::task::spawn_blocking(|| { + let key_id = + env::var("AWS_KMS_TEST_KEY_ARN").expect("Please set AWS_KMS_TEST_KEY_ARN"); + + let sig_alg = env::var("TEST_KEY_SIG_ALG").expect("Please set TEST_KEY_SIG_ALG"); + let sig_alg = + SignatureAlgorithm::from_str(&sig_alg).expect("Invalid TEST_KEY_SIG_ALG"); + + let kms_key = + KmsKey::new(kms_client, key_id, sig_alg).expect("Error building kms_key"); + + let mut map = HeaderMap::new(); + map.insert(CborValue::Integer(4), CborValue::Bytes(b"11".to_vec())); + let cose_doc1 = CoseSign1::new(TEXT, &map, &kms_key).unwrap(); + let tagged_bytes = cose_doc1.as_bytes(true).unwrap(); + + // Tag 6.18 should be present + assert_eq!(tagged_bytes[0], 6 << 5 | 18); + let cose_doc2 = CoseSign1::from_bytes(&tagged_bytes).unwrap(); + + assert_eq!( + cose_doc1.get_payload(None).unwrap(), + cose_doc2.get_payload(Some(&kms_key)).unwrap() + ); + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn cose_sign_kms_invalid_signature() { + let config = aws_config::from_env().load().await; + let kms_client = aws_sdk_kms::Client::new(&config); + + tokio::task::spawn_blocking(|| { + let key_id = + env::var("AWS_KMS_TEST_KEY_ARN").expect("Please set AWS_KMS_TEST_KEY_ARN"); + + let sig_alg = env::var("TEST_KEY_SIG_ALG").expect("Please set TEST_KEY_SIG_ALG"); + let sig_alg = + SignatureAlgorithm::from_str(&sig_alg).expect("Invalid TEST_KEY_SIG_ALG"); + let kms_key = + KmsKey::new(kms_client, key_id, sig_alg).expect("Error building kms_key"); + + let mut map = HeaderMap::new(); + map.insert(CborValue::Integer(4), CborValue::Bytes(b"11".to_vec())); + let mut cose_doc1 = CoseSign1::new(TEXT, &map, &kms_key).unwrap(); + + // Mangle the signature + cose_doc1.signature[0] ^= 0xff; + + let tagged_bytes = cose_doc1.as_bytes(true).unwrap(); + let cose_doc2 = CoseSign1::from_bytes(&tagged_bytes).unwrap(); + + match cose_doc2.get_payload(Some(&kms_key)) { + Ok(_) => panic!("Did not fail"), + Err(CoseError::UnverifiedSignature) => {} + Err(e) => { + panic!("Unexpected error: {:?}", e) + } + } + }) + .await + .unwrap(); + } + + #[cfg(feature = "key_openssl_pkey")] + #[tokio::test] + async fn cose_sign_kms_public_key() { + let config = aws_config::from_env().load().await; + let kms_client = aws_sdk_kms::Client::new(&config); + + let key_id = env::var("AWS_KMS_TEST_KEY_ARN").expect("Please set AWS_KMS_TEST_KEY_ARN"); + + tokio::task::spawn_blocking(|| { + let kms_key = KmsKey::new_with_public_key(kms_client, key_id, None) + .expect("Error building kms_key"); + + let mut map = HeaderMap::new(); + map.insert(CborValue::Integer(4), CborValue::Bytes(b"11".to_vec())); + let cose_doc1 = CoseSign1::new(TEXT, &map, &kms_key).unwrap(); + let tagged_bytes = cose_doc1.as_bytes(true).unwrap(); + + // Tag 6.18 should be present + assert_eq!(tagged_bytes[0], 6 << 5 | 18); + let cose_doc2 = CoseSign1::from_bytes(&tagged_bytes).unwrap(); + + assert_eq!( + cose_doc1.get_payload(None).unwrap(), + cose_doc2.get_payload(Some(&kms_key)).unwrap() + ); + }) + .await + .unwrap(); + } + + #[cfg(feature = "key_openssl_pkey")] + #[tokio::test] + async fn cose_sign_kms_public_key_invalid_signature() { + let config = aws_config::from_env().load().await; + let kms_client = aws_sdk_kms::Client::new(&config); + + let key_id = env::var("AWS_KMS_TEST_KEY_ARN").expect("Please set AWS_KMS_TEST_KEY_ARN"); + + tokio::task::spawn_blocking(|| { + let kms_key = KmsKey::new_with_public_key(kms_client, key_id, None) + .expect("Error building kms_key"); + + let mut map = HeaderMap::new(); + map.insert(CborValue::Integer(4), CborValue::Bytes(b"11".to_vec())); + let mut cose_doc1 = CoseSign1::new(TEXT, &map, &kms_key).unwrap(); + + // Mangle the signature + cose_doc1.signature[0] ^= 0xff; + + let tagged_bytes = cose_doc1.as_bytes(true).unwrap(); + let cose_doc2 = CoseSign1::from_bytes(&tagged_bytes).unwrap(); + + match cose_doc2.get_payload(Some(&kms_key)) { + Ok(_) => panic!("Did not fail"), + Err(CoseError::UnverifiedSignature) => {} + Err(e) => { + panic!("Unexpected error: {:?}", e) + } + } + }) + .await + .unwrap(); + } + } }