From 90ae01673bf382fa9e2d3bc87b1dbf3b38631eb7 Mon Sep 17 00:00:00 2001 From: Hayden Stainsby Date: Wed, 6 Sep 2023 15:42:53 +0200 Subject: [PATCH] test(subscriber): add initial integration tests (#452) The `console-subscriber` crate has no integration tests. There are some unit tests, but without very high coverage of features. Recently, we've found or fixed a few errors which probably could have been caught by a medium level of integration testing. However, testing `console-subscriber` isn't straight forward. It is effectively a tracing subscriber (or layer) on one end, and a gRPC server on the other end. This change adds enough of a testing framework to write some initial integration tests. It is the first step towards closing #450. Each test comprises 2 parts: - One or more "expected tasks" - A future which will be driven to completion on a dedicated Tokio runtime. Behind the scenes, a console subscriber layer is created and its server part is connected to a duplex stream. The client of the duplex stream then records incoming updates and reconstructs "actual tasks". The layer itself is set as the default subscriber for the duration of `block_on` which is used to drive the provided future to completioin. The expected tasks have a set of "matches", which is how we find the actual task that we want to validate against. Currently, the only value we match on is the task's name. The expected tasks also have a set of "expectations". These are other fields on the actual task which are validated once a matching task is found. Currently, the two fields which can have expectations set on them are `wakes` and `self_wakes`. So, to construct an expected task, which will match a task with the name `"my-task"` and then validate that the matched task gets woken once, the code would be: ```rust ExpectedTask::default() .match_name("my-task") .expect_wakes(1); ``` A future which passes this test could be: ```rust async { task::Builder::new() .name("my-task") .spawn(async { tokio::time::sleep(std::time::Duration::ZERO).await }) } ``` The full test would then look like: ```rust fn wakes_once() { let expected_task = ExpectedTask::default() .match_name("my-task") .expect_wakes(1); let future = async { task::Builder::new() .name("my-task") .spawn(async { tokio::time::sleep(std::time::Duration::ZERO).await }) }; assert_task(expected_task, future); } ``` The PR depends on 2 others: - #447 which fixes an error in the logic that determines whether a task is retained in the aggregator or not. - #451 which exposes the server parts and is necessary to allow us to connect the instrument server and client via a duplex channel. This change contains some initial tests for wakes and self wakes which would have caught the error fixed in #430. Additionally there are tests for the functionality of the testing framework itself. Co-authored-by: Eliza Weisman --- Cargo.lock | 1 + console-subscriber/Cargo.toml | 1 + console-subscriber/tests/framework.rs | 184 ++++++++++ console-subscriber/tests/support/mod.rs | 47 +++ console-subscriber/tests/support/state.rs | 143 ++++++++ .../tests/support/subscriber.rs | 339 ++++++++++++++++++ console-subscriber/tests/support/task.rs | 242 +++++++++++++ console-subscriber/tests/wake.rs | 48 +++ 8 files changed, 1005 insertions(+) create mode 100644 console-subscriber/tests/framework.rs create mode 100644 console-subscriber/tests/support/mod.rs create mode 100644 console-subscriber/tests/support/state.rs create mode 100644 console-subscriber/tests/support/subscriber.rs create mode 100644 console-subscriber/tests/support/task.rs create mode 100644 console-subscriber/tests/wake.rs diff --git a/Cargo.lock b/Cargo.lock index 062c2e45a..e78a52f1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -285,6 +285,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tower", "tracing", "tracing-core", "tracing-subscriber", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index c69eedcc8..1bfb60e24 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -55,6 +55,7 @@ crossbeam-channel = "0.5" [dev-dependencies] tokio = { version = "^1.21", features = ["full", "rt-multi-thread"] } +tower = { version = "0.4", default-features = false } futures = "0.3" [package.metadata.docs.rs] diff --git a/console-subscriber/tests/framework.rs b/console-subscriber/tests/framework.rs new file mode 100644 index 000000000..855f778ac --- /dev/null +++ b/console-subscriber/tests/framework.rs @@ -0,0 +1,184 @@ +//! Framework tests +//! +//! The tests in this module are here to verify the testing framework itself. +//! As such, some of these tests may be repeated elsewhere (where we wish to +//! actually test the functionality of `console-subscriber`) and others are +//! negative tests that should panic. + +use std::time::Duration; + +use tokio::{task, time::sleep}; + +mod support; +use support::{assert_task, assert_tasks, ExpectedTask}; + +#[test] +fn expect_present() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_present(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=console-test::main }: no expectations set, if you want to just expect that a matching task is present, use `expect_present()` +")] +fn fail_no_expectations() { + let expected_task = ExpectedTask::default().match_default_name(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=console-test::main }: expected `wakes` to be 5, but actual was 1 +")] +fn fail_wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(5); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=console-test::main }: expected `self_wakes` to be 1, but actual was 0 +")] +fn fail_self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn test_spawned_task() { + let expected_task = ExpectedTask::default() + .match_name("another-name".into()) + .expect_present(); + + let future = async { + task::Builder::new() + .name("another-name") + .spawn(async { task::yield_now().await }) + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=wrong-name }: no matching actual task was found +")] +fn fail_wrong_task_name() { + let expected_task = ExpectedTask::default().match_name("wrong-name".into()); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +fn multiple_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(1), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=task-2 }: expected `wakes` to be 2, but actual was 1 +")] +fn fail_1_of_2_expected_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(2), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} diff --git a/console-subscriber/tests/support/mod.rs b/console-subscriber/tests/support/mod.rs new file mode 100644 index 000000000..4937aff6a --- /dev/null +++ b/console-subscriber/tests/support/mod.rs @@ -0,0 +1,47 @@ +use futures::Future; + +mod state; +mod subscriber; +mod task; + +use subscriber::run_test; + +pub(crate) use subscriber::MAIN_TASK_NAME; +pub(crate) use task::ExpectedTask; + +/// Assert that an `expected_task` is recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// This function is equivalent to calling [`assert_tasks`] with a vector +/// containing a single task. +/// +/// # Panics +/// +/// This function will panic if the expectations on the expected task are not +/// met or if a matching task is not recorded. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_task(expected_task: ExpectedTask, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(vec![expected_task], future) +} + +/// Assert that the `expected_tasks` are recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// # Panics +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_tasks(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(expected_tasks, future) +} diff --git a/console-subscriber/tests/support/state.rs b/console-subscriber/tests/support/state.rs new file mode 100644 index 000000000..97aef9d48 --- /dev/null +++ b/console-subscriber/tests/support/state.rs @@ -0,0 +1,143 @@ +use std::fmt; + +use tokio::sync::broadcast::{ + self, + error::{RecvError, TryRecvError}, +}; + +/// A step in the running of the test +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub(super) enum TestStep { + /// The overall test has begun + Start, + /// The instrument server has been started + ServerStarted, + /// The client has connected to the instrument server + ClientConnected, + /// The future being driven has completed + TestFinished, + /// The client has finished recording updates + UpdatesRecorded, +} + +impl fmt::Display for TestStep { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +/// The state of the test. +/// +/// This struct is used by various parts of the test framework to wait until +/// a specific test step has been reached and advance the test state to a new +/// step. +pub(super) struct TestState { + receiver: broadcast::Receiver, + sender: broadcast::Sender, + step: TestStep, +} + +impl TestState { + pub(super) fn new() -> Self { + let (sender, receiver) = broadcast::channel(1); + Self { + receiver, + sender, + step: TestStep::Start, + } + } + + /// Wait asynchronously until the desired step has been reached. + /// + /// # Panics + /// + /// This function will panic if the underlying channel gets closed. + pub(super) async fn wait_for_step(&mut self, desired_step: TestStep) { + while self.step < desired_step { + match self.receiver.recv().await { + Ok(step) => self.step = step, + Err(RecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(RecvError::Closed) => panic!( + "console-test error: failed to receive current step, \ + waiting for step: {desired_step}. This shouldn't happen, \ + did the test abort?" + ), + } + } + } + + /// Returns `true` if the current step is `desired_step` or later. + pub(super) fn is_step(&mut self, desired_step: TestStep) -> bool { + self.update_step(); + + self.step == desired_step + } + + /// Advance to the next step. + /// + /// The test must be at the step prior to the next step before starting. + /// Being in a different step is likely to indicate a logic error in the + /// test framework. + /// + /// # Panics + /// + /// This method will panic if the test state is not at the step prior to + /// `next_step`, or if the underlying channel is closed. + #[track_caller] + pub(super) fn advance_to_step(&mut self, next_step: TestStep) { + self.update_step(); + + assert!( + self.step < next_step, + "console-test error: cannot advance to previous or current step! \ + current step: {current}, next step: {next_step}. This shouldn't \ + happen.", + current = self.step, + ); + + match (&self.step, &next_step) { + (TestStep::Start, TestStep::ServerStarted) + | (TestStep::ServerStarted, TestStep::ClientConnected) + | (TestStep::ClientConnected, TestStep::TestFinished) + | (TestStep::TestFinished, TestStep::UpdatesRecorded) => {} + (current, _) => panic!( + "console-test error: test cannot advance more than one step! \ + current step: {current}, next step: {next_step}. This \ + shouldn't happen." + ), + } + + self.sender.send(next_step).expect( + "console-test error: failed to send the next test step. \ + This shouldn't happen, did the test abort?", + ); + } + + fn update_step(&mut self) { + loop { + match self.receiver.try_recv() { + Ok(step) => self.step = step, + Err(TryRecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(TryRecvError::Closed) => panic!( + "console-test error: failed to update current step, did \ + the test abort?" + ), + Err(TryRecvError::Empty) => break, + } + } + } +} + +impl Clone for TestState { + fn clone(&self) -> Self { + Self { + receiver: self.receiver.resubscribe(), + sender: self.sender.clone(), + step: self.step.clone(), + } + } +} diff --git a/console-subscriber/tests/support/subscriber.rs b/console-subscriber/tests/support/subscriber.rs new file mode 100644 index 000000000..ace48397d --- /dev/null +++ b/console-subscriber/tests/support/subscriber.rs @@ -0,0 +1,339 @@ +use std::{collections::HashMap, fmt, future::Future, thread}; + +use console_api::{ + field::Value, + instrument::{instrument_client::InstrumentClient, InstrumentRequest}, +}; +use console_subscriber::ServerParts; +use futures::stream::StreamExt; +use tokio::{io::DuplexStream, task}; +use tonic::transport::{Channel, Endpoint, Server, Uri}; +use tower::service_fn; + +use super::state::{TestState, TestStep}; +use super::task::{ActualTask, ExpectedTask, TaskValidationFailure}; + +pub(crate) const MAIN_TASK_NAME: &str = "console-test::main"; +const END_SIGNAL_TASK_NAME: &str = "console-test::signal"; + +#[derive(Debug)] +struct TestFailure { + failures: Vec, +} + +impl fmt::Display for TestFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Task validation failed:\n")?; + for failure in &self.failures { + write!(f, " - {failure}\n")?; + } + Ok(()) + } +} + +/// Runs the test +/// +/// This function runs the whole test. It sets up a `console-subscriber` layer +/// together with the gRPC server and connects a client to it. The subscriber +/// is then used to record traces as the provided future is driven to +/// completion on a current thread tokio runtime. +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +pub(super) fn run_test(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + use tracing_subscriber::prelude::*; + + let (client_stream, server_stream) = tokio::io::duplex(1024); + let (console_layer, server) = console_subscriber::ConsoleLayer::builder().build(); + let registry = tracing_subscriber::registry().with(console_layer); + + let mut test_state = TestState::new(); + let mut test_state_test = test_state.clone(); + + let thread_name = { + // Include the name of the test thread in the spawned subscriber thread, + // to make it clearer which test it belongs to. + let current_thread = thread::current(); + let test = current_thread.name().unwrap_or(""); + format!("{test}-console::subscriber") + }; + let join_handle = thread::Builder::new() + .name(thread_name) + // Run the test's console server and client tasks in a separate thread + // from the main test, ensuring that any `tracing` emitted by the + // console worker and the client are not collected by the subscriber + // under test. + .spawn(move || { + let _subscriber_guard = + tracing::subscriber::set_default(tracing_core::subscriber::NoSubscriber::default()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .expect("console-test error: failed to initialize console subscriber runtime"); + + runtime.block_on(async move { + task::Builder::new() + .name("console::serve") + .spawn(console_server(server, server_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-server' task"); + + let actual_tasks = task::Builder::new() + .name("console::client") + .spawn(console_client(client_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-client' task") + .await + .expect("console-test error: failed to await 'console-client' task"); + + test_state.advance_to_step(TestStep::UpdatesRecorded); + actual_tasks + }) + }) + .expect("console-test error: console subscriber could not spawn thread"); + + tracing::subscriber::with_default(registry, || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async move { + test_state_test + .wait_for_step(TestStep::ClientConnected) + .await; + + // Run the future that we are testing. + _ = task::Builder::new() + .name(MAIN_TASK_NAME) + .spawn(future) + .expect("console-test error: couldn't spawn test task") + .await; + _ = task::Builder::new() + .name(END_SIGNAL_TASK_NAME) + .spawn(futures::future::ready(())) + .expect("console-test error: couldn't spawn end signal task") + .await; + test_state_test.advance_to_step(TestStep::TestFinished); + + test_state_test + .wait_for_step(TestStep::UpdatesRecorded) + .await; + }); + }); + + let actual_tasks = join_handle + .join() + .expect("console-test error: failed to join 'console-subscriber' thread"); + + if let Err(test_failure) = validate_expected_tasks(expected_tasks, actual_tasks) { + panic!("Test failed: {test_failure}") + } +} + +/// Starts the console server. +/// +/// The server will start serving over its side of the duplex stream. +/// +/// Once the server gets spawned into its task, the test state is advanced +/// to the `ServerStarted` step. This function will then wait until the test +/// state reaches the `UpdatesRecorded` step (indicating that all validation of the +/// received updates has been completed) before dropping the aggregator. +/// +/// # Test State +/// +/// 1. Advances to: `ServerStarted` +/// 2. Waits for: `UpdatesRecorded` +async fn console_server( + server: console_subscriber::Server, + server_stream: DuplexStream, + mut test_state: TestState, +) { + let ServerParts { + instrument_server: service, + aggregator, + .. + } = server.into_parts(); + let aggregate = task::Builder::new() + .name("console::aggregate") + .spawn(aggregator.run()) + .expect("console-test error: couldn't spawn aggregator"); + Server::builder() + .add_service(service) + // .serve_with_incoming(futures::stream::once(Ok::<_, std::io::Error>(server_stream))) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + server_stream, + )])) + .await + .expect("console-test error: couldn't start instrument server."); + test_state.advance_to_step(TestStep::ServerStarted); + + test_state.wait_for_step(TestStep::UpdatesRecorded).await; + aggregate.abort(); +} + +/// Starts the console client and validates the expected tasks. +/// +/// First we wait until the server has started (test step `ServerStarted`), then +/// the client is connected to its half of the duplex stream and we start recording +/// the actual tasks. +/// +/// Once recording finishes (see [`record_actual_tasks()`] for details on the test +/// state condition), the actual tasks returned. +/// +/// # Test State +/// +/// 1. Waits for: `ServerStarted` +/// 2. Advances to: `ClientConnected` +async fn console_client(client_stream: DuplexStream, mut test_state: TestState) -> Vec { + test_state.wait_for_step(TestStep::ServerStarted).await; + + let mut client_stream = Some(client_stream); + // Note: we won't actually try to connect to this port on localhost, + // because we will call `connect_with_connector` with a service that + // just returns the `DuplexStream`, instead of making an actual + // network connection. + let endpoint = Endpoint::try_from("http://[::]:6669").expect("Could not create endpoint"); + let channel = endpoint + .connect_with_connector(service_fn(move |_: Uri| { + let client = client_stream.take(); + + async move { + // We need to return a Result from this async block, which is + // why we don't unwrap the `client` here. + client.ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + "console-test error: client already taken. This shouldn't happen.", + ) + }) + } + })) + .await + .expect("console-test client error: couldn't create client"); + test_state.advance_to_step(TestStep::ClientConnected); + + record_actual_tasks(channel, test_state).await +} + +/// Records the actual tasks which are received by the client channel. +/// +/// Updates will be received until the test state reaches the `TestFinished` step +/// (indicating that the test itself has finished running), at which point we wait +/// for a final update before returning all the actual tasks which were recorded. +/// +/// # Test State +/// +/// 1. Waits for: `TestFinished` +async fn record_actual_tasks( + client_channel: Channel, + mut test_state: TestState, +) -> Vec { + let mut client = InstrumentClient::new(client_channel); + + let mut stream = match client + .watch_updates(tonic::Request::new(InstrumentRequest {})) + .await + { + Ok(stream) => stream.into_inner(), + Err(err) => panic!("console-test error: client cannot connect to watch updates: {err}"), + }; + + let mut tasks = HashMap::new(); + + // The console-subscriber aggregator is a bit of an unknown entity for us, + // especially with respect to its update loops. We can't guarantee that + // it will have seen all the tasks in our test N iterations after the test + // ends for some known N. For this reason we need to use a signal task to + // check for and end the collection of events at that point. + let signal_task = ExpectedTask::default().match_name(END_SIGNAL_TASK_NAME.into()); + let mut signal_task_read = false; + while let Some(update) = stream.next().await { + let update = update.expect("console-test error: update stream error"); + + if let Some(task_update) = &update.task_update { + for new_task in &task_update.new_tasks { + let mut actual_task = match new_task.id { + Some(id) => ActualTask::new(id.id), + None => continue, + }; + for field in &new_task.fields { + match field.name.as_ref() { + Some(console_api::field::Name::StrName(name)) if name == "task.name" => { + actual_task.name = match field.value.as_ref() { + Some(Value::DebugVal(value)) => Some(value.clone()), + Some(Value::StrVal(value)) => Some(value.clone()), + _ => continue, + }; + } + _ => {} + } + } + + if signal_task.matches_actual_task(&actual_task) { + signal_task_read = true; + } else { + tasks.insert(actual_task.id, actual_task); + } + } + + for (id, stats) in &task_update.stats_update { + if let Some(task) = tasks.get_mut(id) { + task.wakes = stats.wakes; + task.self_wakes = stats.self_wakes; + } + } + } + + if test_state.is_step(TestStep::TestFinished) && signal_task_read { + // Once the test finishes running and we've read the signal task, the test ends. + break; + } + } + + tasks.into_values().collect() +} + +/// Validate the expected tasks against the actual tasks. +/// +/// Each expected task is checked in turn. +/// +/// A matching actual task is searched for. If one is found it, the +/// expected task is validated against the actual task. +/// +/// Any validation errors result in failure. If no matches +fn validate_expected_tasks( + expected_tasks: Vec, + actual_tasks: Vec, +) -> Result<(), TestFailure> { + let failures: Vec<_> = expected_tasks + .iter() + .map(|expected| validate_expected_task(expected, &actual_tasks)) + .filter_map(Result::err) + .collect(); + + if failures.is_empty() { + Ok(()) + } else { + Err(TestFailure { failures: failures }) + } +} + +fn validate_expected_task( + expected: &ExpectedTask, + actual_tasks: &Vec, +) -> Result<(), TaskValidationFailure> { + for actual in actual_tasks { + if expected.matches_actual_task(actual) { + // We only match a single task. + // FIXME(hds): We should probably create an error or a warning if multiple tasks match. + return expected.validate_actual_task(actual); + } + } + + expected.no_match_error() +} diff --git a/console-subscriber/tests/support/task.rs b/console-subscriber/tests/support/task.rs new file mode 100644 index 000000000..63814d016 --- /dev/null +++ b/console-subscriber/tests/support/task.rs @@ -0,0 +1,242 @@ +use std::{error, fmt}; + +use super::MAIN_TASK_NAME; + +/// An actual task +/// +/// This struct contains the values recorded from the console subscriber +/// client and represents what is known about an actual task running on +/// the test's runtime. +#[derive(Clone, Debug)] +pub(super) struct ActualTask { + pub(super) id: u64, + pub(super) name: Option, + pub(super) wakes: u64, + pub(super) self_wakes: u64, +} + +impl ActualTask { + pub(super) fn new(id: u64) -> Self { + Self { + id, + name: None, + wakes: 0, + self_wakes: 0, + } + } +} + +/// An error in task validation. +pub(super) struct TaskValidationFailure { + /// The expected task whose expectations were not met. + expected: ExpectedTask, + /// The actual task which failed the validation + actual: Option, + /// A textual description of the validation failure + failure: String, +} + +impl error::Error for TaskValidationFailure {} + +impl fmt::Display for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.failure) + } +} + +impl fmt::Debug for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.actual { + Some(actual) => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\ + \n Actual Task: {actual:?}\ + \n Failure: {failure}", + expected = self.expected, + failure = self.failure, + ), + None => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\ + \n Actual Task: \ + \n Failure: {failure}", + expected = self.expected, + failure = self.failure, + ), + } + } +} + +/// An expected task. +/// +/// This struct contains the fields that an expected task will attempt to match +/// actual tasks on, as well as the expectations that will be used to validate +/// which the actual task is as expected. +#[derive(Clone, Debug)] +pub(crate) struct ExpectedTask { + match_name: Option, + expect_present: Option, + expect_wakes: Option, + expect_self_wakes: Option, +} + +impl Default for ExpectedTask { + fn default() -> Self { + Self { + match_name: None, + expect_present: None, + expect_wakes: None, + expect_self_wakes: None, + } + } +} + +impl ExpectedTask { + /// Returns whether or not an actual task matches this expected task. + /// + /// All matching rules will be run, if they all succeed, then `true` will + /// be returned, otherwise `false`. + pub(super) fn matches_actual_task(&self, actual_task: &ActualTask) -> bool { + if let Some(match_name) = &self.match_name { + if Some(match_name) == actual_task.name.as_ref() { + return true; + } + } + + false + } + + /// Returns an error specifying that no match was found for this expected + /// task. + pub(super) fn no_match_error(&self) -> Result<(), TaskValidationFailure> { + Err(TaskValidationFailure { + expected: self.clone(), + actual: None, + failure: format!("{self}: no matching actual task was found"), + }) + } + + /// Validates all expectations against the provided actual task. + /// + /// No check that the actual task matches is performed. That must have been + /// done prior. + /// + /// If all expections are met, this method returns `Ok(())`. If any + /// expectations are not met, then the first incorrect expectation will + /// be returned as an `Err`. + pub(super) fn validate_actual_task( + &self, + actual_task: &ActualTask, + ) -> Result<(), TaskValidationFailure> { + let mut no_expectations = true; + if let Some(_expected) = self.expect_present { + no_expectations = false; + } + + if let Some(expected_wakes) = self.expect_wakes { + no_expectations = false; + if expected_wakes != actual_task.wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `wakes` to be {expected_wakes}, but \ + actual was {actual_wakes}", + actual_wakes = actual_task.wakes, + ), + }); + } + } + + if let Some(expected_self_wakes) = self.expect_self_wakes { + no_expectations = false; + if expected_self_wakes != actual_task.self_wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `self_wakes` to be \ + {expected_self_wakes}, but actual was \ + {actual_self_wakes}", + actual_self_wakes = actual_task.self_wakes, + ), + }); + } + } + + if no_expectations { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: no expectations set, if you want to just expect \ + that a matching task is present, use `expect_present()`", + ), + }); + } + + Ok(()) + } + + /// Matches tasks by name. + /// + /// To match this expected task, an actual task must have the name `name`. + #[allow(dead_code)] + pub(crate) fn match_name(mut self, name: String) -> Self { + self.match_name = Some(name); + self + } + + /// Matches tasks by the default task name. + /// + /// To match this expected task, an actual task must have the default name + /// assigned to the task which runs the future provided to [`assert_task`] + /// or [`assert_tasks`]. + /// + /// [`assert_task`]: fn@support::assert_task + /// [`assert_tasks`]: fn@support::assert_tasks + #[allow(dead_code)] + pub(crate) fn match_default_name(mut self) -> Self { + self.match_name = Some(MAIN_TASK_NAME.into()); + self + } + + /// Expects that a task is present. + /// + /// To validate, an actual task matching this expected task must be found. + #[allow(dead_code)] + pub(crate) fn expect_present(mut self) -> Self { + self.expect_present = Some(true); + self + } + + /// Expects that a task has a specific value for `wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of wakes equal to `wakes`. + #[allow(dead_code)] + pub(crate) fn expect_wakes(mut self, wakes: u64) -> Self { + self.expect_wakes = Some(wakes); + self + } + + /// Expects that a task has a specific value for `self_wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of self wakes equal to `self_wakes`. + #[allow(dead_code)] + pub(crate) fn expect_self_wakes(mut self, self_wakes: u64) -> Self { + self.expect_self_wakes = Some(self_wakes); + self + } +} + +impl fmt::Display for ExpectedTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let fields = match &self.match_name { + Some(name) => format!("name={name}"), + None => "(no fields to match on)".into(), + }; + write!(f, "Task {{ {fields} }}") + } +} diff --git a/console-subscriber/tests/wake.rs b/console-subscriber/tests/wake.rs new file mode 100644 index 000000000..e64e87a6e --- /dev/null +++ b/console-subscriber/tests/wake.rs @@ -0,0 +1,48 @@ +mod support; +use std::time::Duration; + +use support::{assert_task, ExpectedTask}; +use tokio::{task, time::sleep}; + +#[test] +fn sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn double_sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(2) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(1); + + let future = async { + task::yield_now().await; + }; + + assert_task(expected_task, future); +}