Skip to content

Commit

Permalink
More federation tests (#59)
Browse files Browse the repository at this point in the history
* clean up and refactoring docker tests

* docker test refactoring and another federated test

* add .dockerignore

* add license note to .dockerignore

* make build workflow cancel upon new push

* more streamlining of docker tests

* more refactoring, more tests

* file renaming

* clean up
  • Loading branch information
kkohbrok authored Aug 8, 2023
1 parent 5f1b52a commit 14d41c6
Show file tree
Hide file tree
Showing 37 changed files with 986 additions and 565 deletions.
11 changes: 11 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH <[email protected]>
#
# SPDX-License-Identifier: AGPL-3.0-or-later

.env
target/
*/target/
tests/
Dockerfile
scripts/
migrations/
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

name: Rust

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
push:
branches: ["main"]
Expand Down
2 changes: 2 additions & 0 deletions apiclient/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod qs_api;
/// encryption should be enabled when used in production and there is no load
/// balancer or reverse proxy in front of the server that terminates TLS
/// connections.
#[derive(Clone)]
pub enum TransportEncryption {
On,
Off,
Expand Down Expand Up @@ -66,6 +67,7 @@ impl std::fmt::Display for DomainOrAddress {

// ApiClient is a wrapper around a reqwest client.
// It exposes a single function for each API endpoint.
#[derive(Clone)]
pub struct ApiClient {
client: Client,
domain_or_address: DomainOrAddress,
Expand Down
27 changes: 2 additions & 25 deletions backend/src/auth_service/credentials/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use std::collections::HashMap;

use mls_assist::{
openmls::prelude::{HashType, OpenMlsCrypto, OpenMlsCryptoProvider, SignatureScheme},
openmls_rust_crypto::OpenMlsRustCrypto,
Expand Down Expand Up @@ -44,7 +42,7 @@ use self::keys::ClientVerifyingKey;

use super::AsClientId;

#[derive(Clone, Debug, PartialEq, Eq, TlsDeserializeBytes, TlsSerialize, TlsSize)]
#[derive(Clone, Debug, PartialEq, Eq, TlsDeserializeBytes, TlsSerialize, TlsSize, Hash)]
pub struct CredentialFingerprint {
value: Vec<u8>,
}
Expand Down Expand Up @@ -263,7 +261,7 @@ pub struct VerifiableAsIntermediateCredential {
}

impl VerifiableAsIntermediateCredential {
pub fn fingerprint(&self) -> &CredentialFingerprint {
pub fn signer_fingerprint(&self) -> &CredentialFingerprint {
&self.credential.signer_fingerprint
}
}
Expand Down Expand Up @@ -408,27 +406,6 @@ impl ClientCredential {
pub fn verifying_key(&self) -> &ClientVerifyingKey {
&self.payload.csr.verifying_key
}

pub fn decrypt_and_verify(
ear_key: &ClientCredentialEarKey,
ciphertext: &EncryptedClientCredential,
as_intermediate_credentials: &HashMap<Fqdn, Vec<AsIntermediateCredential>>,
) -> Result<Self, ClientCredentialProcessingError> {
let verifiable_credential = VerifiableClientCredential::decrypt(ear_key, ciphertext)
.map_err(|_| ClientCredentialProcessingError::DecryptionError)?;
let as_credential = as_intermediate_credentials
.get(&verifiable_credential.domain())
.expect("Could not find AS credentials for domain")
.iter()
.find(|as_cred| {
&as_cred.fingerprint().unwrap() == verifiable_credential.signer_fingerprint()
})
.ok_or(ClientCredentialProcessingError::NoMatchingAsCredential)?;
let client_credential = verifiable_credential
.verify(as_credential.verifying_key())
.map_err(|_| ClientCredentialProcessingError::VerificationError)?;
Ok(client_credential)
}
}

#[derive(Debug, Clone)]
Expand Down
25 changes: 18 additions & 7 deletions backend/src/ds/add_users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,35 @@ impl DsGroupState {
.zip(params.encrypted_welcome_attribution_infos.into_iter())
{
let fqdn = key_package_batch.homeserver_domain().clone();

let key_package_batch: KeyPackageBatch<VERIFIED> =
if let Some(verifying_key) = verifying_keys.get(&fqdn) {
key_package_batch
.verify(verifying_key)
.map_err(|_| AddUsersError::InvalidKeyPackageBatch)?
key_package_batch.verify(verifying_key).map_err(|e| {
tracing::warn!(
"Error verifying key package batch with pre-fetched key: {:?}",
e
);
AddUsersError::InvalidKeyPackageBatch
})?
} else {
let verifying_key = qs_provider
.verifying_key(&fqdn)
.verifying_key(fqdn.clone())
.await
.map_err(|_| AddUsersError::FailedToObtainVerifyingKey)?;
let kpb = key_package_batch
.verify(&verifying_key)
.map_err(|_| AddUsersError::InvalidKeyPackageBatch)?;
let kpb = key_package_batch.verify(&verifying_key).map_err(|e| {
tracing::warn!(
"Error verifying key package batch with freshly fetched key: {:?}",
e
);
AddUsersError::InvalidKeyPackageBatch
})?;
verifying_keys.insert(fqdn, verifying_key);
kpb
};

// Validate freshness of the batch.
if key_package_batch.has_expired(KEYPACKAGEBATCH_EXPIRATION_DAYS) {
tracing::warn!("Key package batch has expired");
return Err(AddUsersError::InvalidKeyPackageBatch);
}

Expand All @@ -193,6 +203,7 @@ impl DsGroupState {
// KeyPackages belonging to one user in the tree.
key_packages.push(added_client);
} else {
tracing::warn!("Incomplete KeyPackageBatch");
return Err(AddUsersError::InvalidKeyPackageBatch);
}
}
Expand Down
25 changes: 6 additions & 19 deletions backend/src/messages/client_as_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use std::collections::HashMap;

use mls_assist::openmls::prelude::GroupId;
use tls_codec::{DeserializeBytes, Serialize, TlsDeserializeBytes, TlsSerialize, TlsSize};

use crate::{
auth_service::{
credentials::{
keys::AsIntermediateVerifyingKey, AsCredential, AsIntermediateCredential,
ClientCredential, CredentialFingerprint, ExpirationData,
VerifiableAsIntermediateCredential, VerifiableClientCredential,
keys::AsIntermediateVerifyingKey, AsCredential, ClientCredential,
CredentialFingerprint, ExpirationData, VerifiableAsIntermediateCredential,
VerifiableClientCredential,
},
errors::AsVerificationError,
storage_provider_trait::{AsEphemeralStorageProvider, AsStorageProvider},
Expand All @@ -34,7 +32,6 @@ use crate::{
},
ConnectionDecryptionKey, ConnectionEncryptionKey, RatchetEncryptionKey,
},
qs::Fqdn,
};

use super::{
Expand Down Expand Up @@ -444,24 +441,14 @@ impl ConnectionEstablishmentPackageIn {
&self.payload.sender_client_credential
}

pub fn verify_all(
pub fn verify(
self,
as_intermediate_credentials: &HashMap<Fqdn, Vec<AsIntermediateCredential>>,
verifying_key: &AsIntermediateVerifyingKey,
) -> ConnectionEstablishmentPackageTbs {
let as_intermediate_credentials = as_intermediate_credentials
.get(&self.payload.sender_client_credential.domain())
.unwrap();
let as_credential = as_intermediate_credentials
.iter()
.find(|as_cred| {
&as_cred.fingerprint().unwrap()
== self.payload.sender_client_credential.signer_fingerprint()
})
.unwrap();
let sender_client_credential: ClientCredential = self
.payload
.sender_client_credential
.verify(as_credential.verifying_key())
.verify(verifying_key)
.unwrap();
ConnectionEstablishmentPackageTbs {
sender_client_credential,
Expand Down
9 changes: 8 additions & 1 deletion backend/src/messages/qs_qs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@ use crate::qs::Fqdn;

use super::{intra_backend::DsFanOutMessage, MlsInfraVersion};

#[derive(TlsSerialize, TlsDeserializeBytes, TlsSize)]
#[repr(u8)]
pub enum QsToQsPayload {
FanOutMessageRequest(DsFanOutMessage),
VerificationKeyRequest,
}

#[derive(TlsSerialize, TlsDeserializeBytes, TlsSize)]
pub struct QsToQsMessage {
pub protocol_version: MlsInfraVersion,
pub sender: Fqdn,
pub recipient: Fqdn,
pub fan_out_message: DsFanOutMessage,
pub payload: QsToQsPayload,
// TODO: Signature
}
62 changes: 57 additions & 5 deletions backend/src/qs/ds_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@ use tls_codec::Serialize;

use crate::{
crypto::hpke::HpkeDecryptable,
messages::{intra_backend::DsFanOutMessage, qs_qs::QsToQsMessage, MlsInfraVersion},
messages::{
intra_backend::DsFanOutMessage,
qs_qs::{QsToQsMessage, QsToQsPayload},
MlsInfraVersion,
},
};

use super::{
errors::QsEnqueueError, network_provider_trait::NetworkProvider,
storage_provider_trait::QsStorageProvider, ClientConfig, Qs, WebsocketNotifier,
errors::{QsEnqueueError, QsVerifyingKeyError},
network_provider_trait::NetworkProvider,
qs_api::FederatedProcessingResult,
storage_provider_trait::QsStorageProvider,
ClientConfig, Fqdn, Qs, QsVerifyingKey, WebsocketNotifier,
};

impl Qs {
Expand All @@ -35,7 +42,7 @@ impl Qs {
protocol_version: MlsInfraVersion::Alpha,
sender: own_domain.clone(),
recipient: message.client_reference.client_homeserver_domain.clone(),
fan_out_message: message.clone(),
payload: QsToQsPayload::FanOutMessageRequest(message.clone()),
};
let serialized_message = qs_to_qs_message
.tls_serialize_detached()
Expand All @@ -46,7 +53,14 @@ impl Qs {
message.client_reference.client_homeserver_domain,
)
.await
.map_err(QsEnqueueError::NetworkError)?
.map_err(QsEnqueueError::NetworkError)
.and_then(|result| {
if matches!(result, FederatedProcessingResult::Ok) {
Ok(())
} else {
Err(QsEnqueueError::InvalidResponse)
}
})?
} else {
let decryption_key = storage_provider
.load_decryption_key()
Expand Down Expand Up @@ -80,4 +94,42 @@ impl Qs {
// TODO: client now has new ratchet key, store it in the storage
// provider.
}

/// Fetch the verifying key of the server with the given domain
#[tracing::instrument(skip_all, err)]
pub async fn verifying_key<S: QsStorageProvider, N: NetworkProvider>(
storage_provider: &S,
network_provider: &N,
domain: Fqdn,
) -> Result<QsVerifyingKey, QsVerifyingKeyError> {
let own_domain = storage_provider.own_domain().await;
let verifying_key = if domain != own_domain {
let qs_to_qs_message = QsToQsMessage {
protocol_version: MlsInfraVersion::Alpha,
sender: own_domain.clone(),
recipient: domain.clone(),
payload: QsToQsPayload::VerificationKeyRequest,
};
let serialized_message = qs_to_qs_message
.tls_serialize_detached()
.map_err(|_| QsVerifyingKeyError::LibraryError)?;
let result = network_provider
.deliver(serialized_message, domain)
.await
.map_err(|_| QsVerifyingKeyError::InvalidResponse)?;
if let FederatedProcessingResult::VerifyingKey(verifying_key) = result {
verifying_key
} else {
return Err(QsVerifyingKeyError::InvalidResponse);
}
} else {
storage_provider
.load_signing_key()
.await
.map_err(|_| QsVerifyingKeyError::StorageError)?
.verifying_key()
.clone()
};
Ok(verifying_key)
}
}
18 changes: 14 additions & 4 deletions backend/src/qs/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub enum QsEnqueueError<S: QsStorageProvider, N: NetworkProvider> {
/// Unrecoverable implementation error
#[error("Library Error")]
LibraryError,
/// Invalid response
#[error("Invalid response")]
InvalidResponse,
}

/// Error enqueuing a fanned-out message.
Expand Down Expand Up @@ -197,6 +200,7 @@ pub enum QsClientKeyPackageError {
#[derive(Error, Debug, Clone, TlsSerialize, TlsDeserializeBytes, TlsSize)]
#[repr(u8)]
pub enum QsKeyPackageBatchError {
/// Library error
#[error("Library Error")]
LibraryError,
/// Decryption error
Expand All @@ -213,16 +217,22 @@ pub enum QsKeyPackageBatchError {
#[derive(Error, Debug, Clone, TlsSerialize, TlsDeserializeBytes, TlsSize)]
#[repr(u8)]
pub enum QsVerifyingKeyError {
/// Error retrieving user key packages
#[error("Error retrieving user key packages")]
/// Library error
#[error("Library Error")]
LibraryError,
/// Error retrieving verifying key
#[error("Error retrieving verifying key")]
StorageError,
/// Invalid response from remote QS
#[error("Invalid response from remote QS")]
InvalidResponse,
}

#[derive(Error, Debug, Clone, TlsSerialize, TlsDeserializeBytes, TlsSize)]
#[repr(u8)]
pub enum QsEncryptionKeyError {
/// Error retrieving user key packages
#[error("Error retrieving user key packages")]
/// Error retrieving encryption key
#[error("Error retrieving encryption key")]
StorageError,
}

Expand Down
8 changes: 7 additions & 1 deletion backend/src/qs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub trait QsConnector: Sync + Send + std::fmt::Debug + 'static {
type EnqueueError: std::fmt::Debug;
type VerifyingKeyError;
async fn dispatch(&self, message: DsFanOutMessage) -> Result<(), Self::EnqueueError>;
async fn verifying_key(&self, fqdn: &Fqdn) -> Result<QsVerifyingKey, Self::VerifyingKeyError>;
async fn verifying_key(&self, domain: Fqdn) -> Result<QsVerifyingKey, Self::VerifyingKeyError>;
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -403,6 +403,12 @@ impl From<&str> for Fqdn {
}
}

impl From<String> for Fqdn {
fn from(domain: String) -> Self {
domain.as_str().into()
}
}

#[derive(
Clone,
Debug,
Expand Down
8 changes: 6 additions & 2 deletions backend/src/qs/network_provider_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ use async_trait::async_trait;
use std::error::Error;
use std::fmt::Debug;

use super::Fqdn;
use super::{qs_api::FederatedProcessingResult, Fqdn};

#[async_trait]
pub trait NetworkProvider: Sync + Send + Debug + 'static {
type NetworkError: Error + Debug + Clone;

async fn deliver(&self, bytes: Vec<u8>, destination: Fqdn) -> Result<(), Self::NetworkError>;
async fn deliver(
&self,
bytes: Vec<u8>,
destination: Fqdn,
) -> Result<FederatedProcessingResult, Self::NetworkError>;
}
Loading

0 comments on commit 14d41c6

Please sign in to comment.