Skip to content

Commit

Permalink
Fix SCTP send_init/handle_init race (#615)
Browse files Browse the repository at this point in the history
* sctp: Don't derive Default for Stream and AssociationInternal

These structures should never be created using defaults.
Which also makes channel senders mandatory (not Optional).

* sctp: Simplify association creation

* sctp: Log error for unexpected INIT

* sctp: Hold association lock when sending client init

This fixes a race where read_loop may process incoming init and going to
established before send_init() is called.
  • Loading branch information
haaspors authored Oct 1, 2024
1 parent c7cfe3c commit be438ed
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 187 deletions.
19 changes: 6 additions & 13 deletions data/src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct Config {
}

/// DataChannel represents a data channel
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub struct DataChannel {
pub config: Config,
stream: Arc<Stream>,
Expand All @@ -54,7 +54,11 @@ impl DataChannel {
Self {
config,
stream,
..Default::default()

messages_sent: Arc::new(AtomicUsize::default()),
messages_received: Arc::new(AtomicUsize::default()),
bytes_sent: Arc::new(AtomicUsize::default()),
bytes_received: Arc::new(AtomicUsize::default()),
}
}

Expand Down Expand Up @@ -404,17 +408,6 @@ pub struct PollDataChannel {

impl PollDataChannel {
/// Constructs a new `PollDataChannel`.
///
/// # Examples
///
/// ```
/// use webrtc_data::data_channel::{DataChannel, PollDataChannel, Config};
/// use sctp::stream::Stream;
/// use std::sync::Arc;
///
/// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default()));
/// let poll_dc = PollDataChannel::new(dc);
/// ```
pub fn new(data_channel: Arc<DataChannel>) -> Self {
Self {
data_channel,
Expand Down
166 changes: 92 additions & 74 deletions sctp/src/association/association_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ use crate::param::param_forward_tsn_supported::ParamForwardTsnSupported;
use crate::param::param_type::ParamType;
use crate::param::param_unrecognized::ParamUnrecognized;

#[derive(Default)]
pub struct AssociationInternal {
pub(crate) name: String,
pub(crate) state: Arc<AtomicU8>,
pub(crate) max_message_size: Arc<AtomicU32>,
pub(crate) inflight_queue_length: Arc<AtomicUsize>,
pub(crate) will_send_shutdown: Arc<AtomicBool>,
awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
awake_write_loop_ch: Arc<mpsc::Sender<()>>,

peer_verification_tag: u32,
pub(crate) my_verification_tag: u32,
Expand Down Expand Up @@ -77,11 +76,8 @@ pub struct AssociationInternal {
streams: HashMap<u16, Arc<Stream>>,

close_loop_ch_tx: Option<broadcast::Sender<()>>,
accept_ch_tx: Option<mpsc::Sender<Arc<Stream>>>,
handshake_completed_ch_tx: Option<mpsc::Sender<Option<Error>>>,

// local error
silent_error: Option<Error>,
accept_ch_tx: mpsc::Sender<Arc<Stream>>,
handshake_completed_ch_tx: mpsc::Sender<Option<Error>>,

// per inbound packet context
delayed_ack_triggered: bool,
Expand Down Expand Up @@ -118,58 +114,94 @@ impl AssociationInternal {
if tsn == 0 {
tsn += 1;
}
let mut a = AssociationInternal {

let mtu = INITIAL_MTU;
// RFC 4690 Sec 7.2.1
// o The initial cwnd before DATA transmission or after a sufficiently
// long idle period MUST be set to min(4*MTU, max (2*MTU, 4380
// bytes)).
// TODO: Consider whether this should use `clamp`
#[allow(clippy::manual_clamp)]
let cwnd = std::cmp::min(4 * mtu, std::cmp::max(2 * mtu, 4380));

let ret = AssociationInternal {
name: config.name,
max_receive_buffer_size,
state: Arc::new(AtomicU8::new(AssociationState::Closed as u8)),
max_message_size: Arc::new(AtomicU32::new(max_message_size)),

my_max_num_outbound_streams: u16::MAX,
will_send_shutdown: Arc::new(AtomicBool::default()),
awake_write_loop_ch,
peer_verification_tag: 0,
my_verification_tag: random::<u32>(),

my_next_tsn: tsn,
peer_last_tsn: 0,
min_tsn2measure_rtt: tsn,
will_send_forward_tsn: false,
will_retransmit_fast: false,
will_retransmit_reconfig: false,
will_send_shutdown_ack: false,
will_send_shutdown_complete: false,

my_next_rsn: tsn,
reconfigs: HashMap::new(),
reconfig_requests: HashMap::new(),

source_port: 0,
destination_port: 0,
my_max_num_inbound_streams: u16::MAX,
my_max_num_outbound_streams: u16::MAX,
my_cookie: None,
payload_queue: PayloadQueue::new(Arc::new(AtomicUsize::new(0))),
inflight_queue: PayloadQueue::new(Arc::clone(&inflight_queue_length)),
inflight_queue_length,
pending_queue: Arc::new(PendingQueue::new()),
control_queue: ControlQueue::new(),
mtu: INITIAL_MTU,
max_payload_size: INITIAL_MTU - (COMMON_HEADER_SIZE + DATA_CHUNK_HEADER_SIZE),
my_verification_tag: random::<u32>(),
my_next_tsn: tsn,
my_next_rsn: tsn,
min_tsn2measure_rtt: tsn,
state: Arc::new(AtomicU8::new(AssociationState::Closed as u8)),
mtu,
max_payload_size: mtu - (COMMON_HEADER_SIZE + DATA_CHUNK_HEADER_SIZE),
cumulative_tsn_ack_point: tsn - 1,
advanced_peer_tsn_ack_point: tsn - 1,
use_forward_tsn: false,

max_receive_buffer_size,
cwnd,
rwnd: 0,
ssthresh: 0,
partial_bytes_acked: 0,
in_fast_recovery: false,
fast_recover_exit_point: 0,

rto_mgr: RtoManager::new(),
t1init: None,
t1cookie: None,
t2shutdown: None,
t3rtx: None,
treconfig: None,
ack_timer: None,

stored_init: None,
stored_cookie_echo: None,
streams: HashMap::new(),
reconfigs: HashMap::new(),
reconfig_requests: HashMap::new(),
accept_ch_tx: Some(accept_ch_tx),
close_loop_ch_tx: Some(close_loop_ch_tx),
handshake_completed_ch_tx: Some(handshake_completed_ch_tx),
cumulative_tsn_ack_point: tsn - 1,
advanced_peer_tsn_ack_point: tsn - 1,
silent_error: Some(Error::ErrSilentlyDiscard),
accept_ch_tx,
handshake_completed_ch_tx,

delayed_ack_triggered: false,
immediate_ack_triggered: false,
stats: Arc::new(AssociationStats::default()),
awake_write_loop_ch: Some(awake_write_loop_ch),
..Default::default()
ack_state: AckState::default(),
ack_mode: AckMode::default(),
};

// RFC 4690 Sec 7.2.1
// o The initial cwnd before DATA transmission or after a sufficiently
// long idle period MUST be set to min(4*MTU, max (2*MTU, 4380
// bytes)).
// TODO: Consider whether this should use `clamp`
#[allow(clippy::manual_clamp)]
{
a.cwnd = std::cmp::min(4 * a.mtu, std::cmp::max(2 * a.mtu, 4380));
}
log::trace!(
"[{}] updated cwnd={} ssthresh={} inflight={} (INI)",
a.name,
a.cwnd,
a.ssthresh,
a.inflight_queue.get_num_bytes()
ret.name,
ret.cwnd,
ret.ssthresh,
ret.inflight_queue.get_num_bytes()
);

a
ret
}

/// caller must hold self.lock
Expand Down Expand Up @@ -291,9 +323,7 @@ impl AssociationInternal {

fn awake_write_loop(&self) {
//log::debug!("[{}] awake_write_loop_ch.notify_one", self.name);
if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch {
let _ = awake_write_loop_ch.try_send(());
}
let _ = self.awake_write_loop_ch.try_send(());
}

/// unregister_stream un-registers a stream from the association
Expand Down Expand Up @@ -606,7 +636,6 @@ impl AssociationInternal {

async fn handle_init(&mut self, p: &Packet, i: &ChunkInit) -> Result<Vec<Packet>> {
let state = self.get_state();
log::debug!("[{}] chunkInit received in state '{}'", self.name, state);

// https://tools.ietf.org/html/rfc4960#section-5.2.1
// Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST
Expand All @@ -619,11 +648,14 @@ impl AssociationInternal {
&& state != AssociationState::CookieWait
&& state != AssociationState::CookieEchoed
{
log::error!("[{}] chunkInit received in state '{}'", self.name, state);
// 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED,
// COOKIE-WAIT, and SHUTDOWN-ACK-SENT
return Err(Error::ErrHandleInitState);
}

log::debug!("[{}] chunkInit received in state '{}'", self.name, state);

// Should we be setting any of these permanently until we've ACKed further?
self.my_max_num_inbound_streams =
std::cmp::min(i.num_inbound_streams, self.my_max_num_inbound_streams);
Expand Down Expand Up @@ -855,9 +887,7 @@ impl AssociationInternal {
self.stored_cookie_echo = None;

self.set_state(AssociationState::Established);
if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx {
let _ = handshake_completed_ch.send(None).await;
}
let _ = self.handshake_completed_ch_tx.send(None).await;
}
_ => return Ok(vec![]),
};
Expand Down Expand Up @@ -891,9 +921,7 @@ impl AssociationInternal {
self.stored_cookie_echo = None;

self.set_state(AssociationState::Established);
if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx {
let _ = handshake_completed_ch.send(None).await;
}
let _ = self.handshake_completed_ch_tx.send(None).await;

Ok(vec![])
}
Expand Down Expand Up @@ -1049,22 +1077,14 @@ impl AssociationInternal {
));

if accept {
if let Some(accept_ch) = &self.accept_ch_tx {
if accept_ch.try_send(Arc::clone(&s)).is_ok() {
log::debug!(
"[{}] accepted a new stream (streamIdentifier: {})",
self.name,
stream_identifier
);
} else {
log::debug!("[{}] dropped a new stream due to accept_ch full", self.name);
return None;
}
} else {
if self.accept_ch_tx.try_send(Arc::clone(&s)).is_ok() {
log::debug!(
"[{}] dropped a new stream due to accept_ch_tx is None",
self.name
"[{}] accepted a new stream (streamIdentifier: {})",
self.name,
stream_identifier
);
} else {
log::debug!("[{}] dropped a new stream due to accept_ch full", self.name);
return None;
}
}
Expand Down Expand Up @@ -2389,19 +2409,17 @@ impl RtxTimerObserver for AssociationInternal {
match id {
RtxTimerId::T1Init => {
log::error!("[{}] retransmission failure: T1-init", self.name);
if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx {
let _ = handshake_completed_ch
.send(Some(Error::ErrHandshakeInitAck))
.await;
}
let _ = self
.handshake_completed_ch_tx
.send(Some(Error::ErrHandshakeInitAck))
.await;
}
RtxTimerId::T1Cookie => {
log::error!("[{}] retransmission failure: T1-cookie", self.name);
if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx {
let _ = handshake_completed_ch
.send(Some(Error::ErrHandshakeCookieEcho))
.await;
}
let _ = self
.handshake_completed_ch_tx
.send(Some(Error::ErrHandshakeCookieEcho))
.await;
}

RtxTimerId::T2Shutdown => {
Expand Down
Loading

0 comments on commit be438ed

Please sign in to comment.