From b7e6a8ddd448832025c2fa30d19e8d91b1dfd77f Mon Sep 17 00:00:00 2001 From: Konrad Kohbrok Date: Wed, 19 Jul 2023 14:56:45 +0200 Subject: [PATCH] federation prep --- backend/src/auth_service/mod.rs | 45 ++- backend/src/auth_service/username.rs | 24 -- backend/src/messages/intra_backend.rs | 6 +- backend/src/messages/mod.rs | 1 + backend/src/messages/qs_qs.rs | 15 + backend/src/qs/client_api/key_packages.rs | 2 +- backend/src/qs/dns_provider_trait.rs | 17 + backend/src/qs/ds_api.rs | 34 +- backend/src/qs/errors.rs | 10 +- backend/src/qs/mod.rs | 37 +- backend/src/qs/network_provider_trait.rs | 16 + backend/src/qs/qs_api.rs | 3 + backend/src/qs/storage_provider_trait.rs | 6 +- coreclient/src/conversations/store.rs | 4 +- coreclient/src/groups/mod.rs | 10 +- coreclient/src/types/mod.rs | 19 + coreclient/src/users/mod.rs | 114 +++--- coreclient/src/users/process.rs | 18 +- server/Cargo.toml | 1 + server/src/lib.rs | 1 + server/src/main.rs | 8 +- server/src/network_provider.rs | 76 ++++ server/src/storage_provider/memory/qs.rs | 14 +- .../storage_provider/memory/qs_connector.rs | 19 +- .../src/storage_provider/memory/tests/qs.rs | 2 +- server/tests/mod.rs | 202 +++++----- server/tests/qs/ws.rs | 10 +- server/tests/utils/mod.rs | 13 +- server/tests/utils/setup.rs | 344 +++++++++++++----- 29 files changed, 716 insertions(+), 355 deletions(-) delete mode 100644 backend/src/auth_service/username.rs create mode 100644 backend/src/messages/qs_qs.rs create mode 100644 backend/src/qs/dns_provider_trait.rs create mode 100644 backend/src/qs/network_provider_trait.rs create mode 100644 backend/src/qs/qs_api.rs create mode 100644 server/src/network_provider.rs diff --git a/backend/src/auth_service/mod.rs b/backend/src/auth_service/mod.rs index 20126e16..a7387495 100644 --- a/backend/src/auth_service/mod.rs +++ b/backend/src/auth_service/mod.rs @@ -10,7 +10,10 @@ use opaque_ke::{ RegistrationResponse, RegistrationUpload, ServerRegistration, }; use serde::{Deserialize, Serialize}; -use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize}; +use tls_codec::{ + DeserializeBytes as TlsDeserializeTrait, Serialize as TlsSerializeTrait, TlsDeserializeBytes, + TlsSerialize, TlsSize, +}; use crate::{ crypto::{ratchet::QueueRatchet, OpaqueCiphersuite, RandomnessError, RatchetEncryptionKey}, @@ -25,6 +28,7 @@ use crate::{ client_as_out::VerifiableClientToAsMessage, EncryptedAsQueueMessage, }, + qs::Fqdn, }; use self::{ @@ -42,7 +46,6 @@ pub mod invitations; pub mod key_packages; pub mod registration; pub mod storage_provider_trait; -pub mod username; /* Actions: @@ -146,39 +149,53 @@ impl AsUserRecord { )] pub struct UserName { pub(crate) user_name: Vec, + pub(crate) domain: Fqdn, } impl From> for UserName { fn from(value: Vec) -> Self { - Self { user_name: value } + Self::tls_deserialize_exact(&value).unwrap() } } +// TODO: This string processing is way too simplistic, but it should do for now. impl From<&str> for UserName { fn from(value: &str) -> Self { - Self { - user_name: value.as_bytes().to_vec(), - } + let mut split_name = value.split('@'); + let name = split_name.next().unwrap(); + // UserNames MUST be qualified + let domain = split_name.next().unwrap(); + assert!(split_name.next().is_none()); + let domain = domain.into(); + let user_name = name.as_bytes().to_vec(); + Self { user_name, domain } } } impl UserName { - pub fn as_bytes(&self) -> &[u8] { - &self.user_name + pub fn to_bytes(&self) -> Vec { + self.tls_serialize_detached().unwrap() + } + + pub fn domain(&self) -> Fqdn { + self.domain.clone() } } impl From for UserName { fn from(value: String) -> Self { - Self { - user_name: value.into_bytes(), - } + value.as_str().into() } } -impl ToString for UserName { - fn to_string(&self) -> String { - String::from_utf8(self.user_name.clone()).unwrap() +impl std::fmt::Display for UserName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}@{}", + String::from_utf8_lossy(&self.user_name), + self.domain + ) } } diff --git a/backend/src/auth_service/username.rs b/backend/src/auth_service/username.rs deleted file mode 100644 index f22cbb8e..00000000 --- a/backend/src/auth_service/username.rs +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -pub struct Username { - text: String, -} - -impl Username { - pub fn from_text(text: &str) -> Result { - // TODO: validate username - Ok(Username { - text: text.to_string(), - }) - } - - pub fn as_str(&self) -> &str { - self.text.as_str() - } -} - -pub enum UsernameValidationError { - Invalid, -} diff --git a/backend/src/messages/intra_backend.rs b/backend/src/messages/intra_backend.rs index 83e24b1f..e3b6b47d 100644 --- a/backend/src/messages/intra_backend.rs +++ b/backend/src/messages/intra_backend.rs @@ -6,7 +6,7 @@ //! passed internally within the backend. use mls_assist::messages::SerializedMlsMessage; -use tls_codec::{TlsDeserializeBytes, TlsSize}; +use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize}; use crate::qs::QsClientReference; @@ -14,13 +14,13 @@ use super::client_ds::{EventMessage, QsQueueMessagePayload}; // === DS to QS === -#[derive(TlsDeserializeBytes, TlsSize)] +#[derive(TlsSerialize, TlsDeserializeBytes, TlsSize)] pub struct DsFanOutMessage { pub payload: DsFanOutPayload, pub client_reference: QsClientReference, } -#[derive(Clone, TlsDeserializeBytes, TlsSize)] +#[derive(Clone, TlsSerialize, TlsDeserializeBytes, TlsSize)] #[repr(u8)] pub enum DsFanOutPayload { QueueMessage(QsQueueMessagePayload), diff --git a/backend/src/messages/mod.rs b/backend/src/messages/mod.rs index 4e267d8a..037de840 100644 --- a/backend/src/messages/mod.rs +++ b/backend/src/messages/mod.rs @@ -19,6 +19,7 @@ pub mod client_ds_out; pub mod client_qs; pub mod client_qs_out; pub mod intra_backend; +pub mod qs_qs; #[derive( Serialize, diff --git a/backend/src/messages/qs_qs.rs b/backend/src/messages/qs_qs.rs new file mode 100644 index 00000000..ee96a4ee --- /dev/null +++ b/backend/src/messages/qs_qs.rs @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +use crate::qs::Fqdn; + +use super::{intra_backend::DsFanOutMessage, MlsInfraVersion}; + +pub struct QsToQsMessage { + protocol_version: MlsInfraVersion, + sender: Fqdn, + recipient: Fqdn, + fan_out_message: DsFanOutMessage, + // TODO: Signature +} diff --git a/backend/src/qs/client_api/key_packages.rs b/backend/src/qs/client_api/key_packages.rs index e0aee232..13e3018f 100644 --- a/backend/src/qs/client_api/key_packages.rs +++ b/backend/src/qs/client_api/key_packages.rs @@ -129,7 +129,7 @@ impl Qs { .map_err(|_| QsKeyPackageBatchError::StorageError)?; let key_package_batch_tbs = KeyPackageBatchTbs { - homeserver_domain: config.fqdn.clone(), + homeserver_domain: config.domain.clone(), key_package_refs, time_of_signature: TimeStamp::now(), }; diff --git a/backend/src/qs/dns_provider_trait.rs b/backend/src/qs/dns_provider_trait.rs new file mode 100644 index 00000000..db215119 --- /dev/null +++ b/backend/src/qs/dns_provider_trait.rs @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +use async_trait::async_trait; +use std::error::Error; +use std::fmt::Debug; +use std::net::SocketAddr; + +use super::Fqdn; + +#[async_trait] +pub trait DnsProvider: Sync + Send + 'static { + type DnsError: Error + Debug + Clone; + + async fn resolve(&self, fqdn: &Fqdn) -> Result; +} diff --git a/backend/src/qs/ds_api.rs b/backend/src/qs/ds_api.rs index 0a85b2e7..5527f23a 100644 --- a/backend/src/qs/ds_api.rs +++ b/backend/src/qs/ds_api.rs @@ -2,11 +2,13 @@ // // SPDX-License-Identifier: AGPL-3.0-or-later +use tls_codec::Serialize; + use crate::{crypto::hpke::HpkeDecryptable, messages::intra_backend::DsFanOutMessage}; use super::{ - errors::QsEnqueueError, storage_provider_trait::QsStorageProvider, ClientConfig, Qs, - WebsocketNotifier, + errors::QsEnqueueError, network_provider_trait::NetworkProvider, + storage_provider_trait::QsStorageProvider, ClientConfig, Qs, WebsocketNotifier, }; impl Qs { @@ -15,15 +17,33 @@ impl Qs { /// quickly. It can attempt to do the full fanout and return potential /// failed transmissions to the DS. /// - /// This endpoint is used for enqueining - /// messages in both local and remote queues, depending on the FQDN of the - /// client. For now, only local queues are supported. + /// This endpoint is used for enqueining messages in both local and remote + /// queues, depending on the FQDN of the client. #[tracing::instrument(skip_all, err)] - pub async fn enqueue_message( + pub async fn enqueue_message( storage_provider: &S, websocket_notifier: &W, + network_provider: &N, message: DsFanOutMessage, - ) -> Result<(), QsEnqueueError> { + ) -> Result<(), QsEnqueueError> { + if message.client_reference.client_homeserver_domain != storage_provider.own_domain().await + { + tracing::info!( + "Domains differ. Destination domain: {:?}, own domain: {:?}", + message.client_reference.client_homeserver_domain, + storage_provider.own_domain().await + ); + let serialized_message = message + .tls_serialize_detached() + .map_err(|_| QsEnqueueError::LibraryError)?; + network_provider + .deliver( + serialized_message, + message.client_reference.client_homeserver_domain, + ) + .await + .map_err(QsEnqueueError::NetworkError)?; + } let decryption_key = storage_provider .load_decryption_key() .await diff --git a/backend/src/qs/errors.rs b/backend/src/qs/errors.rs index 44f271bf..d04db71f 100644 --- a/backend/src/qs/errors.rs +++ b/backend/src/qs/errors.rs @@ -4,7 +4,7 @@ use crate::crypto::DecryptionError; -use super::storage_provider_trait::QsStorageProvider; +use super::{network_provider_trait::NetworkProvider, storage_provider_trait::QsStorageProvider}; use thiserror::Error; use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize}; @@ -12,7 +12,7 @@ use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize}; /// Error fetching a message from the QS. #[derive(Error, Debug, Clone)] -pub enum QsEnqueueError { +pub enum QsEnqueueError { /// Couldn't find the requested queue. #[error("Couldn't find the requested queue")] QueueNotFound, @@ -22,9 +22,15 @@ pub enum QsEnqueueError { /// An error ocurred enqueueing in a fan out queue #[error(transparent)] EnqueueError(#[from] EnqueueError), + /// An error ocurred while sending a message to the network + #[error("An error ocurred while sending a message to the network")] + NetworkError(N::NetworkError), /// Storage provider error #[error("Storage provider error")] StorageError, + /// Unrecoverable implementation error + #[error("Library Error")] + LibraryError, } /// Error enqueuing a fanned-out message. diff --git a/backend/src/qs/mod.rs b/backend/src/qs/mod.rs index de4d7625..eee8bd17 100644 --- a/backend/src/qs/mod.rs +++ b/backend/src/qs/mod.rs @@ -61,6 +61,8 @@ //! smaller than the smalles requested one and responds with the requested //! messages. +use std::fmt::{Display, Formatter}; + use crate::{ crypto::{ ear::{ @@ -95,8 +97,10 @@ use self::errors::SealError; pub mod client_api; pub mod client_record; +pub mod dns_provider_trait; pub mod ds_api; pub mod errors; +pub mod network_provider_trait; pub mod storage_provider_trait; pub mod user_record; @@ -360,7 +364,7 @@ impl VerifyingKey for QsVerifyingKey {} #[derive(Debug, Clone)] pub struct QsConfig { - pub fqdn: Fqdn, + pub domain: Fqdn, } #[derive(Debug)] @@ -379,7 +383,36 @@ pub struct Qs {} Hash, Debug, )] -pub struct Fqdn {} +pub struct Fqdn { + // TODO: We should probably use a more restrictive type here. + domain: Vec, +} + +impl Display for Fqdn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", String::from_utf8_lossy(&self.domain)) + } +} + +impl Fqdn { + pub fn new(domain: String) -> Self { + Self { + domain: domain.into_bytes(), + } + } + + pub fn as_bytes(&self) -> &[u8] { + &self.domain + } +} + +impl From<&str> for Fqdn { + fn from(domain: &str) -> Self { + Self { + domain: domain.as_bytes().to_vec(), + } + } +} #[derive( Clone, diff --git a/backend/src/qs/network_provider_trait.rs b/backend/src/qs/network_provider_trait.rs new file mode 100644 index 00000000..b21dbdbd --- /dev/null +++ b/backend/src/qs/network_provider_trait.rs @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +use async_trait::async_trait; +use std::error::Error; +use std::fmt::Debug; + +use super::Fqdn; + +#[async_trait] +pub trait NetworkProvider: Sync + Send + Debug + 'static { + type NetworkError: Error + Debug + Clone; + + async fn deliver(&self, bytes: Vec, destination: Fqdn) -> Result<(), Self::NetworkError>; +} diff --git a/backend/src/qs/qs_api.rs b/backend/src/qs/qs_api.rs new file mode 100644 index 00000000..fafc7ac6 --- /dev/null +++ b/backend/src/qs/qs_api.rs @@ -0,0 +1,3 @@ +// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH +// +// SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/backend/src/qs/storage_provider_trait.rs b/backend/src/qs/storage_provider_trait.rs index 21a8c79d..17e3fc18 100644 --- a/backend/src/qs/storage_provider_trait.rs +++ b/backend/src/qs/storage_provider_trait.rs @@ -10,8 +10,8 @@ use tls_codec::{DeserializeBytes, Serialize, Size}; use crate::messages::{FriendshipToken, QueueMessage}; use super::{ - client_record::QsClientRecord, user_record::QsUserRecord, ClientIdDecryptionKey, QsClientId, - QsConfig, QsEncryptedAddPackage, QsSigningKey, QsUserId, + client_record::QsClientRecord, user_record::QsUserRecord, ClientIdDecryptionKey, Fqdn, + QsClientId, QsConfig, QsEncryptedAddPackage, QsSigningKey, QsUserId, }; /// Storage provider trait for the QS. @@ -35,6 +35,8 @@ pub trait QsStorageProvider: Sync + Send + Debug + 'static { type LoadConfigError: Error + Debug + Clone; + async fn own_domain(&self) -> Fqdn; + // === USERS === /// Returns a new unique user ID. diff --git a/coreclient/src/conversations/store.rs b/coreclient/src/conversations/store.rs index 1a4d4f94..1ce6d32c 100644 --- a/coreclient/src/conversations/store.rs +++ b/coreclient/src/conversations/store.rs @@ -37,9 +37,7 @@ impl ConversationStore { id: uuid_bytes.clone(), group_id: uuid_bytes.clone(), status: ConversationStatus::Active, - conversation_type: ConversationType::UnconfirmedConnection( - user_name.as_bytes().to_vec(), - ), + conversation_type: ConversationType::UnconfirmedConnection(user_name.to_bytes()), last_used: Timestamp::now().as_u64(), attributes, }; diff --git a/coreclient/src/groups/mod.rs b/coreclient/src/groups/mod.rs index caefac79..c7ed747f 100644 --- a/coreclient/src/groups/mod.rs +++ b/coreclient/src/groups/mod.rs @@ -1282,10 +1282,10 @@ impl Group { } /// Returns the [`AsClientId`] of the clients owned by the given user. - pub(crate) fn user_client_ids(&self, user_name: UserName) -> Vec { + pub(crate) fn user_client_ids(&self, user_name: &UserName) -> Vec { let mut user_clients = vec![]; for (_index, (cred, _sek)) in self.client_information.iter() { - if cred.identity().user_name() == user_name { + if &cred.identity().user_name() == user_name { user_clients.push(cred.identity()) } } @@ -1418,7 +1418,7 @@ pub(crate) fn application_message_to_conversation_messages( application_message: ApplicationMessage, ) -> Vec { vec![new_conversation_message(Message::Content(ContentMessage { - sender: sender.identity().user_name().as_bytes().to_vec(), + sender: sender.identity().user_name().to_bytes(), content: MessageContentType::tls_deserialize( &mut application_message.into_bytes().as_slice(), ) @@ -1456,8 +1456,8 @@ pub(crate) fn staged_commit_to_conversation_messages( }; let event_message = format!( "{} added {} to the conversation", - String::from_utf8_lossy(get_user_name(client_information, sender).as_bytes()), - String::from_utf8_lossy(get_user_name(client_information, free_index).as_bytes()) + get_user_name(client_information, sender), + get_user_name(client_information, free_index) ); event_message_from_string(event_message) }) diff --git a/coreclient/src/types/mod.rs b/coreclient/src/types/mod.rs index 3947575c..f4b55b2f 100644 --- a/coreclient/src/types/mod.rs +++ b/coreclient/src/types/mod.rs @@ -5,6 +5,7 @@ use std::collections::HashSet; use openmls::prelude::GroupId; +use phnxbackend::auth_service::UserName; use tls_codec::{TlsDeserialize, TlsSerialize, TlsSize}; //use phnxbackend::auth_service::UserName; use uuid::Uuid; @@ -121,6 +122,24 @@ pub struct InactiveConversation { pub past_members: HashSet, } +impl InactiveConversation { + pub fn new(past_members: HashSet) -> Self { + Self { + past_members: past_members + .iter() + .map(|s| s.to_string()) + .collect::>(), + } + } + + pub fn past_members(&self) -> HashSet { + self.past_members + .iter() + .map(|s| UserName::from(s.clone())) + .collect() + } +} + #[derive(PartialEq, Debug, Clone)] pub enum ConversationType { // A connection conversation that is not yet confirmed by the other party. diff --git a/coreclient/src/users/mod.rs b/coreclient/src/users/mod.rs index 21587c4e..a0aba5d1 100644 --- a/coreclient/src/users/mod.rs +++ b/coreclient/src/users/mod.rs @@ -45,8 +45,8 @@ use phnxbackend::{ FriendshipToken, MlsInfraVersion, QueueMessage, }, qs::{ - AddPackage, ClientConfig, ClientIdEncryptionKey, Fqdn, QsClientId, QsClientReference, - QsUserId, QsVerifyingKey, + AddPackage, ClientConfig, ClientIdEncryptionKey, QsClientId, QsClientReference, QsUserId, + QsVerifyingKey, }, }; use rand::rngs::OsRng; @@ -111,13 +111,12 @@ pub struct SelfUser { impl SelfUser { /// Create a new user with the given name and a fresh set of credentials. pub async fn new( - user_name: &str, + user_name: UserName, password: &str, address: SocketAddr, notification_hub: NotificationHub, ) -> Self { log::debug!("Creating new user {}", user_name); - let user_name: UserName = UserName::from(user_name.to_string()); let crypto_backend = OpenMlsRustCrypto::default(); // Let's turn TLS off for now. let api_client = ApiClient::initialize(address, TransportEncryption::Off).unwrap(); @@ -167,8 +166,9 @@ impl SelfUser { // Complete the OPAQUE registration. let address = api_client.address().clone().to_string(); + let user_name_bytes = user_name.to_bytes(); let identifiers = Identifiers { - client: Some(user_name.as_bytes()), + client: Some(&user_name_bytes), server: Some(address.as_bytes()), }; let response_parameters = ClientRegistrationFinishParameters::new(identifiers, None); @@ -212,7 +212,7 @@ impl SelfUser { let connection_decryption_key = ConnectionDecryptionKey::generate().unwrap(); // Mutable, because we need to access the leaf signers later. - let mut key_store = MemoryUserKeyStore { + let key_store = MemoryUserKeyStore { signing_key, as_queue_decryption_key, as_queue_ratchet: as_initial_ratchet_secret.clone().try_into().unwrap(), @@ -284,22 +284,32 @@ impl SelfUser { .await .unwrap(); + let mut user = Self { + crypto_backend, + api_client, + user_name, + conversation_store: ConversationStore::default(), + group_store: GroupStore::default(), + key_store, + qs_user_id: create_user_record_response.user_id, + qs_client_id: create_user_record_response.client_id, + qs_client_sequence_number_start: 0, + as_client_sequence_number_start: 0, + contacts: HashMap::default(), + partial_contacts: HashMap::default(), + notification_hub, + }; + let mut qs_add_packages = vec![]; for _ in 0..ADD_PACKAGES { // TODO: Which key do we need to use for encryption here? Probably // the client credential ear key, since friends need to be able to // decrypt it. We might want to use a separate key, though. - let (kp, signature_ear_key, leaf_signer) = SelfUser::::generate_keypackage( - &crypto_backend, - &key_store.signing_key, - &create_user_record_response.client_id, - Some(&key_store.push_token_ear_key), - &key_store.qs_client_id_encryption_key, - ); + let (kp, signature_ear_key, leaf_signer) = user.generate_keypackage(); let esek = signature_ear_key - .encrypt(&key_store.signature_ear_key_wrapper_key) + .encrypt(&user.key_store.signature_ear_key_wrapper_key) .unwrap(); - key_store.leaf_signers.insert( + user.key_store.leaf_signers.insert( leaf_signer.credential().verifying_key().clone(), (leaf_signer, signature_ear_key), ); @@ -309,42 +319,25 @@ impl SelfUser { } // Upload add packages - api_client + user.api_client .qs_publish_key_packages( - create_user_record_response.client_id.clone(), + user.qs_client_id.clone(), qs_add_packages, - key_store.add_package_ear_key.clone(), - &key_store.qs_client_signing_key, + user.key_store.add_package_ear_key.clone(), + &user.key_store.qs_client_signing_key, ) .await .unwrap(); - Self { - crypto_backend, - api_client, - user_name, - conversation_store: ConversationStore::default(), - group_store: GroupStore::default(), - key_store, - qs_user_id: create_user_record_response.user_id, - qs_client_id: create_user_record_response.client_id, - qs_client_sequence_number_start: 0, - as_client_sequence_number_start: 0, - contacts: HashMap::default(), - partial_contacts: HashMap::default(), - notification_hub, - } + user } pub(crate) fn generate_keypackage( - crypto_backend: &impl OpenMlsCryptoProvider, - signing_key: &ClientSigningKey, - qs_client_id: &QsClientId, - push_token_ear_key_option: Option<&PushTokenEarKey>, - qs_encryption_key: &ClientIdEncryptionKey, + &self, ) -> (KeyPackage, SignatureEarKey, InfraCredentialSigningKey) { let signature_ear_key = SignatureEarKey::random().unwrap(); - let leaf_signer = InfraCredentialSigningKey::generate(signing_key, &signature_ear_key); + let leaf_signer = + InfraCredentialSigningKey::generate(&self.key_store.signing_key, &signature_ear_key); let credential_with_key = CredentialWithKey { credential: leaf_signer.credential().clone().into(), signature_key: leaf_signer.credential().verifying_key().clone(), @@ -357,12 +350,18 @@ impl SelfUser { Some(&SUPPORTED_CREDENTIALS), ); let sealed_reference = ClientConfig { - client_id: qs_client_id.clone(), - push_token_ear_key: push_token_ear_key_option.cloned(), + client_id: self.qs_client_id.clone(), + push_token_ear_key: Some(self.key_store.push_token_ear_key.clone()), } - .encrypt(qs_encryption_key, &[], &[]); + .encrypt(&self.key_store.qs_client_id_encryption_key, &[], &[]); let client_reference = QsClientReference { - client_homeserver_domain: Fqdn {}, + client_homeserver_domain: self + .key_store + .signing_key + .credential() + .identity() + .user_name() + .domain(), sealed_reference, }; let extension = Extension::Unknown( @@ -378,7 +377,7 @@ impl SelfUser { ciphersuite: CIPHERSUITE, version: ProtocolVersion::Mls10, }, - crypto_backend, + &self.crypto_backend, &leaf_signer, credential_with_key, ) @@ -433,7 +432,7 @@ impl SelfUser { pub async fn invite_users( &mut self, conversation_id: Uuid, - invited_users: &[&str], + invited_users: &[UserName], ) -> Result<(), CorelibError> { let conversation = self .conversation_store @@ -487,7 +486,7 @@ impl SelfUser { pub async fn remove_users( &mut self, conversation_id: Uuid, - target_users: &[&str], + target_users: &[UserName], ) -> Result<(), CorelibError> { let conversation = self .conversation_store @@ -499,8 +498,8 @@ impl SelfUser { .get_group_mut(&group_id.as_group_id()) .unwrap(); let mut clients = vec![]; - for &user_name in target_users { - let mut user_clients = group.user_client_ids(user_name.into()); + for user_name in target_users { + let mut user_clients = group.user_client_ids(user_name); clients.append(&mut user_clients); } let params = group.remove(&self.crypto_backend, clients).unwrap(); @@ -565,7 +564,7 @@ impl SelfUser { // Store message locally let message = Message::Content(ContentMessage { content: message, - sender: self.user_name.as_bytes().to_vec(), + sender: self.user_name.to_bytes(), }); let conversation_message = new_conversation_message(message); self.conversation_store @@ -586,8 +585,7 @@ impl SelfUser { Ok(conversation_message) } - pub async fn add_contact(&mut self, user_name: &str) { - let user_name: UserName = user_name.to_string().into(); + pub async fn add_contact(&mut self, user_name: &UserName) { let params = UserConnectionPackagesParams { user_name: user_name.clone(), }; @@ -705,7 +703,7 @@ impl SelfUser { conversation_id, friendship_package_ear_key, }; - self.partial_contacts.insert(user_name, contact); + self.partial_contacts.insert(user_name.clone(), contact); // Encrypt the connection establishment package for each connection and send it off. for connection_package in verified_connection_packages { @@ -883,7 +881,7 @@ impl SelfUser { } .encrypt(&self.key_store.qs_client_id_encryption_key, &[], &[]); QsClientReference { - client_homeserver_domain: Fqdn {}, + client_homeserver_domain: self.user_name().domain(), sealed_reference, } } @@ -898,17 +896,17 @@ impl SelfUser { } /// Returns None if there is no conversation with the given id. - pub fn group_members(&self, conversation_id: Uuid) -> Option> { + pub fn group_members(&self, conversation_id: Uuid) -> Option> { self.group(conversation_id).map(|group| { group .members() .iter() - .map(|member| member.to_string()) + .map(|member| member.clone()) .collect() }) } - pub fn pending_removes(&self, conversation_id: Uuid) -> Option> { + pub fn pending_removes(&self, conversation_id: Uuid) -> Option> { self.group(conversation_id).map(|group| { group .mls_group() @@ -916,7 +914,7 @@ impl SelfUser { .filter_map(|proposal| match proposal.proposal() { Proposal::Remove(rp) => group .client_by_index(rp.removed().usize()) - .map(|c| c.user_name().to_string()), + .map(|c| c.user_name()), _ => None, }) .collect() diff --git a/coreclient/src/users/process.rs b/coreclient/src/users/process.rs index 32a66b72..7156a5a7 100644 --- a/coreclient/src/users/process.rs +++ b/coreclient/src/users/process.rs @@ -11,7 +11,6 @@ use phnxbackend::messages::client_ds::{ }; use phnxbackend::messages::client_ds_out::ExternalCommitInfoIn; use phnxbackend::messages::QueueMessage; -use phnxbackend::qs::{ClientConfig, Fqdn, QsClientReference}; use tls_codec::DeserializeBytes; use super::*; @@ -122,7 +121,7 @@ impl SelfUser { { // Check if it was an external commit and if the user name matches if matches!(sender, Sender::NewMemberCommit) - && &sender_credential.identity().user_name().as_bytes() + && &sender_credential.identity().user_name().to_bytes() == user_name { // Load up the partial contact and decrypt the friendship package @@ -400,19 +399,8 @@ impl SelfUser { self.contacts.insert(user_name, contact); // TODO: Send conversation message to UI. - let sealed_reference = ClientConfig { - client_id: self.qs_client_id.clone(), - push_token_ear_key: Some(self.key_store.push_token_ear_key.clone()), - } - .encrypt( - &self.key_store.qs_client_id_encryption_key, - &[], - &[], - ); - let qs_client_reference = QsClientReference { - client_homeserver_domain: Fqdn {}, - sealed_reference, - }; + let qs_client_reference = self.create_own_client_reference(); + // Send the confirmation by way of commit and group info to the DS. self.api_client .ds_join_connection_group( diff --git a/server/Cargo.toml b/server/Cargo.toml index afee33fa..c259b829 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -45,6 +45,7 @@ privacypass = { git = "https://github.com/raphaelrobert/privacypass" } privacypass-middleware = { git = "https://github.com/phnx-im/pp-middleware" } opaque-ke = { version = "3.0.0-pre.1", features = ["argon2"]} tls_codec = { workspace = true } +reqwest = { version = "^0.11", features = ["json", "rustls-tls-webpki-roots", "brotli"] } [dependencies.sqlx] optional = true diff --git a/server/src/lib.rs b/server/src/lib.rs index bfe4788e..4eb80c50 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod configurations; pub mod endpoints; +pub mod network_provider; pub mod storage_provider; pub mod telemetry; diff --git a/server/src/main.rs b/server/src/main.rs index b34ffce0..9645346f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -9,6 +9,7 @@ use phnxbackend::qs::Fqdn; use phnxserver::{ configurations::*, endpoints::qs::ws::DispatchWebsocketNotifier, + network_provider::MockNetworkProvider, run, storage_provider::memory::{ auth_service::{EphemeralAsStorage, MemoryAsStorage}, @@ -34,15 +35,18 @@ async fn main() -> std::io::Result<()> { configuration.application.host, configuration.application.port ); let listener = TcpListener::bind(address).expect("Failed to bind to random port."); + let domain: Fqdn = configuration.application.host.as_str().into(); + let network_provider = Arc::new(MockNetworkProvider::new()); let ds_storage_provider = MemoryDsStorage::new(); - let qs_storage_provider = Arc::new(MemStorageProvider::default()); - let as_storage_provider = MemoryAsStorage::new(Fqdn {}, SignatureScheme::ED25519).unwrap(); + let qs_storage_provider = Arc::new(MemStorageProvider::new(domain.clone())); + let as_storage_provider = MemoryAsStorage::new(domain, SignatureScheme::ED25519).unwrap(); let as_ephemeral_storage_provider = EphemeralAsStorage::default(); let ws_dispatch_notifier = DispatchWebsocketNotifier::default_addr(); let qs_connector = MemoryEnqueueProvider { storage: qs_storage_provider.clone(), notifier: ws_dispatch_notifier.clone(), + network: network_provider, }; // Start the server diff --git a/server/src/network_provider.rs b/server/src/network_provider.rs new file mode 100644 index 00000000..616a4426 --- /dev/null +++ b/server/src/network_provider.rs @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +use std::{collections::HashMap, sync::Mutex}; + +use crate::endpoints::ENDPOINT_QS; +use async_trait::async_trait; +use phnxbackend::qs::{network_provider_trait::NetworkProvider, Fqdn}; +use reqwest::Client; +use thiserror::Error; + +#[derive(Debug, Error, Clone)] +pub enum MockNetworkError {} + +#[derive(Debug)] +pub enum TransportEncryption { + On, + Off, +} + +#[derive(Debug)] +pub struct MockNetworkProvider { + backend_ports: Mutex>, + client: Client, + transport_encryption: TransportEncryption, +} + +impl MockNetworkProvider { + pub fn new() -> Self { + Self { + backend_ports: Mutex::new(HashMap::new()), + client: Client::new(), + transport_encryption: TransportEncryption::Off, + } + } + + pub fn add_port(&self, fqdn: Fqdn, port: u16) { + self.backend_ports.lock().unwrap().insert(fqdn, port); + } +} + +#[async_trait] +impl NetworkProvider for MockNetworkProvider { + type NetworkError = MockNetworkError; + + async fn deliver(&self, bytes: Vec, destination: Fqdn) -> Result<(), Self::NetworkError> { + let transport_encryption = match self.transport_encryption { + TransportEncryption::On => "s", + TransportEncryption::Off => "", + }; + tracing::info!("Currently registered {:?}", self.backend_ports); + tracing::info!("Sending to {:?}", destination); + let port = self + .backend_ports + .lock() + .unwrap() + .get(&destination) + .unwrap() + .to_owned(); + // For now, we don't resolve the actual hostname and just send to + // localhost. + let destination = "localhost"; + let url = format!( + "http{}://{}:{}{}", + transport_encryption, destination, port, ENDPOINT_QS + ); + match self.client.post(url).body(bytes).send().await { + // For now we don't care about the response. + Ok(_response) => (), + // TODO: We only care about the happy path for now. + Err(e) => panic!("Error: {}", e), + } + Ok(()) + } +} diff --git a/server/src/storage_provider/memory/qs.rs b/server/src/storage_provider/memory/qs.rs index 6199c5cc..8d05dea5 100644 --- a/server/src/storage_provider/memory/qs.rs +++ b/server/src/storage_provider/memory/qs.rs @@ -14,7 +14,7 @@ use phnxbackend::{ messages::{FriendshipToken, QueueMessage}, qs::{ client_record::QsClientRecord, storage_provider_trait::QsStorageProvider, - user_record::QsUserRecord, ClientIdDecryptionKey, QsClientId, QsConfig, + user_record::QsUserRecord, ClientIdDecryptionKey, Fqdn, QsClientId, QsConfig, QsEncryptedAddPackage, QsSigningKey, QsUserId, }, }; @@ -48,11 +48,9 @@ pub struct MemStorageProvider { config: QsConfig, } -impl Default for MemStorageProvider { - fn default() -> Self { - let config = QsConfig { - fqdn: phnxbackend::qs::Fqdn {}, - }; +impl MemStorageProvider { + pub fn new(domain: Fqdn) -> Self { + let config = QsConfig { domain }; let client_id_decryption_key = ClientIdDecryptionKey::generate().unwrap(); let signing_key = QsSigningKey::generate().unwrap(); let users = RwLock::new(HashMap::new()); @@ -88,6 +86,10 @@ impl QsStorageProvider for MemStorageProvider { type LoadConfigError = LoadConfigError; + async fn own_domain(&self) -> Fqdn { + self.config.domain.clone() + } + async fn create_user( &self, user_record: QsUserRecord, diff --git a/server/src/storage_provider/memory/qs_connector.rs b/server/src/storage_provider/memory/qs_connector.rs index 7d9a6c2a..4beb5ddc 100644 --- a/server/src/storage_provider/memory/qs_connector.rs +++ b/server/src/storage_provider/memory/qs_connector.rs @@ -8,26 +8,33 @@ use async_trait::async_trait; use phnxbackend::{ messages::intra_backend::DsFanOutMessage, qs::{ - errors::QsEnqueueError, storage_provider_trait::QsStorageProvider, Fqdn, Qs, QsConnector, - QsVerifyingKey, + errors::QsEnqueueError, network_provider_trait::NetworkProvider, + storage_provider_trait::QsStorageProvider, Fqdn, Qs, QsConnector, QsVerifyingKey, }, }; use crate::endpoints::qs::ws::DispatchWebsocketNotifier; #[derive(Debug)] -pub struct MemoryEnqueueProvider { +pub struct MemoryEnqueueProvider { pub storage: Arc, pub notifier: DispatchWebsocketNotifier, + pub network: Arc, } #[async_trait] -impl QsConnector for MemoryEnqueueProvider { - type EnqueueError = QsEnqueueError; +impl QsConnector for MemoryEnqueueProvider { + type EnqueueError = QsEnqueueError; type VerifyingKeyError = T::LoadSigningKeyError; async fn dispatch(&self, message: DsFanOutMessage) -> Result<(), Self::EnqueueError> { - Qs::enqueue_message(self.storage.deref(), &self.notifier, message).await + Qs::enqueue_message( + self.storage.deref(), + &self.notifier, + self.network.deref(), + message, + ) + .await } async fn verifying_key(&self, _fqdn: &Fqdn) -> Result { diff --git a/server/src/storage_provider/memory/tests/qs.rs b/server/src/storage_provider/memory/tests/qs.rs index d8ddf76e..b916cdc8 100644 --- a/server/src/storage_provider/memory/tests/qs.rs +++ b/server/src/storage_provider/memory/tests/qs.rs @@ -23,7 +23,7 @@ use crate::storage_provider::memory::qs::MemStorageProvider; // Unit tests for MemStorageProvider #[actix_rt::test] async fn qs_mem_provider() { - let provider = MemStorageProvider::default(); + let provider = MemStorageProvider::new("example.com".into()); // Set up a user record let user_record = QsUserRecord::new( diff --git a/server/tests/mod.rs b/server/tests/mod.rs index 07e4465b..27049aa4 100644 --- a/server/tests/mod.rs +++ b/server/tests/mod.rs @@ -5,8 +5,11 @@ mod qs; mod utils; +use std::sync::Arc; + use phnxapiclient::{ApiClient, TransportEncryption}; +use phnxserver::network_provider::MockNetworkProvider; use utils::setup::TestBackend; pub use utils::*; @@ -14,7 +17,8 @@ pub use utils::*; #[tracing::instrument(name = "Test WS", skip_all)] async fn health_check_works() { tracing::info!("Tracing: Spawning websocket connection task"); - let (address, _ws_dispatch) = spawn_app().await; + let network_provider = Arc::new(MockNetworkProvider::new()); + let (address, _ws_dispatch) = spawn_app("example.com".into(), network_provider).await; tracing::info!("Server started: {}", address.to_string()); @@ -26,124 +30,133 @@ async fn health_check_works() { assert!(client.health_check().await); } +const ALICE: &str = "alice@example.com"; +const BOB: &str = "bob@example.com"; +const CHARLIE: &str = "charlie@example.com"; +const DAVE: &str = "dave@example.com"; + #[actix_rt::test] #[tracing::instrument(name = "Connect users test", skip_all)] async fn connect_users() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.connect_users("alice", "bob").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.connect_users(ALICE, BOB).await; } #[actix_rt::test] #[tracing::instrument(name = "Send message test", skip_all)] async fn send_message() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - let conversation_id = setup.connect_users("alice", "bob").await; - setup.send_message(conversation_id, "alice", &["bob"]).await; - setup.send_message(conversation_id, "bob", &["alice"]).await; + tracing::info!("Setting up setup"); + let mut setup = TestBackend::single().await; + tracing::info!("Creating users"); + setup.add_user(ALICE).await; + tracing::info!("Created alice"); + setup.add_user(BOB).await; + let conversation_id = setup.connect_users(ALICE, BOB).await; + setup.send_message(conversation_id, ALICE, vec![BOB]).await; + setup.send_message(conversation_id, BOB, vec![ALICE]).await; } #[actix_rt::test] #[tracing::instrument(name = "Create group test", skip_all)] async fn create_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.create_group(ALICE).await; } #[actix_rt::test] #[tracing::instrument(name = "Invite to group test", skip_all)] async fn invite_to_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.add_user("charlie").await; - setup.connect_users("alice", "bob").await; - setup.connect_users("alice", "charlie").await; - let conversation_id = setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.add_user(CHARLIE).await; + setup.connect_users(ALICE, BOB).await; + setup.connect_users(ALICE, CHARLIE).await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob", "charlie"]) + .invite_to_group(conversation_id, ALICE, vec![BOB, CHARLIE]) .await; } #[actix_rt::test] #[tracing::instrument(name = "Invite to group test", skip_all)] async fn update_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.add_user("charlie").await; - setup.connect_users("alice", "bob").await; - setup.connect_users("alice", "charlie").await; - let conversation_id = setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.add_user(CHARLIE).await; + setup.connect_users(ALICE, BOB).await; + setup.connect_users(ALICE, CHARLIE).await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob", "charlie"]) + .invite_to_group(conversation_id, ALICE, vec![BOB, CHARLIE]) .await; - setup.update_group(conversation_id, "bob").await + setup.update_group(conversation_id, BOB).await } #[actix_rt::test] #[tracing::instrument(name = "Invite to group test", skip_all)] async fn remove_from_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.add_user("charlie").await; - setup.add_user("dave").await; - setup.connect_users("alice", "bob").await; - setup.connect_users("alice", "charlie").await; - setup.connect_users("alice", "dave").await; - let conversation_id = setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.add_user(CHARLIE).await; + setup.add_user(DAVE).await; + setup.connect_users(ALICE, BOB).await; + setup.connect_users(ALICE, CHARLIE).await; + setup.connect_users(ALICE, DAVE).await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob", "charlie", "dave"]) + .invite_to_group(conversation_id, ALICE, vec![BOB, CHARLIE, DAVE]) .await; setup - .remove_from_group(conversation_id, "charlie", &["alice", "bob"]) + .remove_from_group(conversation_id, CHARLIE, vec![ALICE, BOB]) .await } #[actix_rt::test] #[tracing::instrument(name = "Invite to group test", skip_all)] async fn leave_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.connect_users("alice", "bob").await; - let conversation_id = setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.connect_users(ALICE, BOB).await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob"]) + .invite_to_group(conversation_id, ALICE, vec![BOB]) .await; - setup.leave_group(conversation_id, "alice").await; + setup.leave_group(conversation_id, ALICE).await; } #[actix_rt::test] #[tracing::instrument(name = "Invite to group test", skip_all)] async fn delete_group() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; - setup.add_user("bob").await; - setup.connect_users("alice", "bob").await; - let conversation_id = setup.create_group("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; + setup.connect_users(ALICE, BOB).await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob"]) + .invite_to_group(conversation_id, ALICE, vec![BOB]) .await; - setup.delete_group(conversation_id, "bob").await; + setup.delete_group(conversation_id, BOB).await; } #[actix_rt::test] #[tracing::instrument(name = "Create user", skip_all)] async fn create_user() { - let mut setup = TestBackend::new().await; - setup.add_user("alice").await; + let mut setup = TestBackend::single().await; + setup.add_user(ALICE).await; } #[actix_rt::test] #[tracing::instrument(name = "Inexistant endpoint", skip_all)] async fn inexistant_endpoint() { - let (address, _ws_dispatch) = spawn_app().await; + let network_provider = Arc::new(MockNetworkProvider::new()); + let (address, _ws_dispatch) = spawn_app("example.com".into(), network_provider).await; // Initialize the client let client = ApiClient::initialize(address, TransportEncryption::Off) @@ -156,80 +169,80 @@ async fn inexistant_endpoint() { #[actix_rt::test] #[tracing::instrument(name = "Full cycle", skip_all)] async fn full_cycle() { - let mut setup = TestBackend::new().await; + let mut setup = TestBackend::single().await; // Create alice and bob - setup.add_user("alice").await; - setup.add_user("bob").await; + setup.add_user(ALICE).await; + setup.add_user(BOB).await; // Connect them - let conversation_alice_bob = setup.connect_users("alice", "bob").await; + let conversation_alice_bob = setup.connect_users(ALICE, BOB).await; // Test the connection conversation by sending messages back and forth. setup - .send_message(conversation_alice_bob, "alice", &["bob"]) + .send_message(conversation_alice_bob, ALICE, vec![BOB]) .await; setup - .send_message(conversation_alice_bob, "bob", &["alice"]) + .send_message(conversation_alice_bob, BOB, vec![ALICE]) .await; // Create an independent group and invite bob. - let conversation_id = setup.create_group("alice").await; + let conversation_id = setup.create_group(ALICE).await; setup - .invite_to_group(conversation_id, "alice", &["bob"]) + .invite_to_group(conversation_id, ALICE, vec![BOB]) .await; // Create chalie, connect him with alice and invite him to the group. - setup.add_user("charlie").await; - setup.connect_users("alice", "charlie").await; + setup.add_user(CHARLIE).await; + setup.connect_users(ALICE, CHARLIE).await; setup - .invite_to_group(conversation_id, "alice", &["charlie"]) + .invite_to_group(conversation_id, ALICE, vec![CHARLIE]) .await; // Add dave, connect him with charlie and invite him to the group. Then have dave remove alice and bob. - setup.add_user("dave").await; - setup.connect_users("charlie", "dave").await; + setup.add_user(DAVE).await; + setup.connect_users(CHARLIE, DAVE).await; setup - .invite_to_group(conversation_id, "charlie", &["dave"]) + .invite_to_group(conversation_id, CHARLIE, vec![DAVE]) .await; setup - .send_message(conversation_id, "alice", &["charlie", "bob", "dave"]) + .send_message(conversation_id, ALICE, vec![CHARLIE, BOB, DAVE]) .await; setup - .remove_from_group(conversation_id, "dave", &["alice", "bob"]) + .remove_from_group(conversation_id, DAVE, vec![ALICE, BOB]) .await; - setup.leave_group(conversation_id, "charlie").await; + setup.leave_group(conversation_id, CHARLIE).await; - setup.delete_group(conversation_id, "dave").await + setup.delete_group(conversation_id, DAVE).await } #[actix_rt::test] async fn benchmarks() { - let mut setup = TestBackend::new().await; + let mut setup = TestBackend::single().await; const NUM_USERS: usize = 10; const NUM_MESSAGES: usize = 10; // Create alice - setup.add_user("alice").await; + setup.add_user(ALICE).await; // Create bob - setup.add_user("bob").await; + setup.add_user(BOB).await; // Create many different bobs let bobs: Vec = (0..NUM_USERS) - .map(|i| format!("bob{}", i)) + .map(|i| format!("bob{}@example.com", i)) .collect::>(); // Measure the time it takes to create all the users let start = std::time::Instant::now(); - for bob in &bobs { - setup.add_user(&bob).await; + for bob in bobs.clone() { + setup.add_user(bob).await; } let elapsed = start.elapsed(); println!( @@ -240,8 +253,8 @@ async fn benchmarks() { // Measure the time it takes to connect all bobs with alice let start = std::time::Instant::now(); - for bob in &bobs { - setup.connect_users("alice", &bob).await; + for bob in bobs.clone() { + setup.connect_users(ALICE, bob).await; } let elapsed = start.elapsed(); println!( @@ -251,13 +264,13 @@ async fn benchmarks() { ); // Connect them - let conversation_alice_bob = setup.connect_users("alice", "bob").await; + let conversation_alice_bob = setup.connect_users(ALICE, BOB).await; // Measure the time it takes to send a message let start = std::time::Instant::now(); for _ in 0..NUM_MESSAGES { setup - .send_message(conversation_alice_bob, "alice", &["bob"]) + .send_message(conversation_alice_bob, ALICE, vec![BOB]) .await; } let elapsed = start.elapsed(); @@ -268,13 +281,13 @@ async fn benchmarks() { ); // Create an independent group - let conversation_id = setup.create_group("alice").await; + let conversation_id = setup.create_group(ALICE).await; // Measure the time it takes to invite a user let start = std::time::Instant::now(); - for bob in &bobs { + for bob in bobs.clone() { setup - .invite_to_group(conversation_id, "alice", &[&bob]) + .invite_to_group(conversation_id, ALICE, vec![bob]) .await; } let elapsed = start.elapsed(); @@ -288,14 +301,7 @@ async fn benchmarks() { let start = std::time::Instant::now(); for _ in 0..NUM_MESSAGES { setup - .send_message( - conversation_id, - "alice", - bobs.iter() - .map(|s| s.as_str()) - .collect::>() - .as_slice(), - ) + .send_message(conversation_id, ALICE, bobs.clone()) .await; } let elapsed = start.elapsed(); diff --git a/server/tests/qs/ws.rs b/server/tests/qs/ws.rs index f97e1d5f..185fd768 100644 --- a/server/tests/qs/ws.rs +++ b/server/tests/qs/ws.rs @@ -2,9 +2,11 @@ // // SPDX-License-Identifier: AGPL-3.0-or-later +use std::sync::Arc; + use phnxapiclient::{qs_api::ws::WsEvent, ApiClient, TransportEncryption}; use phnxbackend::qs::{QsClientId, WebsocketNotifier, WsNotification}; -use phnxserver::endpoints::qs::ws::QsWsMessage; +use phnxserver::{endpoints::qs::ws::QsWsMessage, network_provider::MockNetworkProvider}; use super::*; @@ -12,7 +14,8 @@ use super::*; #[actix_rt::test] #[tracing::instrument(name = "Test WS Reconnect", skip_all)] async fn test_ws_reconnect() { - let (address, _ws_dispatch) = spawn_app().await; + let network_provider = Arc::new(MockNetworkProvider::new()); + let (address, _ws_dispatch) = spawn_app("example.com".into(), network_provider).await; let client_id = QsClientId::random(); @@ -48,7 +51,8 @@ async fn test_ws_reconnect() { #[actix_rt::test] #[tracing::instrument(name = "Test WS Sending", skip_all)] async fn test_ws_sending() { - let (address, ws_dispatch) = spawn_app().await; + let network_provider = Arc::new(MockNetworkProvider::new()); + let (address, ws_dispatch) = spawn_app("example.com".into(), network_provider).await; let client_id = QsClientId::random(); diff --git a/server/tests/utils/mod.rs b/server/tests/utils/mod.rs index b57227a3..5a9a008b 100644 --- a/server/tests/utils/mod.rs +++ b/server/tests/utils/mod.rs @@ -17,6 +17,7 @@ use phnxbackend::qs::Fqdn; use phnxserver::{ configurations::get_configuration, endpoints::qs::ws::DispatchWebsocketNotifier, + network_provider::MockNetworkProvider, run, storage_provider::memory::{ auth_service::{EphemeralAsStorage, MemoryAsStorage}, @@ -45,7 +46,10 @@ static TRACING: Lazy<()> = Lazy::new(|| { /// Start the server and initialize the database connection. Returns the /// address and a DispatchWebsocketNotifier to dispatch notofication over the /// websocket. -pub async fn spawn_app() -> (SocketAddr, DispatchWebsocketNotifier) { +pub async fn spawn_app( + domain: Fqdn, + network_provider: Arc, +) -> (SocketAddr, DispatchWebsocketNotifier) { // Initialize tracing subscription only once. Lazy::force(&TRACING); @@ -62,14 +66,16 @@ pub async fn spawn_app() -> (SocketAddr, DispatchWebsocketNotifier) { let ws_dispatch_notifier = DispatchWebsocketNotifier::default_addr(); let ds_storage_provider = MemoryDsStorage::new(); - let qs_storage_provider = Arc::new(MemStorageProvider::default()); + let qs_storage_provider = Arc::new(MemStorageProvider::new(domain.clone())); - let as_storage_provider = MemoryAsStorage::new(Fqdn {}, SignatureScheme::ED25519).unwrap(); + let as_storage_provider = + MemoryAsStorage::new(domain.clone(), SignatureScheme::ED25519).unwrap(); let as_ephemeral_storage_provider = EphemeralAsStorage::default(); let qs_connector = MemoryEnqueueProvider { storage: qs_storage_provider.clone(), notifier: ws_dispatch_notifier.clone(), + network: network_provider.clone(), }; // Start the server @@ -88,5 +94,6 @@ pub async fn spawn_app() -> (SocketAddr, DispatchWebsocketNotifier) { tokio::spawn(server); // Return the address + network_provider.add_port(domain, address.port()); (address, ws_dispatch_notifier) } diff --git a/server/tests/utils/setup.rs b/server/tests/utils/setup.rs index ba949995..34bacfc4 100644 --- a/server/tests/utils/setup.rs +++ b/server/tests/utils/setup.rs @@ -9,6 +9,7 @@ use std::{ }; use opaque_ke::rand::{rngs::OsRng, Rng}; +use phnxbackend::{auth_service::UserName, qs::Fqdn}; use phnxcoreclient::{ notifications::{Notifiable, NotificationHub}, types::{ @@ -17,6 +18,7 @@ use phnxcoreclient::{ }, users::SelfUser, }; +use phnxserver::network_provider::MockNetworkProvider; use uuid::Uuid; use crate::spawn_app; @@ -54,12 +56,18 @@ pub struct TestUser { } impl TestUser { - pub async fn new(user_name: &str, address: SocketAddr) -> Self { + pub async fn new(user_name: &UserName, address: SocketAddr) -> Self { let mut notification_hub = NotificationHub::::default(); let notifier = TestNotifier::new(); notification_hub.add_sink(notifier.notifier()); - let user = SelfUser::new(user_name, user_name, address, notification_hub).await; + let user = SelfUser::new( + user_name.clone(), + &user_name.to_string(), + address, + notification_hub, + ) + .await; Self { user, notifier } } @@ -69,25 +77,35 @@ impl TestUser { } pub struct TestBackend { - pub users: HashMap, - pub groups: HashMap>, + pub users: HashMap, + pub groups: HashMap>, pub address: SocketAddr, + pub domain: Fqdn, } impl TestBackend { - pub async fn new() -> Self { - let (address, _ws_dispatch) = spawn_app().await; + pub async fn single() -> Self { + let network_provider = Arc::new(MockNetworkProvider::new()); + let domain = "example.com".into(); + TestBackend::new(domain, network_provider).await + } + + async fn new(domain: Fqdn, network_provider: Arc) -> Self { + let (address, _ws_dispatch) = spawn_app(domain.clone(), network_provider).await; Self { users: HashMap::new(), address, groups: HashMap::new(), + domain, } } - pub async fn add_user(&mut self, user_name: &str) { + pub async fn add_user(&mut self, user_name: impl Into) { + tracing::info!("Turning string into user name"); + let user_name = user_name.into(); tracing::info!("Creating {user_name}"); - let user = TestUser::new(user_name, self.address).await; - self.users.insert(user_name.to_owned(), user); + let user = TestUser::new(&user_name, self.address).await; + self.users.insert(user_name, user); } pub fn flush_notifications(&mut self) { @@ -98,14 +116,19 @@ impl TestBackend { /// This has the updater commit an update, but without the checks ensuring /// that the group state remains unchanged. - pub async fn commit_to_proposals(&mut self, conversation_id: Uuid, updater_name: &str) { + pub async fn commit_to_proposals( + &mut self, + conversation_id: Uuid, + updater_name: impl Into, + ) { + let updater_name = &updater_name.into(); tracing::info!( "{} performs an update in group {}", updater_name, conversation_id ); - let test_updater = self.users.get_mut(updater_name).unwrap(); + let test_updater = self.users.get_mut(&updater_name).unwrap(); let updater = &mut test_updater.user; let pending_removes = updater.pending_removes(conversation_id).unwrap(); @@ -114,7 +137,7 @@ impl TestBackend { updater.update(conversation_id).await; let group_members_after = updater.group_members(conversation_id).unwrap(); - let difference: HashSet = group_members_before + let difference: HashSet = group_members_before .difference(&group_members_after) .map(|s| s.to_owned()) .collect(); @@ -145,14 +168,14 @@ impl TestBackend { let conversation_after = group_member.conversation(conversation_id).unwrap(); assert!( conversation_after.status - == ConversationStatus::Inactive(InactiveConversation { - past_members: group_members_before - }) + == ConversationStatus::Inactive(InactiveConversation::new( + group_members_before + )) ); } else { // ... if not, it should remove the members to be removed. let group_members_after = group_member.group_members(conversation_id).unwrap(); - let difference: HashSet = group_members_before + let difference: HashSet = group_members_before .difference(&group_members_after) .map(|s| s.to_owned()) .collect(); @@ -161,7 +184,8 @@ impl TestBackend { } } - pub async fn update_group(&mut self, conversation_id: Uuid, updater_name: &str) { + pub async fn update_group(&mut self, conversation_id: Uuid, updater_name: impl Into) { + let updater_name = &updater_name.into(); tracing::info!( "{} performs an update in group {}", updater_name, @@ -196,18 +220,23 @@ impl TestBackend { } } - pub async fn connect_users(&mut self, user1_name: &str, user2_name: &str) -> Uuid { + pub async fn connect_users( + &mut self, + user1_name: impl Into, + user2_name: impl Into, + ) -> Uuid { + let user1_name = user1_name.into(); + let user2_name = user2_name.into(); tracing::info!("Connecting users {} and {}", user1_name, user2_name); - let test_user1 = self.users.get_mut(user1_name).unwrap(); + let test_user1 = self.users.get_mut(&user1_name).unwrap(); let user1 = &mut test_user1.user; let user1_partial_contacts_before = user1.partial_contacts(); let user1_conversations_before = user1.get_conversations(); - tracing::info!("{} adds {} as a contact", user1_name, user2_name); user1.add_contact(&user2_name).await; let mut user1_partial_contacts_after = user1.partial_contacts(); let new_user_position = user1_partial_contacts_after .iter() - .position(|c| &c.user_name.to_string() == user2_name) + .position(|c| c.user_name == user2_name) .expect("User 2 should be in the partial contacts list of user 1"); // If we remove the new user, the partial contact lists should be the same. user1_partial_contacts_after.remove(new_user_position); @@ -220,13 +249,13 @@ impl TestBackend { let mut user1_conversations_after = user1.get_conversations(); let new_conversation_position = user1_conversations_after .iter() - .position(|c| &c.attributes.title == user2_name) + .position(|c| c.attributes.title == user2_name.to_string()) .expect("User 1 should have created a new conversation"); let conversation = user1_conversations_after.remove(new_conversation_position); assert!(conversation.status == ConversationStatus::Active); assert!( conversation.conversation_type - == ConversationType::UnconfirmedConnection(user2_name.as_bytes().to_vec()) + == ConversationType::UnconfirmedConnection(user2_name.to_bytes()) ); user1_conversations_before .into_iter() @@ -236,7 +265,7 @@ impl TestBackend { }); let user1_conversation_id = conversation.id.clone().as_uuid(); - let test_user2 = self.users.get_mut(user2_name).unwrap(); + let test_user2 = self.users.get_mut(&user2_name).unwrap(); let user2 = &mut test_user2.user; let user2_contacts_before = user2.contacts(); let user2_conversations_before = user2.get_conversations(); @@ -248,7 +277,7 @@ impl TestBackend { let mut user2_contacts_after = user2.contacts(); let new_contact_position = user2_contacts_after .iter() - .position(|c| &c.user_name.to_string() == user1_name) + .position(|c| c.user_name == user1_name) .expect("User 1 should be in the partial contacts list of user 2"); // If we remove the new user, the partial contact lists should be the same. user2_contacts_after.remove(new_contact_position); @@ -262,13 +291,12 @@ impl TestBackend { let mut user2_conversations_after = user2.get_conversations(); let new_conversation_position = user2_conversations_after .iter() - .position(|c| &c.attributes.title == user1_name) + .position(|c| c.attributes.title == user1_name.to_string()) .expect("User 2 should have created a new conversation"); let conversation = user2_conversations_after.remove(new_conversation_position); assert!(conversation.status == ConversationStatus::Active); assert!( - conversation.conversation_type - == ConversationType::Connection(user1_name.as_bytes().to_vec()) + conversation.conversation_type == ConversationType::Connection(user1_name.to_bytes()) ); user2_conversations_before .into_iter() @@ -279,7 +307,7 @@ impl TestBackend { let user2_conversation_id = conversation.id.as_uuid(); let user2_user_name = user2.user_name().clone(); - let test_user1 = self.users.get_mut(user1_name).unwrap(); + let test_user1 = self.users.get_mut(&user1_name).unwrap(); let user1 = &mut test_user1.user; let user1_contacts_before: HashSet<_> = user1 .contacts() @@ -307,13 +335,12 @@ impl TestBackend { let mut user1_conversations_after = user1.get_conversations(); let new_conversation_position = user1_conversations_after .iter() - .position(|c| &c.attributes.title == &user2_name) + .position(|c| &c.attributes.title == &user2_name.to_string()) .expect("User 1 should have created a new conversation"); let conversation = user1_conversations_after.remove(new_conversation_position); assert!(conversation.status == ConversationStatus::Active); assert!( - conversation.conversation_type - == ConversationType::Connection(user2_name.as_bytes().to_vec()) + conversation.conversation_type == ConversationType::Connection(user2_name.to_bytes()) ); user1_conversations_before .into_iter() @@ -325,12 +352,20 @@ impl TestBackend { debug_assert_eq!(user1_conversation_id, user2_conversation_id); // Send messages both ways to ensure it works. - self.send_message(user1_conversation_id, user1_name, &[user2_name]) - .await; - self.send_message(user1_conversation_id, user2_name, &[user1_name]) - .await; - - let member_set: HashSet = [user1_name.to_string(), user2_name.to_string()].into(); + self.send_message( + user1_conversation_id, + user1_name.clone(), + vec![user2_name.clone()], + ) + .await; + self.send_message( + user1_conversation_id, + user2_name.clone(), + vec![user1_name.clone()], + ) + .await; + + let member_set: HashSet = [user1_name, user2_name].into(); assert_eq!(member_set.len(), 2); self.groups.insert(user1_conversation_id, member_set); user1_conversation_id @@ -342,17 +377,26 @@ impl TestBackend { pub async fn send_message( &mut self, conversation_id: Uuid, - sender_name: &str, - recipient_names: &[&str], + sender_name: impl Into, + recipient_names: Vec>, ) { + let sender_name = sender_name.into(); + let recipient_names: Vec = recipient_names + .into_iter() + .map(|name| name.into()) + .collect::>(); + let recipient_strings = recipient_names + .iter() + .map(|n| n.to_string()) + .collect::>(); tracing::info!( "{} sends a message to {}", sender_name, - recipient_names.join(", ") + recipient_strings.join(", ") ); let message: Vec = OsRng.gen::<[u8; 32]>().to_vec(); let orig_message = MessageContentType::Text(phnxcoreclient::types::TextMessage { message }); - let test_sender = self.users.get_mut(sender_name).unwrap(); + let test_sender = self.users.get_mut(&sender_name).unwrap(); let sender = &mut test_sender.user; // Before sending a message, the sender must first fetch and process its QS messages. @@ -374,13 +418,13 @@ impl TestBackend { assert_eq!( message.message, Message::Content(ContentMessage { - sender: test_sender.user.user_name().as_bytes().to_vec(), + sender: test_sender.user.user_name().to_bytes(), content: orig_message.clone() }) ); - for recipient_name in recipient_names { - let recipient = self.users.get_mut(recipient_name.to_owned()).unwrap(); + for recipient_name in &recipient_names { + let recipient = self.users.get_mut(recipient_name).unwrap(); let recipient_user = &mut recipient.user; // Flush notifications //let _recipient_notifications = recipient.notifier.notifications(); @@ -402,7 +446,7 @@ impl TestBackend { assert_eq!( message.conversation_message.message, Message::Content(ContentMessage { - sender: sender_user_name.as_bytes().to_vec(), + sender: sender_user_name.to_bytes(), content: orig_message.clone() }) ); @@ -411,8 +455,9 @@ impl TestBackend { self.flush_notifications(); } - pub async fn create_group(&mut self, user_name: &str) -> Uuid { - let test_user = self.users.get_mut(user_name).unwrap(); + pub async fn create_group(&mut self, user_name: impl Into) -> Uuid { + let user_name = user_name.into(); + let test_user = self.users.get_mut(&user_name).unwrap(); let user = &mut test_user.user; let user_conversations_before = user.get_conversations(); @@ -434,7 +479,7 @@ impl TestBackend { assert_eq!(before.id, after.id); }); self.flush_notifications(); - let member_set: HashSet = [user_name.to_string()].into(); + let member_set: HashSet = [user_name].into(); assert_eq!(member_set.len(), 1); self.groups.insert(conversation_id, member_set); @@ -446,10 +491,19 @@ impl TestBackend { pub async fn invite_to_group( &mut self, conversation_id: Uuid, - inviter_name: &str, - invitee_names: &[&str], + inviter_name: impl Into, + invitee_names: Vec>, ) { - let test_inviter = self.users.get_mut(inviter_name).unwrap(); + let inviter_name = inviter_name.into(); + let invitee_names: Vec = invitee_names + .into_iter() + .map(|name| name.into()) + .collect::>(); + let invitee_strings = invitee_names + .iter() + .map(|n| n.to_string()) + .collect::>(); + let test_inviter = self.users.get_mut(&inviter_name).unwrap(); let inviter = &mut test_inviter.user; // Before inviting anyone to a group, the inviter must first fetch and @@ -464,36 +518,32 @@ impl TestBackend { tracing::info!( "{} invites {} to the group with id {}", inviter_name, - invitee_names.join(", "), + invitee_strings.join(", "), conversation_id ); // Perform the invite operation and check that the invitees are now in the group. - let inviter_group_members_before: HashSet = inviter + let inviter_group_members_before: HashSet = inviter .group_members(conversation_id) .expect("Error getting group members."); inviter - .invite_users(conversation_id, invitee_names) + .invite_users(conversation_id, &invitee_names) .await .expect("Error inviting users."); - let inviter_group_members_after: HashSet = inviter + let inviter_group_members_after = inviter .group_members(conversation_id) .expect("Error getting group members."); let new_members = inviter_group_members_after .difference(&inviter_group_members_before) - .map(|name| name.to_owned()) - .collect::>(); - let invitee_set = invitee_names - .iter() - .map(|&name| name.to_owned()) .collect::>(); + let invitee_set = invitee_names.iter().collect::>(); assert_eq!(new_members, invitee_set); // Now that the invitation is out, have the invitees and all other group // members fetch and process QS messages. - for &invitee_name in invitee_names { + for invitee_name in &invitee_names { let test_invitee = self.users.get_mut(invitee_name).unwrap(); let invitee = &mut test_invitee.user; let invitee_conversations_before = invitee.get_conversations(); @@ -524,7 +574,7 @@ impl TestBackend { let group_members = self.groups.get_mut(&conversation_id).unwrap(); for group_member_name in group_members.iter() { // Skip the sender - if group_member_name == inviter_name { + if group_member_name == &inviter_name { continue; } let test_group_member = self.users.get_mut(group_member_name).unwrap(); @@ -540,32 +590,28 @@ impl TestBackend { let group_members_after = group_member.group_members(conversation_id).unwrap(); let new_members = group_members_after .difference(&group_members_before) - .map(|name| name.to_owned()) - .collect::>(); - let invitee_set = invitee_names - .iter() - .map(|&name| name.to_owned()) .collect::>(); + let invitee_set = invitee_names.iter().collect::>(); assert_eq!(new_members, invitee_set) } - for invitee_name in invitee_names { - let unique_member = group_members.insert(invitee_name.to_string()); + for invitee_name in &invitee_names { + let unique_member = group_members.insert(invitee_name.clone()); assert!(unique_member == true); } self.flush_notifications(); // Now send messages to check that the group works properly. This also // ensures that everyone involved has picked up their messages from the // QS and that notifications are flushed. - self.send_message(conversation_id, inviter_name, invitee_names) + self.send_message(conversation_id, inviter_name.clone(), invitee_names.clone()) .await; - for &invitee_name in invitee_names { + for invitee_name in &invitee_names { let recipients: Vec<_> = invitee_names .iter() - .filter(|&&name| name != invitee_name) + .filter(|&name| name != invitee_name) .chain([&inviter_name].into_iter()) .map(|name| name.to_owned()) .collect(); - self.send_message(conversation_id, invitee_name, recipients.as_slice()) + self.send_message(conversation_id, invitee_name.clone(), recipients) .await; } } @@ -575,10 +621,19 @@ impl TestBackend { pub async fn remove_from_group( &mut self, conversation_id: Uuid, - remover_name: &str, - removed_names: &[&str], + remover_name: impl Into, + removed_names: Vec>, ) { - let test_remover = self.users.get_mut(remover_name).unwrap(); + let remover_name = remover_name.into(); + let removed_names: Vec = removed_names + .into_iter() + .map(|name| name.into()) + .collect::>(); + let removed_strings = removed_names + .iter() + .map(|n| n.to_string()) + .collect::>(); + let test_remover = self.users.get_mut(&remover_name).unwrap(); let remover = &mut test_remover.user; // Before removing anyone from a group, the remover must first fetch and @@ -593,22 +648,22 @@ impl TestBackend { tracing::info!( "{} removes {} from the group with id {}", remover_name, - removed_names.join(", "), + removed_strings.join(", "), conversation_id ); // Perform the remove operation and check that the removed are not in // the group anymore. - let remover_group_members_before: HashSet = remover + let remover_group_members_before = remover .group_members(conversation_id) .expect("Error getting group members."); remover - .remove_users(conversation_id, removed_names) + .remove_users(conversation_id, &removed_names) .await .expect("Error removing users."); - let remover_group_members_after: HashSet = remover + let remover_group_members_after = remover .group_members(conversation_id) .expect("Error getting group members."); let removed_members = remover_group_members_before @@ -617,11 +672,11 @@ impl TestBackend { .collect::>(); let removed_set = removed_names .iter() - .map(|&name| name.to_owned()) + .map(|name| name.to_owned()) .collect::>(); assert_eq!(removed_members, removed_set); - for &removed_name in removed_names { + for removed_name in &removed_names { let test_removed = self.users.get_mut(removed_name).unwrap(); let removed = &mut test_removed.user; let removed_conversations_before = removed.get_conversations(); @@ -643,8 +698,7 @@ impl TestBackend { )); assert!(conversation.id.as_uuid() == conversation_id); if let ConversationStatus::Inactive(inactive_status) = &conversation.status { - let inactive_status_members: HashSet<_> = - inactive_status.past_members.clone().into_iter().collect(); + let inactive_status_members: HashSet = inactive_status.past_members(); assert_eq!(inactive_status_members, past_members); } else { panic!("Conversation should be inactive.") @@ -665,14 +719,14 @@ impl TestBackend { assert!(!error) } let group_members = self.groups.get_mut(&conversation_id).unwrap(); - for &removed_name in removed_names { + for removed_name in &removed_names { let remove_successful = group_members.remove(removed_name); assert!(remove_successful == true); } // Now have the rest of the group pick up and process their messages. for group_member_name in group_members.iter() { // Skip the remover - if group_member_name == remover_name { + if group_member_name == &remover_name { continue; } let test_group_member = self.users.get_mut(group_member_name).unwrap(); @@ -692,7 +746,7 @@ impl TestBackend { .collect::>(); let removed_set = removed_names .iter() - .map(|&name| name.to_owned()) + .map(|name| name.to_owned()) .collect::>(); assert_eq!(removed_members, removed_set) } @@ -701,13 +755,14 @@ impl TestBackend { } /// Has the leaver leave the given group. - pub async fn leave_group(&mut self, conversation_id: Uuid, leaver_name: &str) { + pub async fn leave_group(&mut self, conversation_id: Uuid, leaver_name: impl Into) { + let leaver_name = leaver_name.into(); tracing::info!( "{} leaves the group with id {}", leaver_name, conversation_id ); - let test_leaver = self.users.get_mut(leaver_name).unwrap(); + let test_leaver = self.users.get_mut(&leaver_name).unwrap(); let leaver = &mut test_leaver.user; // Perform the leave operation. @@ -721,7 +776,7 @@ impl TestBackend { let mut random_member_iter = group_members.iter(); let mut random_member_name = random_member_iter.next().unwrap(); // Ensure that the random member isn't the leaver. - if random_member_name == leaver_name { + if random_member_name == &leaver_name { random_member_name = random_member_iter.next().unwrap() } let test_random_member = self.users.get_mut(random_member_name).unwrap(); @@ -739,22 +794,23 @@ impl TestBackend { // pick up and process their messages. This also tests that group // members were removed correctly from the local group and that the // leaver has turned its conversation inactive. - self.commit_to_proposals(conversation_id, random_member_name) + self.commit_to_proposals(conversation_id, random_member_name.clone()) .await; let group_members = self.groups.get_mut(&conversation_id).unwrap(); - group_members.remove(leaver_name); + group_members.remove(&leaver_name); self.flush_notifications(); } - pub async fn delete_group(&mut self, conversation_id: Uuid, deleter_name: &str) { + pub async fn delete_group(&mut self, conversation_id: Uuid, deleter_name: impl Into) { + let deleter_name = deleter_name.into(); tracing::info!( "{} deletes the group with id {}", deleter_name, conversation_id ); - let test_deleter = self.users.get_mut(deleter_name).unwrap(); + let test_deleter = self.users.get_mut(&deleter_name).unwrap(); let deleter = &mut test_deleter.user; // Before removing anyone from a group, the remover must first fetch and @@ -779,8 +835,7 @@ impl TestBackend { let deleter_conversation_after = deleter.conversation(conversation_id).unwrap(); if let ConversationStatus::Inactive(inactive_status) = &deleter_conversation_after.status { - let inactive_status_members: HashSet<_> = - inactive_status.past_members.clone().into_iter().collect(); + let inactive_status_members: HashSet<_> = inactive_status.past_members(); assert_eq!(inactive_status_members, past_members); } else { panic!("Conversation should be inactive.") @@ -788,7 +843,7 @@ impl TestBackend { for group_member_name in self.groups.get(&conversation_id).unwrap().iter() { // Skip the deleter - if group_member_name == deleter_name { + if group_member_name == &deleter_name { continue; } let test_group_member = self.users.get_mut(group_member_name).unwrap(); @@ -814,8 +869,7 @@ impl TestBackend { if let ConversationStatus::Inactive(inactive_status) = &group_member_conversation_after.status { - let inactive_status_members: HashSet<_> = - inactive_status.past_members.clone().into_iter().collect(); + let inactive_status_members: HashSet<_> = inactive_status.past_members(); assert_eq!(inactive_status_members, past_members); } else { panic!("Conversation should be inactive.") @@ -826,3 +880,93 @@ impl TestBackend { self.flush_notifications(); } } + +pub struct TestBed { + pub network_provider: Arc, + pub backends: HashMap, +} + +impl TestBed { + pub fn new() -> Self { + Self { + network_provider: Arc::new(MockNetworkProvider::new()), + backends: HashMap::new(), + } + } + + pub async fn new_backend(&mut self, domain: Fqdn) { + let backend = TestBackend::new(domain.clone(), self.network_provider.clone()).await; + self.backends.insert(domain, backend); + } + + pub async fn add_user(&mut self, user_name: impl Into) { + let user_name = user_name.into(); + let domain = user_name.domain(); + let backend = self.backends.get_mut(&domain).unwrap(); + backend.add_user(user_name).await; + } + + pub async fn connect_users( + &mut self, + domain: Fqdn, + user1_name: impl Into, + user2_name: impl Into, + ) { + let backend = self.backends.get_mut(&domain).unwrap(); + backend.connect_users(user1_name, user2_name).await; + } + + pub async fn send_message( + &mut self, + domain: Fqdn, + conversation_id: Uuid, + sender_name: impl Into, + recipient_names: Vec>, + ) { + let backend = self.backends.get_mut(&domain).unwrap(); + backend + .send_message(conversation_id, sender_name, recipient_names) + .await; + } + + pub async fn create_group(&mut self, domain: Fqdn, user_name: impl Into) -> Uuid { + let backend = self.backends.get_mut(&domain).unwrap(); + backend.create_group(user_name).await + } + + pub async fn invite_to_group( + &mut self, + domain: Fqdn, + conversation_id: Uuid, + inviter_name: impl Into, + invitee_names: Vec>, + ) { + let backend = self.backends.get_mut(&domain).unwrap(); + backend + .invite_to_group(conversation_id, inviter_name, invitee_names) + .await; + } + + pub async fn remove_from_group( + &mut self, + domain: Fqdn, + conversation_id: Uuid, + remover_name: impl Into, + removed_names: Vec>, + ) { + let backend = self.backends.get_mut(&domain).unwrap(); + backend + .remove_from_group(conversation_id, remover_name, removed_names) + .await; + } + + pub async fn leave_group( + &mut self, + domain: Fqdn, + conversation_id: Uuid, + leaver_name: impl Into, + ) { + let backend = self.backends.get_mut(&domain).unwrap(); + backend.leave_group(conversation_id, leaver_name).await; + } +}