diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 862f5a8f9728..fe93b720af09 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -55,118 +55,110 @@ fn init() { #[tokio::test] async fn oom_sort() { - TestCase::new( - "select * from t order by host DESC", - vec![ + TestCase::new() + .with_query("select * from t order by host DESC") + .with_expected_errors(vec![ "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", - ], - 200_000, - ) + ]) + .with_memory_limit(200_000) .run() .await } #[tokio::test] async fn group_by_none() { - TestCase::new( - "select median(image) from t", - vec![ + TestCase::new() + .with_query("select median(image) from t") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "AggregateStream", - ], - 20_000, - ) - .run() - .await + ]) + .with_memory_limit(20_000) + .run() + .await } #[tokio::test] async fn group_by_row_hash() { - TestCase::new( - "select count(*) from t GROUP BY response_bytes", - vec![ + TestCase::new() + .with_query("select count(*) from t GROUP BY response_bytes") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "GroupedHashAggregateStream", - ], - 2_000, - ) - .run() - .await + ]) + .with_memory_limit(2_000) + .run() + .await } #[tokio::test] async fn group_by_hash() { - TestCase::new( + TestCase::new() // group by dict column - "select count(*) from t GROUP BY service, host, pod, container", - vec![ + .with_query("select count(*) from t GROUP BY service, host, pod, container") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "GroupedHashAggregateStream", - ], - 1_000, - ) - .run() - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] async fn join_by_key_multiple_partitions() { let config = SessionConfig::new().with_target_partitions(2); - TestCase::new( - "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput[0]", - ], - 1_000, - ) - .with_config(config) - .run() - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] async fn join_by_key_single_partition() { let config = SessionConfig::new().with_target_partitions(1); - TestCase::new( - "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput", - ], - 1_000, - ) - .with_config(config) - .run() - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] async fn join_by_expression() { - TestCase::new( - "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "NestedLoopJoinLoad[0]", - ], - 1_000, - ) - .run() - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] async fn cross_join() { - TestCase::new( - "select t1.* from t t1 CROSS JOIN t t2", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 CROSS JOIN t t2") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "CrossJoinExec", - ], - 1_000, - ) - .run() - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] @@ -176,49 +168,50 @@ async fn merge_join() { .with_target_partitions(2) .set_bool("datafusion.optimizer.prefer_hash_join", false); - TestCase::new( - "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", - vec![ + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "SMJStream", - ], - 1_000, - ) - .with_config(config) - .run() - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] async fn symmetric_hash_join() { - TestCase::new( - "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", - vec![ + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "SymmetricHashJoinStream", - ], - 1_000, - ) - .with_scenario(Scenario::AccessLogStreaming) - .run() - .await + ]) + .with_memory_limit(1_000) + .with_scenario(Scenario::AccessLogStreaming) + .run() + .await } #[tokio::test] async fn sort_preserving_merge() { let partition_size = batches_byte_size(&dict_batches()); - TestCase::new( - // This query uses the exact same ordering as the input table - // so only a merge is needed - "select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10", - vec![ + TestCase::new() + // This query uses the exact same ordering as the input table + // so only a merge is needed + .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "SortPreservingMergeExec", - ], + ]) // provide insufficient memory to merge - partition_size / 2, - ) + .with_memory_limit(partition_size / 2) // two partitions of data, so a merge is required .with_scenario(Scenario::DictionaryStrings(2)) .with_expected_plan( @@ -254,16 +247,14 @@ async fn sort_spill_reservation() { // This test case shows how sort_spill_reservation works by // purposely sorting data that requires non trivial memory to // sort/merge. - let test = TestCase::new( + let test = TestCase::new() // This query uses a different order than the input table to // force a sort. It also needs to have multiple columns to // force RowFormat / interner that makes merge require // substantial memory - "select * from t ORDER BY a , b DESC", - vec![], // expected errors set below + .with_query("select * from t ORDER BY a , b DESC") // enough memory to sort if we don't try to merge it all at once - (partition_size * 5) / 2, - ) + .with_memory_limit((partition_size * 5) / 2) // use a single partiton so only a sort is needed .with_scenario(Scenario::DictionaryStrings(1)) .with_disk_manager_config(DiskManagerConfig::NewOs) @@ -312,7 +303,7 @@ async fn sort_spill_reservation() { /// and verifies the expected errors are returned #[derive(Clone, Debug)] struct TestCase { - query: String, + query: Option, expected_errors: Vec, memory_limit: usize, config: SessionConfig, @@ -327,19 +318,11 @@ struct TestCase { } impl TestCase { - // TODO remove expected errors and memory limits and query from constructor - fn new<'a>( - query: impl Into, - expected_errors: impl IntoIterator, - memory_limit: usize, - ) -> Self { - let expected_errors: Vec = - expected_errors.into_iter().map(|s| s.to_string()).collect(); - + fn new() -> Self { Self { - query: query.into(), - expected_errors, - memory_limit, + query: None, + expected_errors: vec![], + memory_limit: 0, config: SessionConfig::new(), scenario: Scenario::AccessLog, disk_manager_config: DiskManagerConfig::Disabled, @@ -348,6 +331,12 @@ impl TestCase { } } + /// Set the query to run + fn with_query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + /// Set a list of expected strings that must appear in any errors fn with_expected_errors<'a>( mut self, @@ -358,6 +347,12 @@ impl TestCase { self } + /// Set the amount of memory that can be used + fn with_memory_limit(mut self, memory_limit: usize) -> Self { + self.memory_limit = memory_limit; + self + } + /// Specify the configuration to use pub fn with_config(mut self, config: SessionConfig) -> Self { self.config = config; @@ -424,6 +419,7 @@ impl TestCase { let ctx = SessionContext::with_state(state); ctx.register_table("t", table).expect("registering table"); + let query = query.expect("Test error: query not specified"); let df = ctx.sql(&query).await.expect("Planning query"); if !expected_plan.is_empty() {