Skip to content

Commit

Permalink
Merge branch 'konrad/identifiers' into konrad/as_provider
Browse files Browse the repository at this point in the history
  • Loading branch information
kkohbrok committed Sep 12, 2024
2 parents 291bff7 + 11ec6a5 commit 570642c
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 201 deletions.
12 changes: 4 additions & 8 deletions backend/src/ds/add_clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,20 @@ use tls_codec::DeserializeBytes;

use crate::messages::intra_backend::{DsFanOutMessage, DsFanOutPayload};

use super::{
group_state::ClientProfile,
process::{Provider, USER_EXPIRATION_DAYS},
};
use super::{group_state::ClientProfile, process::USER_EXPIRATION_DAYS};

use super::group_state::DsGroupState;

impl DsGroupState {
pub(crate) fn add_clients(
&mut self,
provider: &Provider,
params: AddClientsParams,
group_state_ear_key: &GroupStateEarKey,
) -> Result<(SerializedMlsMessage, Vec<DsFanOutMessage>), ClientAdditionError> {
// Process message (but don't apply it yet). This performs mls-assist-level validations.
let processed_assisted_message_plus = self
.group()
.process_assisted_message(provider.crypto(), params.commit)
.process_assisted_message(self.provider.crypto(), params.commit)
.map_err(|_| ClientAdditionError::ProcessingError)?;

// Perform DS-level validation
Expand Down Expand Up @@ -127,8 +123,8 @@ impl DsGroupState {
// Now we have to update the group state and distribute.

// We first accept the message into the group state ...
self.group_mut().accept_processed_message(
provider.storage(),
self.group.accept_processed_message(
self.provider.storage(),
processed_assisted_message_plus.processed_assisted_message,
Duration::days(USER_EXPIRATION_DAYS),
)?;
Expand Down
12 changes: 4 additions & 8 deletions backend/src/ds/add_users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,21 @@ use crate::{
qs::QsConnector,
};

use super::{
group_state::ClientProfile,
process::{Provider, USER_EXPIRATION_DAYS},
};
use super::{group_state::ClientProfile, process::USER_EXPIRATION_DAYS};

use super::group_state::DsGroupState;

impl DsGroupState {
pub(crate) async fn add_users<Q: QsConnector>(
&mut self,
provider: &Provider,
params: AddUsersParams,
group_state_ear_key: &GroupStateEarKey,
qs_provider: &Q,
) -> Result<(SerializedMlsMessage, Vec<DsFanOutMessage>), AddUsersError> {
// Process message (but don't apply it yet). This performs mls-assist-level validations.
let processed_assisted_message_plus = self
.group()
.process_assisted_message(provider.crypto(), params.commit)
.process_assisted_message(self.provider.crypto(), params.commit)
.map_err(|e| {
tracing::warn!("Error processing assisted message: {:?}", e);
AddUsersError::ProcessingError
Expand Down Expand Up @@ -219,8 +215,8 @@ impl DsGroupState {
// Now we have to update the group state and distribute.

// We first accept the message into the group state ...
self.group_mut().accept_processed_message(
provider.storage(),
self.group.accept_processed_message(
self.provider.storage(),
processed_assisted_message_plus.processed_assisted_message,
Duration::days(USER_EXPIRATION_DAYS),
)?;
Expand Down
5 changes: 2 additions & 3 deletions backend/src/ds/delete_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,17 @@ use mls_assist::{
};
use phnxtypes::{errors::GroupDeletionError, messages::client_ds::DeleteGroupParams};

use super::{group_state::DsGroupState, process::Provider};
use super::group_state::DsGroupState;

impl DsGroupState {
pub(crate) fn delete_group(
&mut self,
provider: &Provider,
params: DeleteGroupParams,
) -> Result<SerializedMlsMessage, GroupDeletionError> {
// Process message (but don't apply it yet). This performs mls-assist-level validations.
let processed_assisted_message_plus = self
.group()
.process_assisted_message(provider.crypto(), params.commit)
.process_assisted_message(self.provider.crypto(), params.commit)
.map_err(|_| GroupDeletionError::ProcessingError)?;

// Perform DS-level validation
Expand Down
172 changes: 113 additions & 59 deletions backend/src/ds/group_state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
use std::collections::{BTreeMap, HashMap, HashSet};

use mls_assist::{
group::{errors::StorageError as MlsAssistStorageError, Group},
group::Group,
openmls::{
group::GroupId,
prelude::{GroupEpoch, LeafNodeIndex, QueuedRemoveProposal, Sender},
treesync::RatchetTree,
},
provider_traits::MlsAssistProvider,
MlsAssistRustCrypto,
};
use phnxtypes::{
codec::PhnxCodec,
credentials::EncryptedClientCredential,
crypto::{
ear::{
keys::{EncryptedSignatureEarKey, GroupStateEarKey},
Ciphertext, EarDecryptable, EarEncryptable,
},
errors::{DecryptionError, EncryptionError},
signatures::keys::{UserAuthVerifyingKey, UserKeyHash},
},
errors::{CborMlsAssistStorage, UpdateQueueConfigError, ValidationError},
Expand All @@ -28,6 +32,7 @@ use phnxtypes::{
};
use serde::{Deserialize, Serialize};
use sqlx::PgExecutor;
use thiserror::Error;
use uuid::Uuid;

use crate::persistence::StorageError;
Expand All @@ -52,69 +57,14 @@ pub(super) struct ClientProfile {
pub(super) activity_epoch: GroupEpoch,
}

#[derive(Serialize, Deserialize)]
pub(super) struct ProposalStore {}

#[derive(Serialize, Deserialize)]
pub(crate) struct SerializableDsGroupState {
pub(super) group_id: GroupId,
pub(super) serialized_provider: Vec<u8>,
pub(super) user_profiles: Vec<(UserKeyHash, UserProfile)>,
pub(super) unmerged_users: Vec<Vec<LeafNodeIndex>>,
pub(super) client_profiles: Vec<(LeafNodeIndex, ClientProfile)>,
}

impl SerializableDsGroupState {
pub(super) fn from_group_and_provider(
group_state: DsGroupState,
provider: &CborMlsAssistStorage,
) -> Result<Self, MlsAssistStorageError<CborMlsAssistStorage>> {
let group_id = group_state
.group()
.group_info()
.group_context()
.group_id()
.clone();
let user_profiles = group_state.user_profiles.into_iter().collect();
let client_profiles = group_state.client_profiles.into_iter().collect();
let serialized_provider = provider.serialize()?;
Ok(Self {
group_id,
serialized_provider,
user_profiles,
unmerged_users: group_state.unmerged_users,
client_profiles,
})
}

pub(super) fn into_group_state_and_provider(
self,
) -> Result<(DsGroupState, CborMlsAssistStorage), MlsAssistStorageError<CborMlsAssistStorage>>
{
let provider = CborMlsAssistStorage::deserialize(&self.serialized_provider)?;
// We unwrap here, because the constructor ensures that `self` always stores a group
let group = Group::load(&provider, &self.group_id)?.unwrap();
let user_profiles = self.user_profiles.into_iter().collect();
let client_profiles = self.client_profiles.into_iter().collect();
Ok((
DsGroupState {
group,
user_profiles,
unmerged_users: self.unmerged_users,
client_profiles,
},
provider,
))
}
}

/// The `DsGroupState` is the per-group state that the DS persists.
/// It is encrypted-at-rest with a roster key.
///
/// TODO: Past group states are now included in mls-assist. However, we might
/// have to store client credentials externally.
pub(crate) struct DsGroupState {
pub(super) group: Group,
pub(super) provider: MlsAssistRustCrypto<PhnxCodec>,
pub(super) user_profiles: HashMap<UserKeyHash, UserProfile>,
// Here we keep users that haven't set their user key yet.
pub(super) unmerged_users: Vec<Vec<LeafNodeIndex>>,
Expand All @@ -124,6 +74,7 @@ pub(crate) struct DsGroupState {
impl DsGroupState {
//#[instrument(level = "trace", skip_all)]
pub(crate) fn new(
provider: MlsAssistRustCrypto<PhnxCodec>,
group: Group,
creator_user_auth_key: UserAuthVerifyingKey,
creator_encrypted_client_credential: EncryptedClientCredential,
Expand All @@ -149,6 +100,7 @@ impl DsGroupState {
};
let client_profiles = [(LeafNodeIndex::new(0u32), creator_client_profile)].into();
Self {
provider,
group,
user_profiles,
client_profiles,
Expand Down Expand Up @@ -280,6 +232,41 @@ impl DsGroupState {
}
client_information
}

pub(super) fn encrypt(
self,
ear_key: &GroupStateEarKey,
) -> Result<EncryptedDsGroupState, DsGroupStateEncryptionError> {
let encrypted =
EncryptableDsGroupState::from(SerializableDsGroupState::from_group_state(self)?)
.encrypt(ear_key)?;
Ok(encrypted)
}

pub(super) fn decrypt(
encrypted_group_state: &EncryptedDsGroupState,
ear_key: &GroupStateEarKey,
) -> Result<Self, DsGroupStateDecryptionError> {
let encryptable = EncryptableDsGroupState::decrypt(ear_key, encrypted_group_state)?;
let group_state = SerializableDsGroupState::into_group_state(encryptable.into())?;
Ok(group_state)
}
}

#[derive(Debug, Error)]
pub(super) enum DsGroupStateEncryptionError {
#[error("Error decrypting group state: {0}")]
EncryptionError(#[from] EncryptionError),
#[error("Error deserializing group state: {0}")]
DeserializationError(#[from] phnxtypes::codec::Error),
}

#[derive(Debug, Error)]
pub(super) enum DsGroupStateDecryptionError {
#[error("Error decrypting group state: {0}")]
DecryptionError(#[from] DecryptionError),
#[error("Error deserializing group state: {0}")]
DeserializationError(#[from] phnxtypes::codec::Error),
}

#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
Expand Down Expand Up @@ -327,5 +314,72 @@ impl AsRef<Ciphertext> for EncryptedDsGroupState {
}
}

impl EarEncryptable<GroupStateEarKey, EncryptedDsGroupState> for SerializableDsGroupState {}
impl EarDecryptable<GroupStateEarKey, EncryptedDsGroupState> for SerializableDsGroupState {}
#[derive(Serialize, Deserialize)]
pub(crate) struct SerializableDsGroupState {
group_id: GroupId,
serialized_provider: Vec<u8>,
user_profiles: Vec<(UserKeyHash, UserProfile)>,
unmerged_users: Vec<Vec<LeafNodeIndex>>,
client_profiles: Vec<(LeafNodeIndex, ClientProfile)>,
}

impl SerializableDsGroupState {
pub(super) fn from_group_state(
group_state: DsGroupState,
) -> Result<Self, phnxtypes::codec::Error> {
let group_id = group_state
.group()
.group_info()
.group_context()
.group_id()
.clone();
let user_profiles = group_state.user_profiles.into_iter().collect();
let client_profiles = group_state.client_profiles.into_iter().collect();
let serialized_provider = group_state.provider.storage().serialize()?;
Ok(Self {
group_id,
serialized_provider,
user_profiles,
unmerged_users: group_state.unmerged_users,
client_profiles,
})
}

pub(super) fn into_group_state(self) -> Result<DsGroupState, phnxtypes::codec::Error> {
let storage = CborMlsAssistStorage::deserialize(&self.serialized_provider)?;
// We unwrap here, because the constructor ensures that `self` always stores a group
let group = Group::load(&storage, &self.group_id)?.unwrap();
let user_profiles = self.user_profiles.into_iter().collect();
let client_profiles = self.client_profiles.into_iter().collect();
let provider = MlsAssistRustCrypto::from(storage);
Ok(DsGroupState {
provider,
group,
user_profiles,
unmerged_users: self.unmerged_users,
client_profiles,
})
}
}

#[derive(Serialize, Deserialize)]
pub(super) enum EncryptableDsGroupState {
V1(SerializableDsGroupState),
}

impl From<EncryptableDsGroupState> for SerializableDsGroupState {
fn from(encryptable: EncryptableDsGroupState) -> Self {
match encryptable {
EncryptableDsGroupState::V1(serializable) => serializable,
}
}
}

impl From<SerializableDsGroupState> for EncryptableDsGroupState {
fn from(serializable: SerializableDsGroupState) -> Self {
EncryptableDsGroupState::V1(serializable)
}
}

impl EarEncryptable<GroupStateEarKey, EncryptedDsGroupState> for EncryptableDsGroupState {}
impl EarDecryptable<GroupStateEarKey, EncryptedDsGroupState> for EncryptableDsGroupState {}
9 changes: 4 additions & 5 deletions backend/src/ds/join_connection_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@ use tls_codec::DeserializeBytes;

use super::{
group_state::{ClientProfile, DsGroupState, UserProfile},
process::{Provider, USER_EXPIRATION_DAYS},
process::USER_EXPIRATION_DAYS,
};

impl DsGroupState {
pub(super) fn join_connection_group(
&mut self,
provider: &Provider,
params: JoinConnectionGroupParams,
) -> Result<SerializedMlsMessage, JoinConnectionGroupError> {
// Process message (but don't apply it yet). This performs mls-assist-level validations.
let processed_assisted_message_plus = self
.group()
.process_assisted_message(provider.crypto(), params.external_commit)
.process_assisted_message(self.provider.crypto(), params.external_commit)
.map_err(|e| {
tracing::warn!(
"Processing error: Could not process assisted message: {:?}",
Expand Down Expand Up @@ -87,8 +86,8 @@ impl DsGroupState {
let sender_credential = processed_message.credential().clone();

// Finalize processing.
self.group_mut().accept_processed_message(
provider.storage(),
self.group.accept_processed_message(
self.provider.storage(),
processed_assisted_message_plus.processed_assisted_message,
Duration::days(USER_EXPIRATION_DAYS),
)?;
Expand Down
Loading

0 comments on commit 570642c

Please sign in to comment.