Skip to content

Commit

Permalink
feat: Do not error when the provider closes the connection (#124)
Browse files Browse the repository at this point in the history
This makes it so that the provider is allowed to close the connection
after a transfer is completed, without the getter resulting in an
error.

It modifies the provider to only emit the transfer completed once the
data has actually reached the client.
  • Loading branch information
flub authored Feb 13, 2023
1 parent 8e42874 commit 5bd545d
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 91 deletions.
13 changes: 7 additions & 6 deletions src/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use anyhow::{anyhow, bail, ensure, Result};
use bytes::BytesMut;
use futures::Future;
use postcard::experimental::max_size::MaxSize;
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, ReadBuf};
use tracing::debug;

use crate::bao_slice_decoder::AsyncSliceDecoder;
Expand Down Expand Up @@ -116,7 +116,7 @@ impl AsyncRead for DataStream {
/// Get a collection and all its blobs from a provider
pub async fn run<A, B, C, FutA, FutB, FutC>(
hash: Hash,
token: AuthToken,
auth_token: AuthToken,
opts: Options,
on_connected: A,
on_collection: B,
Expand Down Expand Up @@ -145,7 +145,7 @@ where
// 1. Send Handshake
{
debug!("sending handshake");
let handshake = Handshake::new(token);
let handshake = Handshake::new(auth_token);
let used = postcard::to_slice(&handshake, &mut out_buffer)?;
write_lp(&mut writer, used).await?;
}
Expand All @@ -158,6 +158,8 @@ where
let used = postcard::to_slice(&req, &mut out_buffer)?;
write_lp(&mut writer, used).await?;
}
writer.finish().await?;
drop(writer);

// 3. Read response
{
Expand Down Expand Up @@ -224,8 +226,7 @@ where
}

// Shut down the stream
debug!("shutting down stream");
writer.shutdown().await?;
drop(reader);

let elapsed = now.elapsed();

Expand All @@ -234,7 +235,7 @@ where
Ok(stats)
}
None => {
bail!("provider disconnected");
bail!("provider closed stream");
}
}
}
Expand Down
92 changes: 84 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ mod tests {
net::SocketAddr,
path::PathBuf,
sync::{atomic::AtomicUsize, Arc},
time::Duration,
};

use anyhow::{anyhow, Context, Result};
use rand::RngCore;
use testdir::testdir;
use tokio::fs;
use tokio::io::{self, AsyncReadExt};

use crate::protocol::AuthToken;
use crate::provider::{create_collection, Event, Provider};
use crate::tls::PeerId;
use crate::{protocol::AuthToken, util::Hash};
use crate::util::Hash;

use super::*;
use anyhow::Result;
use rand::RngCore;
use testdir::testdir;
use tokio::io::AsyncReadExt;

#[tokio::test]
async fn basics() -> Result<()> {
Expand Down Expand Up @@ -204,7 +209,13 @@ mod tests {
let events_task = tokio::task::spawn(async move {
let mut events = Vec::new();
while let Ok(event) = provider_events.recv().await {
events.push(event);
match event {
Event::TransferCompleted { .. } => {
events.push(event);
break;
}
_ => events.push(event),
}
}
events
});
Expand Down Expand Up @@ -245,12 +256,77 @@ mod tests {
)
.await?;

// We have to wait for the completed event before shutting down the provider.
let events = tokio::time::timeout(Duration::from_secs(30), events_task)
.await
.expect("duration expired")
.expect("events task failed");
provider.shutdown();
provider.await.ok(); // .abort() makes this a Result::Err
provider.await?;

let events = events_task.await.unwrap();
assert_eq!(events.len(), 3);

Ok(())
}

#[tokio::test]
async fn test_server_close() {
// Prepare a Provider transferring a file.
let dir = testdir!();
let src = dir.join("src");
fs::write(&src, "hello there").await.unwrap();
let (db, hash) = create_collection(vec![src.into()]).await.unwrap();
let mut provider = Provider::builder(db)
.bind_addr("127.0.0.1:0".parse().unwrap())
.spawn()
.unwrap();
let auth_token = provider.auth_token();
let provider_addr = provider.listen_addr();

// This tasks closes the connection on the provider side as soon as the transfer
// completes.
let supervisor = tokio::spawn(async move {
let mut events = provider.subscribe();
loop {
tokio::select! {
biased;
res = &mut provider => break res.context("provider failed"),
maybe_event = events.recv() => {
match maybe_event {
Ok(event) => {
match event {
Event::TransferCompleted { .. } => provider.shutdown(),
Event::TransferAborted { .. } => {
break Err(anyhow!("transfer aborted"));
}
_ => (),
}
}
Err(err) => break Err(anyhow!("event failed: {err:#}")),
}
}
}
}
});

get::run(
hash,
auth_token,
get::Options {
addr: provider_addr,
peer_id: None,
},
|| async move { Ok(()) },
|_collection| async move { Ok(()) },
|_hash, mut stream, _name| async move {
io::copy(&mut stream, &mut io::sink()).await?;
Ok(stream)
},
)
.await
.unwrap();

// Unwrap the JoinHandle, then the result of the Provider
supervisor.await.unwrap().unwrap();
}
}
62 changes: 62 additions & 0 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::str::FromStr;
use anyhow::{ensure, Result};
use bytes::{Bytes, BytesMut};
use postcard::experimental::max_size::MaxSize;
use quinn::VarInt;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;
Expand Down Expand Up @@ -212,6 +213,67 @@ impl FromStr for AuthToken {
}
}

/// Reasons to close connections or stop streams.
///
/// A QUIC **connection** can be *closed* and a **stream** can request the other side to
/// *stop* sending data. Both closing and stopping have an associated `error_code`, closing
/// also adds a `reason` as some arbitrary bytes.
///
/// This enum exists so we have a single namespace for `error_code`s used.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub(crate) enum Closed {
/// The [`quinn::RecvStream`] was dropped.
///
/// Used implicitly when a [`quinn::RecvStream`] is dropped without explicit call to
/// [`quinn::RecvStream::stop`]. We don't use this explicitly but this is here as
/// documentation as to what happened to `0`.
StreamDropped = 0,
/// The provider is terminating.
///
/// When a provider terminates all connections and associated streams are closed.
ProviderTerminating = 1,
/// The provider has received the request.
///
/// Only a single request is allowed on a stream, once this request is received the
/// provider will close its [`quinn::RecvStream`] with this error code.
RequestReceived = 2,
}

impl Closed {
pub fn reason(&self) -> &'static [u8] {
match self {
Closed::StreamDropped => &b"stream dropped"[..],
Closed::ProviderTerminating => &b"provider terminating"[..],
Closed::RequestReceived => &b"request received"[..],
}
}
}

impl From<Closed> for VarInt {
fn from(source: Closed) -> Self {
VarInt::from(source as u16)
}
}

/// Unknown error_code, can not be converted into [`Closed`].
#[derive(thiserror::Error, Debug)]
#[error("Unknown error_code: {0}")]
pub(crate) struct UnknownErrorCode(u64);

impl TryFrom<VarInt> for Closed {
type Error = UnknownErrorCode;

fn try_from(value: VarInt) -> std::result::Result<Self, Self::Error> {
match value.into_inner() {
0 => Ok(Self::StreamDropped),
1 => Ok(Self::ProviderTerminating),
2 => Ok(Self::RequestReceived),
val => Err(UnknownErrorCode(val)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 5bd545d

Please sign in to comment.