Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_columns to dataframe api #11010

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,41 @@ impl DataFrame {
})
}

/// Returns a new DataFrame containing all columns except the specified columns.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let df = df.drop_columns(&["a"])?;
/// # Ok(())
/// # }
/// ```
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields_to_drop = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.collect::<Result<Vec<_>>>()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect into a hash table might be good for wide table scenarios

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I'll take a look. The logic for this function was based on the select_columns fn which did things that way.

let expr: Vec<Expr> = self
.plan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, _)| self.plan.schema().qualified_field(idx))
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}

/// Expand each list element of a column to multiple rows.
#[deprecated(since = "37.0.0", note = "use unnest_columns instead")]
pub fn unnest_column(self, column: &str) -> Result<DataFrame> {
Expand Down Expand Up @@ -1799,6 +1834,51 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn drop_columns() -> Result<()> {
// build plan using Table API

let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_with_periods() -> Result<()> {
// define data with a column name that has a "." in it:
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f.c1", Arc::new(array1) as _),
("f.c2", Arc::new(array2) as _),
])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;

let df_results = df.collect().await?;

assert_batches_sorted_eq!(
["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"],
&df_results
);

Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
// build plan using DataFrame API
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| aggregate | Perform an aggregate query with optional grouping expressions. |
| distinct | Filter out duplicate rows. |
| drop_columns | Create a projection with all but the provided column names. |
| except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema |
| filter | Filter a DataFrame to only include rows that match the specified filter expression. |
| intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema |
Expand Down