Skip to content

Commit

Permalink
minor: improve join fuzz tests debug kit (apache#12397)
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead authored Sep 9, 2024
1 parent 09f7592 commit 0a82ac3
Showing 1 changed file with 63 additions and 60 deletions.
123 changes: 63 additions & 60 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::Schema;
use std::sync::Arc;
use std::time::SystemTime;

use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
Expand Down Expand Up @@ -474,11 +474,34 @@ impl JoinFuzzTestCase {
let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows());

if debug {
// compare
let smj_formatted =
pretty_format_batches(&smj_collected).unwrap().to_string();
let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string();
let nlj_formatted =
pretty_format_batches(&nlj_collected).unwrap().to_string();

let mut smj_formatted_sorted: Vec<&str> =
smj_formatted.trim().lines().collect();
smj_formatted_sorted.sort_unstable();

let mut hj_formatted_sorted: Vec<&str> =
hj_formatted.trim().lines().collect();
hj_formatted_sorted.sort_unstable();

let mut nlj_formatted_sorted: Vec<&str> =
nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

if debug
&& ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows)
|| (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows))
{
let fuzz_debug = "fuzz_test_debug";
std::fs::remove_dir_all(fuzz_debug).unwrap_or(());
std::fs::create_dir_all(fuzz_debug).unwrap();
let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}");
println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows);
println!("The debug is ON. Input data will be saved to {out_dir_name}");

Self::save_partitioned_batches_as_parquet(
Expand All @@ -492,7 +515,15 @@ impl JoinFuzzTestCase {
"input2",
);

if join_tests.contains(&JoinTestType::NljHj) {
if join_tests.contains(&JoinTestType::NljHj)
&& join_tests.contains(&JoinTestType::NljHj)
&& nlj_rows != hj_rows
{
println!("=============== HashJoinExec ==================");
hj_formatted_sorted.iter().for_each(|s| println!("{}", s));
println!("=============== NestedLoopJoinExec ==================");
smj_formatted_sorted.iter().for_each(|s| println!("{}", s));

Self::save_partitioned_batches_as_parquet(
&nlj_collected,
out_dir_name,
Expand All @@ -505,7 +536,12 @@ impl JoinFuzzTestCase {
);
}

if join_tests.contains(&JoinTestType::HjSmj) {
if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows {
println!("=============== HashJoinExec ==================");
hj_formatted_sorted.iter().for_each(|s| println!("{}", s));
println!("=============== SortMergeJoinExec ==================");
smj_formatted_sorted.iter().for_each(|s| println!("{}", s));

Self::save_partitioned_batches_as_parquet(
&hj_collected,
out_dir_name,
Expand All @@ -519,25 +555,6 @@ impl JoinFuzzTestCase {
}
}

// compare
let smj_formatted =
pretty_format_batches(&smj_collected).unwrap().to_string();
let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string();
let nlj_formatted =
pretty_format_batches(&nlj_collected).unwrap().to_string();

let mut smj_formatted_sorted: Vec<&str> =
smj_formatted.trim().lines().collect();
smj_formatted_sorted.sort_unstable();

let mut hj_formatted_sorted: Vec<&str> =
hj_formatted.trim().lines().collect();
hj_formatted_sorted.sort_unstable();

let mut nlj_formatted_sorted: Vec<&str> =
nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

if join_tests.contains(&JoinTestType::NljHj) {
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());
Expand Down Expand Up @@ -602,34 +619,6 @@ impl JoinFuzzTestCase {
/// )
/// .run_test(&[JoinTestType::HjSmj], false)
/// .await;
///
/// let ctx: SessionContext = SessionContext::new();
/// let df = ctx
/// .read_parquet(
/// "/tmp/input1/*.parquet",
/// datafusion::prelude::ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
/// let left = df.collect().await.unwrap();
///
/// let df = ctx
/// .read_parquet(
/// "/tmp/input2/*.parquet",
/// datafusion::prelude::ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
///
/// let right = df.collect().await.unwrap();
/// JoinFuzzTestCase::new(
/// left,
/// right,
/// JoinType::LeftSemi,
/// Some(Box::new(less_than_100_join_filter)),
/// )
/// .run_test()
/// .await
/// }
fn save_partitioned_batches_as_parquet(
input: &[RecordBatch],
Expand All @@ -641,9 +630,15 @@ impl JoinFuzzTestCase {
std::fs::create_dir_all(out_path).unwrap();

input.iter().enumerate().for_each(|(idx, batch)| {
let mut file =
std::fs::File::create(format!("{out_path}/file_{}.parquet", idx))
.unwrap();
let file_path = format!("{out_path}/file_{}.parquet", idx);
let mut file = std::fs::File::create(&file_path).unwrap();
println!(
"{}: Saving batch idx {} rows {} to parquet {}",
&out_name,
idx,
batch.num_rows(),
&file_path
);
let mut writer = parquet::arrow::ArrowWriter::try_new(
&mut file,
input.first().unwrap().schema(),
Expand All @@ -653,8 +648,6 @@ impl JoinFuzzTestCase {
writer.write(batch).unwrap();
writer.close().unwrap();
});

println!("The data {out_name} saved as parquet into {out_path}");
}

/// Read parquet files preserving partitions, i.e. 1 file -> 1 partition
Expand All @@ -667,10 +660,20 @@ impl JoinFuzzTestCase {
) -> std::io::Result<Vec<RecordBatch>> {
let ctx: SessionContext = SessionContext::new();
let mut batches: Vec<RecordBatch> = vec![];
let mut entries = std::fs::read_dir(dir)?
.map(|res| res.map(|e| e.path()))
.collect::<Result<Vec<_>, std::io::Error>>()?;

// important to read files using the same order as they have been written
// sort by modification time
entries.sort_by_key(|path| {
std::fs::metadata(path)
.and_then(|metadata| metadata.modified())
.unwrap_or(SystemTime::UNIX_EPOCH)
});

for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
for entry in entries {
let path = entry.as_path();

if path.is_file() {
let mut batch = ctx
Expand Down

0 comments on commit 0a82ac3

Please sign in to comment.