diff --git a/vegafusion-rt-datafusion/src/transform/pivot.rs b/vegafusion-rt-datafusion/src/transform/pivot.rs index 14bc9eb53..aa47c8e0f 100644 --- a/vegafusion-rt-datafusion/src/transform/pivot.rs +++ b/vegafusion-rt-datafusion/src/transform/pivot.rs @@ -9,7 +9,7 @@ use crate::transform::utils::RecordBatchUtils; use crate::transform::TransformTrait; use async_trait::async_trait; use datafusion::prelude::Column; -use datafusion_expr::{coalesce, lit, min, BuiltInWindowFunction, Expr, WindowFunction}; +use datafusion_expr::{coalesce, col, lit, min, when, BuiltInWindowFunction, Expr, WindowFunction}; use sqlgen::dialect::DialectDisplay; use std::sync::Arc; use vegafusion_core::arrow::array::StringArray; @@ -27,7 +27,26 @@ impl TransformTrait for Pivot { ) -> Result<(Arc, Vec)> { // Make sure the pivot column is a string let pivot_dtype = data_type(&unescaped_col(&self.field), &dataframe.schema_df())?; - let dataframe = if !is_string_datatype(&pivot_dtype) { + let dataframe = if matches!(pivot_dtype, DataType::Boolean) { + // Boolean column type. For consistency with vega, replace 0 with "false" and 1 with "true" + let select_exprs: Vec<_> = dataframe + .schema() + .fields + .iter() + .map(|field| { + if field.name() == &self.field { + Ok(when(col(&self.field).eq(lit(true)), lit("true")) + .otherwise(lit("false")) + .expect("Failed to construct Case expression") + .alias(&self.field)) + } else { + Ok(unescaped_col(field.name())) + } + }) + .collect::>>()?; + dataframe.select(select_exprs)? + } else if !is_string_datatype(&pivot_dtype) { + // Column type is not string, so cast values to strings let select_exprs: Vec<_> = dataframe .schema() .fields @@ -47,6 +66,7 @@ impl TransformTrait for Pivot { .collect::>>()?; dataframe.select(select_exprs)? } else { + // Column type is string dataframe }; diff --git a/vegafusion-rt-datafusion/tests/test_transform_pivot.rs b/vegafusion-rt-datafusion/tests/test_transform_pivot.rs index e7aa5fd9f..a398e14f4 100644 --- a/vegafusion-rt-datafusion/tests/test_transform_pivot.rs +++ b/vegafusion-rt-datafusion/tests/test_transform_pivot.rs @@ -16,16 +16,16 @@ use vegafusion_core::data::table::VegaFusionTable; fn medals() -> VegaFusionTable { VegaFusionTable::from_json( &json!([ - {"country": "Germany", "type": "gold", "count": 14}, - {"country": "Norway", "type": "gold", "count": 14}, - {"country": "Norway", "type": "silver", "count": 14}, - {"country": "Canada", "type": "copper", "count": 10}, - {"country": "Norway", "type": "bronze", "count": 11}, - {"country": "Germany", "type": "silver", "count": 10}, - {"country": "Germany", "type": "bronze", "count": 7}, - {"country": "Canada", "type": "gold", "count": 11}, - {"country": "Canada", "type": "silver", "count": 8}, - {"country": "Canada", "type": "bronze", "count": 10}, + {"country": "Germany", "type": "gold", "count": 14, "is_gold": true}, + {"country": "Norway", "type": "gold", "count": 14, "is_gold": true}, + {"country": "Norway", "type": "silver", "count": 14, "is_gold": false}, + {"country": "Canada", "type": "copper", "count": 10, "is_gold": false}, + {"country": "Norway", "type": "bronze", "count": 11, "is_gold": false}, + {"country": "Germany", "type": "silver", "count": 10, "is_gold": false}, + {"country": "Germany", "type": "bronze", "count": 7, "is_gold": false}, + {"country": "Canada", "type": "gold", "count": 11, "is_gold": true}, + {"country": "Canada", "type": "silver", "count": 8, "is_gold": false}, + {"country": "Canada", "type": "bronze", "count": 10, "is_gold": false}, ]), 1024, ) @@ -129,3 +129,52 @@ mod test_pivot_no_group { ); } } + +#[cfg(test)] +mod test_pivot_no_group_boolean { + use crate::medals; + use crate::util::check::check_transform_evaluation; + use rstest::rstest; + use vegafusion_core::spec::transform::aggregate::AggregateOpSpec; + use vegafusion_core::spec::transform::pivot::PivotTransformSpec; + use vegafusion_core::spec::transform::TransformSpec; + + #[rstest( + op, + limit, + case(None, None), + case(Some(AggregateOpSpec::Sum), None), + case(Some(AggregateOpSpec::Sum), Some(2)), + case(Some(AggregateOpSpec::Count), None), + case(Some(AggregateOpSpec::Count), Some(3)), + case(Some(AggregateOpSpec::Mean), None), + case(Some(AggregateOpSpec::Mean), Some(4)), + case(Some(AggregateOpSpec::Max), None), + case(Some(AggregateOpSpec::Max), Some(10)), + case(Some(AggregateOpSpec::Min), None), + case(Some(AggregateOpSpec::Min), Some(0)) + )] + fn test(op: Option, limit: Option) { + let dataset = medals(); + + let pivot_spec = PivotTransformSpec { + field: "is_gold".to_string(), + value: "count".to_string(), + groupby: None, + limit, + op, + extra: Default::default(), + }; + let transform_specs = vec![TransformSpec::Pivot(pivot_spec)]; + + let comp_config = Default::default(); + let eq_config = Default::default(); + + check_transform_evaluation( + &dataset, + transform_specs.as_slice(), + &comp_config, + &eq_config, + ); + } +}