Skip to content

Commit

Permalink
federation prep
Browse files Browse the repository at this point in the history
  • Loading branch information
kkohbrok committed Jul 19, 2023
1 parent 011daa3 commit b7e6a8d
Show file tree
Hide file tree
Showing 29 changed files with 716 additions and 355 deletions.
45 changes: 31 additions & 14 deletions backend/src/auth_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use opaque_ke::{
RegistrationResponse, RegistrationUpload, ServerRegistration,
};
use serde::{Deserialize, Serialize};
use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize};
use tls_codec::{
DeserializeBytes as TlsDeserializeTrait, Serialize as TlsSerializeTrait, TlsDeserializeBytes,
TlsSerialize, TlsSize,
};

use crate::{
crypto::{ratchet::QueueRatchet, OpaqueCiphersuite, RandomnessError, RatchetEncryptionKey},
Expand All @@ -25,6 +28,7 @@ use crate::{
client_as_out::VerifiableClientToAsMessage,
EncryptedAsQueueMessage,
},
qs::Fqdn,
};

use self::{
Expand All @@ -42,7 +46,6 @@ pub mod invitations;
pub mod key_packages;
pub mod registration;
pub mod storage_provider_trait;
pub mod username;

/*
Actions:
Expand Down Expand Up @@ -146,39 +149,53 @@ impl AsUserRecord {
)]
pub struct UserName {
pub(crate) user_name: Vec<u8>,
pub(crate) domain: Fqdn,
}

impl From<Vec<u8>> for UserName {
fn from(value: Vec<u8>) -> Self {
Self { user_name: value }
Self::tls_deserialize_exact(&value).unwrap()
}
}

// TODO: This string processing is way too simplistic, but it should do for now.
impl From<&str> for UserName {
fn from(value: &str) -> Self {
Self {
user_name: value.as_bytes().to_vec(),
}
let mut split_name = value.split('@');
let name = split_name.next().unwrap();
// UserNames MUST be qualified
let domain = split_name.next().unwrap();
assert!(split_name.next().is_none());
let domain = domain.into();
let user_name = name.as_bytes().to_vec();
Self { user_name, domain }
}
}

impl UserName {
pub fn as_bytes(&self) -> &[u8] {
&self.user_name
pub fn to_bytes(&self) -> Vec<u8> {
self.tls_serialize_detached().unwrap()
}

pub fn domain(&self) -> Fqdn {
self.domain.clone()
}
}

impl From<String> for UserName {
fn from(value: String) -> Self {
Self {
user_name: value.into_bytes(),
}
value.as_str().into()
}
}

impl ToString for UserName {
fn to_string(&self) -> String {
String::from_utf8(self.user_name.clone()).unwrap()
impl std::fmt::Display for UserName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}@{}",
String::from_utf8_lossy(&self.user_name),
self.domain
)
}
}

Expand Down
24 changes: 0 additions & 24 deletions backend/src/auth_service/username.rs

This file was deleted.

6 changes: 3 additions & 3 deletions backend/src/messages/intra_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
//! passed internally within the backend.

use mls_assist::messages::SerializedMlsMessage;
use tls_codec::{TlsDeserializeBytes, TlsSize};
use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize};

use crate::qs::QsClientReference;

use super::client_ds::{EventMessage, QsQueueMessagePayload};

// === DS to QS ===

#[derive(TlsDeserializeBytes, TlsSize)]
#[derive(TlsSerialize, TlsDeserializeBytes, TlsSize)]
pub struct DsFanOutMessage {
pub payload: DsFanOutPayload,
pub client_reference: QsClientReference,
}

#[derive(Clone, TlsDeserializeBytes, TlsSize)]
#[derive(Clone, TlsSerialize, TlsDeserializeBytes, TlsSize)]
#[repr(u8)]
pub enum DsFanOutPayload {
QueueMessage(QsQueueMessagePayload),
Expand Down
1 change: 1 addition & 0 deletions backend/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod client_ds_out;
pub mod client_qs;
pub mod client_qs_out;
pub mod intra_backend;
pub mod qs_qs;

#[derive(
Serialize,
Expand Down
15 changes: 15 additions & 0 deletions backend/src/messages/qs_qs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH <[email protected]>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use crate::qs::Fqdn;

use super::{intra_backend::DsFanOutMessage, MlsInfraVersion};

pub struct QsToQsMessage {
protocol_version: MlsInfraVersion,
sender: Fqdn,
recipient: Fqdn,
fan_out_message: DsFanOutMessage,
// TODO: Signature
}
2 changes: 1 addition & 1 deletion backend/src/qs/client_api/key_packages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl Qs {
.map_err(|_| QsKeyPackageBatchError::StorageError)?;

let key_package_batch_tbs = KeyPackageBatchTbs {
homeserver_domain: config.fqdn.clone(),
homeserver_domain: config.domain.clone(),
key_package_refs,
time_of_signature: TimeStamp::now(),
};
Expand Down
17 changes: 17 additions & 0 deletions backend/src/qs/dns_provider_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH <[email protected]>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use async_trait::async_trait;
use std::error::Error;
use std::fmt::Debug;
use std::net::SocketAddr;

use super::Fqdn;

#[async_trait]
pub trait DnsProvider: Sync + Send + 'static {
type DnsError: Error + Debug + Clone;

async fn resolve(&self, fqdn: &Fqdn) -> Result<SocketAddr, Self::DnsError>;
}
34 changes: 27 additions & 7 deletions backend/src/qs/ds_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use tls_codec::Serialize;

use crate::{crypto::hpke::HpkeDecryptable, messages::intra_backend::DsFanOutMessage};

use super::{
errors::QsEnqueueError, storage_provider_trait::QsStorageProvider, ClientConfig, Qs,
WebsocketNotifier,
errors::QsEnqueueError, network_provider_trait::NetworkProvider,
storage_provider_trait::QsStorageProvider, ClientConfig, Qs, WebsocketNotifier,
};

impl Qs {
Expand All @@ -15,15 +17,33 @@ impl Qs {
/// quickly. It can attempt to do the full fanout and return potential
/// failed transmissions to the DS.
///
/// This endpoint is used for enqueining
/// messages in both local and remote queues, depending on the FQDN of the
/// client. For now, only local queues are supported.
/// This endpoint is used for enqueining messages in both local and remote
/// queues, depending on the FQDN of the client.
#[tracing::instrument(skip_all, err)]
pub async fn enqueue_message<S: QsStorageProvider, W: WebsocketNotifier>(
pub async fn enqueue_message<S: QsStorageProvider, W: WebsocketNotifier, N: NetworkProvider>(
storage_provider: &S,
websocket_notifier: &W,
network_provider: &N,
message: DsFanOutMessage,
) -> Result<(), QsEnqueueError<S>> {
) -> Result<(), QsEnqueueError<S, N>> {
if message.client_reference.client_homeserver_domain != storage_provider.own_domain().await
{
tracing::info!(
"Domains differ. Destination domain: {:?}, own domain: {:?}",
message.client_reference.client_homeserver_domain,
storage_provider.own_domain().await
);
let serialized_message = message
.tls_serialize_detached()
.map_err(|_| QsEnqueueError::LibraryError)?;
network_provider
.deliver(
serialized_message,
message.client_reference.client_homeserver_domain,
)
.await
.map_err(QsEnqueueError::NetworkError)?;
}
let decryption_key = storage_provider
.load_decryption_key()
.await
Expand Down
10 changes: 8 additions & 2 deletions backend/src/qs/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

use crate::crypto::DecryptionError;

use super::storage_provider_trait::QsStorageProvider;
use super::{network_provider_trait::NetworkProvider, storage_provider_trait::QsStorageProvider};
use thiserror::Error;
use tls_codec::{TlsDeserializeBytes, TlsSerialize, TlsSize};

// === DS API errors ===

/// Error fetching a message from the QS.
#[derive(Error, Debug, Clone)]
pub enum QsEnqueueError<S: QsStorageProvider> {
pub enum QsEnqueueError<S: QsStorageProvider, N: NetworkProvider> {
/// Couldn't find the requested queue.
#[error("Couldn't find the requested queue")]
QueueNotFound,
Expand All @@ -22,9 +22,15 @@ pub enum QsEnqueueError<S: QsStorageProvider> {
/// An error ocurred enqueueing in a fan out queue
#[error(transparent)]
EnqueueError(#[from] EnqueueError<S>),
/// An error ocurred while sending a message to the network
#[error("An error ocurred while sending a message to the network")]
NetworkError(N::NetworkError),
/// Storage provider error
#[error("Storage provider error")]
StorageError,
/// Unrecoverable implementation error
#[error("Library Error")]
LibraryError,
}

/// Error enqueuing a fanned-out message.
Expand Down
37 changes: 35 additions & 2 deletions backend/src/qs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
//! smaller than the smalles requested one and responds with the requested
//! messages.

use std::fmt::{Display, Formatter};

use crate::{
crypto::{
ear::{
Expand Down Expand Up @@ -95,8 +97,10 @@ use self::errors::SealError;

pub mod client_api;
pub mod client_record;
pub mod dns_provider_trait;
pub mod ds_api;
pub mod errors;
pub mod network_provider_trait;
pub mod storage_provider_trait;
pub mod user_record;

Expand Down Expand Up @@ -360,7 +364,7 @@ impl VerifyingKey for QsVerifyingKey {}

#[derive(Debug, Clone)]
pub struct QsConfig {
pub fqdn: Fqdn,
pub domain: Fqdn,
}

#[derive(Debug)]
Expand All @@ -379,7 +383,36 @@ pub struct Qs {}
Hash,
Debug,
)]
pub struct Fqdn {}
pub struct Fqdn {
// TODO: We should probably use a more restrictive type here.
domain: Vec<u8>,
}

impl Display for Fqdn {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", String::from_utf8_lossy(&self.domain))
}
}

impl Fqdn {
pub fn new(domain: String) -> Self {
Self {
domain: domain.into_bytes(),
}
}

pub fn as_bytes(&self) -> &[u8] {
&self.domain
}
}

impl From<&str> for Fqdn {
fn from(domain: &str) -> Self {
Self {
domain: domain.as_bytes().to_vec(),
}
}
}

#[derive(
Clone,
Expand Down
16 changes: 16 additions & 0 deletions backend/src/qs/network_provider_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH <[email protected]>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

use async_trait::async_trait;
use std::error::Error;
use std::fmt::Debug;

use super::Fqdn;

#[async_trait]
pub trait NetworkProvider: Sync + Send + Debug + 'static {
type NetworkError: Error + Debug + Clone;

async fn deliver(&self, bytes: Vec<u8>, destination: Fqdn) -> Result<(), Self::NetworkError>;
}
3 changes: 3 additions & 0 deletions backend/src/qs/qs_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// SPDX-FileCopyrightText: 2023 Phoenix R&D GmbH <[email protected]>
//
// SPDX-License-Identifier: AGPL-3.0-or-later
6 changes: 4 additions & 2 deletions backend/src/qs/storage_provider_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use tls_codec::{DeserializeBytes, Serialize, Size};
use crate::messages::{FriendshipToken, QueueMessage};

use super::{
client_record::QsClientRecord, user_record::QsUserRecord, ClientIdDecryptionKey, QsClientId,
QsConfig, QsEncryptedAddPackage, QsSigningKey, QsUserId,
client_record::QsClientRecord, user_record::QsUserRecord, ClientIdDecryptionKey, Fqdn,
QsClientId, QsConfig, QsEncryptedAddPackage, QsSigningKey, QsUserId,
};

/// Storage provider trait for the QS.
Expand All @@ -35,6 +35,8 @@ pub trait QsStorageProvider: Sync + Send + Debug + 'static {

type LoadConfigError: Error + Debug + Clone;

async fn own_domain(&self) -> Fqdn;

// === USERS ===

/// Returns a new unique user ID.
Expand Down
Loading

0 comments on commit b7e6a8d

Please sign in to comment.