Skip to content

Commit

Permalink
complete conversion to column name constants
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Oct 19, 2023
1 parent 0598046 commit 434e286
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 33 deletions.
3 changes: 2 additions & 1 deletion rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_array::{RecordBatch, RecordBatchReader, UInt64Array};
use futures::future::try_join_all;
use futures::stream::BoxStream;
use futures::{join, StreamExt, TryFutureExt, TryStreamExt};
use lance_core::ROW_ID;
use lance_core::{io::ReadBatchParams, Error, Result};
use object_store::path::Path;
use snafu::{location, Location};
Expand Down Expand Up @@ -485,7 +486,7 @@ impl FileFragment {
.try_into_stream()
.await?
.try_for_each(|batch| {
let array = batch["_rowid"].clone();
let array = batch[ROW_ID].clone();
let int_array: &UInt64Array = as_primitive_array(array.as_ref());

// _row_id is global, not within fragment level. The high bits
Expand Down
47 changes: 21 additions & 26 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ mod test {
use arrow_schema::DataType;
use arrow_select::take;
use futures::TryStreamExt;
use lance_core::ROW_ID;
use lance_index::vector::DIST_COL;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use tempfile::tempdir;
Expand Down Expand Up @@ -1277,7 +1278,7 @@ mod test {
),
true,
),
ArrowField::new("_distance", DataType::Float32, true),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
);

Expand Down Expand Up @@ -1381,7 +1382,7 @@ mod test {
),
true,
),
ArrowField::new("_distance", DataType::Float32, true),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
);

Expand Down Expand Up @@ -1434,10 +1435,10 @@ mod test {
let plan = scan.create_plan().await.unwrap();

assert!(plan.as_any().is::<ProjectionExec>());
assert_eq!(plan.schema().field_names(), &["i", "_rowid"]);
assert_eq!(plan.schema().field_names(), &["i", ROW_ID]);
let scan = &plan.children()[0];
assert!(scan.as_any().is::<LanceScanExec>());
assert_eq!(scan.schema().field_names(), &["i", "_rowid"]);
assert_eq!(scan.schema().field_names(), &["i", ROW_ID]);
}

#[tokio::test]
Expand Down Expand Up @@ -1477,7 +1478,7 @@ mod test {

// If they aren't equal, they should be equal if we sort by row id
if ordered_batch != unordered_batch {
let sort_indices = sort_to_indices(&unordered_batch["_rowid"], None, None).unwrap();
let sort_indices = sort_to_indices(&unordered_batch[ROW_ID], None, None).unwrap();

let ordered_i = ordered_batch["i"].clone();
let sorted_i = take::take(&unordered_batch["i"], &sort_indices, None).unwrap();
Expand Down Expand Up @@ -1584,15 +1585,15 @@ mod test {

let take = &plan.children()[0];
assert!(take.as_any().is::<TakeExec>());
assert_eq!(take.schema().field_names(), ["i", "_rowid", "s"]);
assert_eq!(take.schema().field_names(), ["i", ROW_ID, "s"]);

let filter = &take.children()[0];
assert!(filter.as_any().is::<FilterExec>());
assert_eq!(filter.schema().field_names(), ["i", "_rowid"]);
assert_eq!(filter.schema().field_names(), ["i", ROW_ID]);

let scan = &filter.children()[0];
assert!(scan.as_any().is::<LanceScanExec>());
assert_eq!(filter.schema().field_names(), ["i", "_rowid"]);
assert_eq!(filter.schema().field_names(), ["i", ROW_ID]);
}

/// Test KNN with index
Expand Down Expand Up @@ -1622,14 +1623,14 @@ mod test {
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
vec!["s", "vec", "_distance"]
vec!["s", "vec", DIST_COL]
);

let take = &plan.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i", "s"]
[DIST_COL, ROW_ID, "vec", "i", "s"]
);
assert_eq!(
take.extra_schema
Expand All @@ -1644,15 +1645,12 @@ mod test {
assert!(filter.as_any().is::<FilterExec>());
assert_eq!(
filter.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
[DIST_COL, ROW_ID, "vec", "i"]
);

let take = &filter.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
);
assert_eq!(take.schema().field_names(), [DIST_COL, ROW_ID, "vec", "i"]);
assert_eq!(
take.extra_schema
.fields
Expand All @@ -1665,7 +1663,7 @@ mod test {
// TODO: Two continuous take execs, we can merge them into one.
let take = &take.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(take.schema().field_names(), ["_distance", "_rowid", "vec"]);
assert_eq!(take.schema().field_names(), [DIST_COL, ROW_ID, "vec"]);
assert_eq!(
take.extra_schema
.fields
Expand All @@ -1677,7 +1675,7 @@ mod test {

let knn = &take.children()[0];
assert!(knn.as_any().is::<KNNIndexExec>());
assert_eq!(knn.schema().field_names(), ["_distance", "_rowid"]);
assert_eq!(knn.schema().field_names(), [DIST_COL, ROW_ID]);
}

/// Test KNN index with refine factor
Expand Down Expand Up @@ -1709,14 +1707,14 @@ mod test {
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
vec!["s", "vec", "_distance"]
vec!["s", "vec", DIST_COL]
);

let take = &plan.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i", "s"]
[DIST_COL, ROW_ID, "vec", "i", "s"]
);
assert_eq!(
take.extra_schema
Expand All @@ -1731,15 +1729,12 @@ mod test {
assert!(filter.as_any().is::<FilterExec>());
assert_eq!(
filter.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
[DIST_COL, ROW_ID, "vec", "i"]
);

let take = &filter.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
);
assert_eq!(take.schema().field_names(), [DIST_COL, ROW_ID, "vec", "i"]);
assert_eq!(
take.extra_schema
.fields
Expand All @@ -1755,7 +1750,7 @@ mod test {

let take = &flat.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(take.schema().field_names(), ["_distance", "_rowid", "vec"]);
assert_eq!(take.schema().field_names(), [DIST_COL, ROW_ID, "vec"]);
assert_eq!(
take.extra_schema
.fields
Expand All @@ -1767,7 +1762,7 @@ mod test {

let knn = &take.children()[0];
assert!(knn.as_any().is::<KNNIndexExec>());
assert_eq!(knn.schema().field_names(), ["_distance", "_rowid"]);
assert_eq!(knn.schema().field_names(), [DIST_COL, ROW_ID]);
}

#[tokio::test]
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/io/exec/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ mod tests {
.unwrap();
let results = stream.try_collect::<Vec<_>>().await.unwrap();

assert!(results[0].schema().column_with_name("_distance").is_some());
assert!(results[0].schema().column_with_name(DIST_COL).is_some());

assert_eq!(results.len(), 1);

Expand Down
4 changes: 2 additions & 2 deletions rust/lance/src/io/exec/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ mod tests {
let schema = take_exec.schema();
assert_eq!(
schema.fields.iter().map(|f| f.name()).collect::<Vec<_>>(),
vec!["i", "_rowid", "s"]
vec!["i", ROW_ID, "s"]
);
}

Expand Down Expand Up @@ -401,7 +401,7 @@ mod tests {
let schema = take_exec.schema();
assert_eq!(
schema.fields.iter().map(|f| f.name()).collect::<Vec<_>>(),
vec!["i", "s", "_rowid"]
vec!["i", "s", ROW_ID]
);
}

Expand Down
6 changes: 3 additions & 3 deletions rust/lance/src/io/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ mod tests {

for b in 0..10 {
let batch = reader.read_batch(b, .., reader.schema()).await.unwrap();
let row_ids_col = &batch["_rowid"];
let row_ids_col = &batch[ROW_ID];
// Do the same computation as `compute_row_id`.
let start_pos = (fragment << 32) + 10 * b as u64;

Expand Down Expand Up @@ -1227,7 +1227,7 @@ mod tests {

for b in 0..10 {
let batch = reader.read_batch(b, .., reader.schema()).await.unwrap();
let row_ids_col = &batch["_rowid"];
let row_ids_col = &batch[ROW_ID];
// Do the same computation as `compute_row_id`.
let start_pos = (fragment << 32) + 10 * b as u64;

Expand Down Expand Up @@ -1292,7 +1292,7 @@ mod tests {
.fields()
.iter()
.map(|f| f.name().as_str())
.any(|name| name == "_rowid"))
.any(|name| name == ROW_ID))
}
}

Expand Down

0 comments on commit 434e286

Please sign in to comment.