Skip to content

Commit

Permalink
Add drop_columns to dataframe api (apache#11010)
Browse files Browse the repository at this point in the history
* Add drop_columns to dataframe api apache#11007

* Prettier cleanup

* Added additional drop_columns tests and fixed issue with nonexistent columns.
  • Loading branch information
Omega359 authored and xinlifoobar committed Jun 22, 2024
1 parent 6f2b577 commit 3add70a
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
169 changes: 169 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,42 @@ impl DataFrame {
})
}

/// Returns a new DataFrame containing all columns except the specified columns.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let df = df.drop_columns(&["a"])?;
/// # Ok(())
/// # }
/// ```
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields_to_drop = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.filter(|r| r.is_ok())
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, _)| self.plan.schema().qualified_field(idx))
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}

/// Expand each list element of a column to multiple rows.
#[deprecated(since = "37.0.0", note = "use unnest_columns instead")]
pub fn unnest_column(self, column: &str) -> Result<DataFrame> {
Expand Down Expand Up @@ -1830,6 +1866,139 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn drop_columns() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_duplicates() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11", "c2", "c2"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_nonexistent_columns() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["canada", "c2", "rocks"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_empty_array() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&[])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_with_quotes() -> Result<()> {
// define data with a column name that has a "." in it:
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f\"c1", Arc::new(array1) as _),
("f\"c2", Arc::new(array2) as _),
])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?;

let df_results = df.collect().await?;

assert_batches_sorted_eq!(
[
"+------+",
"| f\"c2 |",
"+------+",
"| 2 |",
"| 11 |",
"+------+"
],
&df_results
);

Ok(())
}

#[tokio::test]
async fn drop_with_periods() -> Result<()> {
// define data with a column name that has a "." in it:
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f.c1", Arc::new(array1) as _),
("f.c2", Arc::new(array2) as _),
])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;

let df_results = df.collect().await?;

assert_batches_sorted_eq!(
["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"],
&df_results
);

Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
// build plan using DataFrame API
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| aggregate | Perform an aggregate query with optional grouping expressions. |
| distinct | Filter out duplicate rows. |
| drop_columns | Create a projection with all but the provided column names. |
| except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema |
| filter | Filter a DataFrame to only include rows that match the specified filter expression. |
| intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema |
Expand Down

0 comments on commit 3add70a

Please sign in to comment.