From 65a7cdc92af12fdeacec630d8bfb8cb03f6d45e7 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Thu, 27 Jun 2024 13:46:23 -0400 Subject: [PATCH 1/9] failing test --- arrow-flight/src/encode.rs | 60 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index f59c29e68173..aa16dbf01413 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -1045,6 +1045,66 @@ mod tests { } } + #[tokio::test] + async fn test_send_multiple_dictionaries() { + // Create a schema with two dictionary fields that have the same dict ID + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict_1", + DataType::UInt16, + DataType::Utf8, + false, + ),Field::new_dictionary( + "dict_2", + DataType::UInt16, + DataType::Utf8, + false, + )])); + + let arr_one_1: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_one_2: Arc> = + Arc::new(vec!["c", "c", "d"].into_iter().collect()); + let arr_two_1: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let arr_two_2: Arc> = + Arc::new(vec!["k", "d", "e"].into_iter().collect()); + let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()]).unwrap(); + let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()]).unwrap(); + + let encoder = FlightDataEncoderBuilder::default() + .with_dictionary_handling(DictionaryHandling::Resend) + .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + + let mut decoder = FlightDataDecoder::new(encoder); + let mut expected_array_1 = arr_one_1; + let mut expected_array_2 = arr_one_2; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), schema); + + let actual_1 = Arc::new(downcast_array::>( + b.column_by_name("dict_1").unwrap(), + )); + + assert_eq!(actual_1, expected_array_1); + + let actual_2 = Arc::new(downcast_array::>( + b.column_by_name("dict_2").unwrap(), + )); + + assert_eq!(actual_2, expected_array_2); + + expected_array_1 = arr_two_1.clone(); + expected_array_2 = arr_two_2.clone(); + } + } + } + } + #[test] fn test_schema_metadata_encoded() { let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata( From 737d350cf841752984eff5fbca5239314ae6af8b Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Thu, 27 Jun 2024 15:07:17 -0400 Subject: [PATCH 2/9] Handle dict ID assignment during flight encoding/decoding --- arrow-flight/src/encode.rs | 417 +++++++++++++++++++++++++++---------- arrow-ipc/src/reader.rs | 7 +- arrow-ipc/src/writer.rs | 68 +++--- 3 files changed, 346 insertions(+), 146 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index aa16dbf01413..6a689dac7ae0 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -248,6 +248,9 @@ pub struct FlightDataEncoder { /// Deterimines how `DictionaryArray`s are encoded for transport. /// See [`DictionaryHandling`] for more information. dictionary_handling: DictionaryHandling, + /// Tracks dictionary IDs for dictionary fields so that we can assign a unique + /// dictionary ID for each field + next_dict_id: i64, } impl FlightDataEncoder { @@ -273,6 +276,7 @@ impl FlightDataEncoder { done: false, descriptor, dictionary_handling, + next_dict_id: 0, }; // If schema is known up front, enqueue it immediately @@ -304,7 +308,11 @@ impl FlightDataEncoder { // The first message is the schema message, and all // batches have the same schema let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; - let schema = Arc::new(prepare_schema_for_flight(schema, send_dictionaries)); + let schema = Arc::new(prepare_schema_for_flight( + schema, + &mut self.next_dict_id, + send_dictionaries, + )); let mut schema_flight_data = self.encoder.encode_schema(&schema); // attach any metadata requested @@ -325,6 +333,8 @@ impl FlightDataEncoder { None => self.encode_schema(batch.schema_ref()), }; + println!("schema: {schema:#?}"); + let batch = match self.dictionary_handling { DictionaryHandling::Resend => batch, DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, @@ -438,24 +448,28 @@ pub enum DictionaryHandling { Resend, } -fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field { +fn prepare_field_for_flight( + field: &FieldRef, + next_dict_id: &mut i64, + send_dictionaries: bool, +) -> Field { match field.data_type() { DataType::List(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, send_dictionaries), + prepare_field_for_flight(inner, next_dict_id, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::LargeList(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, send_dictionaries), + prepare_field_for_flight(inner, next_dict_id, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::Struct(fields) => { let new_fields: Vec = fields .iter() - .map(|f| prepare_field_for_flight(f, send_dictionaries)) + .map(|f| prepare_field_for_flight(f, next_dict_id, send_dictionaries)) .collect(); Field::new_struct(field.name(), new_fields, field.is_nullable()) .with_metadata(field.metadata().clone()) @@ -463,17 +477,37 @@ fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field DataType::Union(fields, mode) => { let (type_ids, new_fields): (Vec, Vec) = fields .iter() - .map(|(type_id, f)| (type_id, prepare_field_for_flight(f, send_dictionaries))) + .map(|(type_id, f)| { + ( + type_id, + prepare_field_for_flight(f, next_dict_id, send_dictionaries), + ) + }) .unzip(); Field::new_union(field.name(), type_ids, new_fields, *mode) } - DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( - field.name(), - value_type.as_ref().clone(), - field.is_nullable(), - ) - .with_metadata(field.metadata().clone()), + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = *next_dict_id; + *next_dict_id += 1; + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } _ => field.as_ref().clone(), } } @@ -483,18 +517,39 @@ fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field /// Convert dictionary types to underlying types /// /// See hydrate_dictionary for more information -fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema { +fn prepare_schema_for_flight( + schema: &Schema, + next_dict_id: &mut i64, + send_dictionaries: bool, +) -> Schema { let fields: Fields = schema .fields() .iter() .map(|field| match field.data_type() { - DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( - field.name(), - value_type.as_ref().clone(), - field.is_nullable(), - ) - .with_metadata(field.metadata().clone()), - tpe if tpe.is_nested() => prepare_field_for_flight(field, send_dictionaries), + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = *next_dict_id; + *next_dict_id += 1; + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + tpe if tpe.is_nested() => { + prepare_field_for_flight(field, next_dict_id, send_dictionaries) + } _ => field.as_ref().clone(), }) .collect(); @@ -619,7 +674,10 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_multiple_dictionaries_resend() { + // Create a schema with two dictionary fields that have the same dict ID + let schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false), + Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false), + ])); + + let arr_one_1: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_one_2: Arc> = + Arc::new(vec!["c", "c", "d"].into_iter().collect()); + let arr_two_1: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let arr_two_2: Arc> = + Arc::new(vec!["k", "d", "e"].into_iter().collect()); + let batch1 = + RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()]) + .unwrap(); + let batch2 = + RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()]) + .unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_list_hydration() { - let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); builder.append_value(vec![Some("a"), None, Some("b")]); @@ -766,6 +865,30 @@ mod tests { } } + #[tokio::test] + async fn test_dictionary_list_resend() { + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_struct_hydration() { let struct_fields = vec![Field::new_list( @@ -774,26 +897,38 @@ mod tests { true, )]; - let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); - builder.append_value(vec![Some("a"), None, Some("b")]); + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("a"), None, Some("b")]); - let arr1 = Arc::new(builder.finish()); - let arr1 = StructArray::new(struct_fields.clone().into(), vec![arr1], None); + struct_builder.append(true); - builder.append_value(vec![Some("c"), None, Some("d")]); + let arr1 = struct_builder.finish(); - let arr2 = Arc::new(builder.finish()); - let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); let schema = Arc::new(Schema::new(vec![Field::new_struct( "struct", - struct_fields.clone(), + struct_fields, true, )])); let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); - let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); @@ -838,6 +973,43 @@ mod tests { } } + #[tokio::test] + async fn test_dictionary_struct_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("a"), None, Some("b")]); + struct_builder.append(true); + + let arr1 = struct_builder.finish(); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "struct", + struct_fields, + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_union_hydration() { let struct_fields = vec![Field::new_list( @@ -1004,102 +1176,123 @@ mod tests { } #[tokio::test] - async fn test_send_dictionaries() { - let schema = Arc::new(Schema::new(vec![Field::new_dictionary( - "dict", - DataType::UInt16, - DataType::Utf8, - false, - )])); + async fn test_dictionary_union_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; - let arr_one: Arc> = - Arc::new(vec!["a", "a", "b"].into_iter().collect()); - let arr_two: Arc> = - Arc::new(vec!["b", "a", "c"].into_iter().collect()); - let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); - let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + let union_fields = [ + ( + 0, + Arc::new(Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )), + ), + ( + 1, + Arc::new(Field::new_struct("struct", struct_fields.clone(), true)), + ), + (2, Arc::new(Field::new("string", DataType::Utf8, true))), + ] + .into_iter() + .collect::(); - let encoder = FlightDataEncoderBuilder::default() - .with_dictionary_handling(DictionaryHandling::Resend) - .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; - let mut decoder = FlightDataDecoder::new(encoder); - let mut expected_array = arr_one; - while let Some(decoded) = decoder.next().await { - let decoded = decoded.unwrap(); - match decoded.payload { - DecodedPayload::None => {} - DecodedPayload::Schema(s) => assert_eq!(s, schema), - DecodedPayload::RecordBatch(b) => { - assert_eq!(b.schema(), schema); + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); - let actual_array = Arc::new(downcast_array::>( - b.column_by_name("dict").unwrap(), - )); + builder.append_value(vec![Some("a"), None, Some("b")]); - assert_eq!(actual_array, expected_array); + let arr1 = builder.finish(); - expected_array = arr_two.clone(); - } - } - } - } + let type_id_buffer = [0].into_iter().collect::>(); + let arr1 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + Arc::new(arr1) as Arc, + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); - #[tokio::test] - async fn test_send_multiple_dictionaries() { - // Create a schema with two dictionary fields that have the same dict ID - let schema = Arc::new(Schema::new(vec![Field::new_dictionary( - "dict_1", - DataType::UInt16, - DataType::Utf8, - false, - ),Field::new_dictionary( - "dict_2", - DataType::UInt16, - DataType::Utf8, - false, + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let type_id_buffer = [1].into_iter().collect::>(); + let arr2 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + Arc::new(arr2), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + let type_id_buffer = [2].into_iter().collect::>(); + let arr3 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + Arc::new(StringArray::from(vec!["e"])), + ], + ) + .unwrap(); + + let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields + .iter() + .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone())) + .unzip(); + let schema = Arc::new(Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + union_fields.clone(), + UnionMode::Sparse, )])); - let arr_one_1: Arc> = - Arc::new(vec!["a", "a", "b"].into_iter().collect()); - let arr_one_2: Arc> = - Arc::new(vec!["c", "c", "d"].into_iter().collect()); - let arr_two_1: Arc> = - Arc::new(vec!["b", "a", "c"].into_iter().collect()); - let arr_two_2: Arc> = - Arc::new(vec!["k", "d", "e"].into_iter().collect()); - let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()]).unwrap(); - let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()]).unwrap(); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2, batch3]).await; + } + + async fn verify_flight_round_trip(mut batches: Vec) { + let expected_schema = batches.first().unwrap().schema(); let encoder = FlightDataEncoderBuilder::default() .with_dictionary_handling(DictionaryHandling::Resend) - .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + .build(futures::stream::iter(batches.clone().into_iter().map(Ok))); + + let mut expected_batches = batches.drain(..); let mut decoder = FlightDataDecoder::new(encoder); - let mut expected_array_1 = arr_one_1; - let mut expected_array_2 = arr_one_2; while let Some(decoded) = decoder.next().await { let decoded = decoded.unwrap(); match decoded.payload { DecodedPayload::None => {} - DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), DecodedPayload::RecordBatch(b) => { - assert_eq!(b.schema(), schema); - - let actual_1 = Arc::new(downcast_array::>( - b.column_by_name("dict_1").unwrap(), - )); - - assert_eq!(actual_1, expected_array_1); - - let actual_2 = Arc::new(downcast_array::>( - b.column_by_name("dict_2").unwrap(), - )); - - assert_eq!(actual_2, expected_array_2); - - expected_array_1 = arr_two_1.clone(); - expected_array_2 = arr_two_2.clone(); + let expected_batch = expected_batches.next().unwrap(); + assert_eq!(b, expected_batch); } } } @@ -1111,7 +1304,9 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let got = prepare_schema_for_flight(&schema, false); + let mut next_dict_id = 0; + + let got = prepare_schema_for_flight(&schema, &mut next_dict_id, false); assert!(got.metadata().contains_key("some_key")); } diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 3423d06b6fca..70767a7a349e 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -617,7 +617,8 @@ fn read_dictionary_impl( let id = batch.id(); let fields_using_this_dictionary = schema.fields_with_dict_id(id); let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + println!("schema: {schema:#?}"); + ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; // As the dictionary batch does not contain the type of the @@ -643,7 +644,7 @@ fn read_dictionary_impl( _ => None, } .ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; // We don't currently record the isOrdered field. This could be general @@ -1812,7 +1813,7 @@ mod tests { "values", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), true, - 1, + 2, false, )); let entry_struct = StructArray::from(vec![ diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index f74a86e013cd..81e85f71b7b0 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -163,7 +163,7 @@ impl Default for IpcWriteOptions { /// /// // encode the batch into zero or more encoded dictionaries /// // and the data for the actual array. -/// let data_gen = IpcDataGenerator {}; +/// let data_gen = IpcDataGenerator::default(); /// let (encoded_dictionaries, encoded_message) = data_gen /// .encoded_batch(&batch, &mut dictionary_tracker, &options) /// .unwrap(); @@ -204,21 +204,22 @@ impl IpcDataGenerator { encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, + next_dict_id: &mut i64, ) -> Result<(), ArrowError> { match column.data_type() { - DataType::Struct(fields) => { + DataType::Struct(_) => { let s = as_struct_array(column); - for (field, column) in fields.iter().zip(s.columns()) { + for column in s.columns() { self.encode_dictionaries( - field, column, encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } } - DataType::RunEndEncoded(_, values) => { + DataType::RunEndEncoded(_, _) => { let data = column.to_data(); if data.child_data().len() != 2 { return Err(ArrowError::InvalidArgumentError(format!( @@ -230,82 +231,77 @@ impl IpcDataGenerator { // only for values array. let values_array = make_array(data.child_data()[1].clone()); self.encode_dictionaries( - values, &values_array, encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } - DataType::List(field) => { + DataType::List(_) => { let list = as_list_array(column); self.encode_dictionaries( - field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } - DataType::LargeList(field) => { + DataType::LargeList(_) => { let list = as_large_list_array(column); self.encode_dictionaries( - field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } - DataType::FixedSizeList(field, _) => { + DataType::FixedSizeList(_, _) => { let list = column .as_any() .downcast_ref::() .expect("Unable to downcast to fixed size list array"); self.encode_dictionaries( - field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } - DataType::Map(field, _) => { + DataType::Map(_, _) => { let map_array = as_map_array(column); - let (keys, values) = match field.data_type() { - DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]), - _ => panic!("Incorrect field data type {:?}", field.data_type()), - }; - // keys self.encode_dictionaries( - keys, map_array.keys(), encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; // values self.encode_dictionaries( - values, map_array.values(), encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } DataType::Union(fields, _) => { let union = as_union_array(column); - for (type_id, field) in fields.iter() { + for (type_id, _) in fields.iter() { let column = union.child(type_id); self.encode_dictionaries( - field, column, encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; } } @@ -317,17 +313,17 @@ impl IpcDataGenerator { fn encode_dictionaries( &self, - field: &Field, column: &ArrayRef, encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, + next_dict_id: &mut i64, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Dictionary(_key_type, _value_type) => { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); + let dict_id = *next_dict_id; + *next_dict_id += 1; + let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; @@ -338,6 +334,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?; let emit = dictionary_tracker.insert(dict_id, column)?; @@ -355,6 +352,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + next_dict_id, )?, } @@ -373,14 +371,16 @@ impl IpcDataGenerator { let schema = batch.schema(); let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len()); - for (i, field) in schema.fields().iter().enumerate() { + let mut next_dict_id = 0; + + for i in 0..schema.fields().len() { let column = batch.column(i); self.encode_dictionaries( - field, column, &mut encoded_dictionaries, dictionary_tracker, write_options, + &mut next_dict_id, )?; } @@ -695,6 +695,8 @@ impl DictionaryTracker { /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just /// inserted. pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result { + println!("\n\ninserting dict id {dict_id}\n\n"); + let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; @@ -1821,8 +1823,9 @@ mod tests { gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); - // Dictionary with id 2 should have been written to the dict tracker - assert!(dict_tracker.written.contains_key(&2)); + // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema + // so we expect the dict will be keyed to 0 + assert!(dict_tracker.written.contains_key(&0)); } #[test] @@ -1856,8 +1859,9 @@ mod tests { gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); - // Dictionary with id 2 should have been written to the dict tracker - assert!(dict_tracker.written.contains_key(&2)); + // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema + // so we expect the dict will be keyed to 0 + assert!(dict_tracker.written.contains_key(&0)); } fn write_union_file(options: IpcWriteOptions) { From f125868f84d0e9fbcffa9e63f171c7b4378c3b1b Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Thu, 27 Jun 2024 15:17:55 -0400 Subject: [PATCH 3/9] remove println --- arrow-flight/src/encode.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 6a689dac7ae0..160d32064ad5 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -333,8 +333,6 @@ impl FlightDataEncoder { None => self.encode_schema(batch.schema_ref()), }; - println!("schema: {schema:#?}"); - let batch = match self.dictionary_handling { DictionaryHandling::Resend => batch, DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, From a932508c9ebee7ff07cefb68d2de9cbe95e7e12a Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Thu, 27 Jun 2024 15:19:47 -0400 Subject: [PATCH 4/9] One more println --- arrow-ipc/src/writer.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 81e85f71b7b0..ca82f3bdef96 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -695,8 +695,6 @@ impl DictionaryTracker { /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just /// inserted. pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result { - println!("\n\ninserting dict id {dict_id}\n\n"); - let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; From 648e819a576ca68e1cc0f8df2d7f5e9aed8e6866 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 28 Jun 2024 10:02:35 -0400 Subject: [PATCH 5/9] Make auto-assign optional --- arrow-flight/src/encode.rs | 43 +++-- arrow-flight/src/lib.rs | 2 +- arrow-flight/src/utils.rs | 4 +- .../integration_test.rs | 2 +- .../integration_test.rs | 2 +- arrow-integration-testing/tests/ipc_writer.rs | 4 +- arrow-ipc/src/reader.rs | 5 +- arrow-ipc/src/writer.rs | 164 ++++++++++++------ 8 files changed, 141 insertions(+), 85 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 160d32064ad5..1d0868a7231a 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -248,9 +248,6 @@ pub struct FlightDataEncoder { /// Deterimines how `DictionaryArray`s are encoded for transport. /// See [`DictionaryHandling`] for more information. dictionary_handling: DictionaryHandling, - /// Tracks dictionary IDs for dictionary fields so that we can assign a unique - /// dictionary ID for each field - next_dict_id: i64, } impl FlightDataEncoder { @@ -276,7 +273,6 @@ impl FlightDataEncoder { done: false, descriptor, dictionary_handling, - next_dict_id: 0, }; // If schema is known up front, enqueue it immediately @@ -310,7 +306,7 @@ impl FlightDataEncoder { let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; let schema = Arc::new(prepare_schema_for_flight( schema, - &mut self.next_dict_id, + &mut self.encoder.dictionary_tracker, send_dictionaries, )); let mut schema_flight_data = self.encoder.encode_schema(&schema); @@ -448,26 +444,26 @@ pub enum DictionaryHandling { fn prepare_field_for_flight( field: &FieldRef, - next_dict_id: &mut i64, + dictionary_tracker: &mut DictionaryTracker, send_dictionaries: bool, ) -> Field { match field.data_type() { DataType::List(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, next_dict_id, send_dictionaries), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::LargeList(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, next_dict_id, send_dictionaries), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::Struct(fields) => { let new_fields: Vec = fields .iter() - .map(|f| prepare_field_for_flight(f, next_dict_id, send_dictionaries)) + .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries)) .collect(); Field::new_struct(field.name(), new_fields, field.is_nullable()) .with_metadata(field.metadata().clone()) @@ -478,7 +474,7 @@ fn prepare_field_for_flight( .map(|(type_id, f)| { ( type_id, - prepare_field_for_flight(f, next_dict_id, send_dictionaries), + prepare_field_for_flight(f, dictionary_tracker, send_dictionaries), ) }) .unzip(); @@ -494,8 +490,8 @@ fn prepare_field_for_flight( ) .with_metadata(field.metadata().clone()) } else { - let dict_id = *next_dict_id; - *next_dict_id += 1; + let dict_id = dictionary_tracker.push_dict_id(field.as_ref()); + Field::new_dict( field.name(), field.data_type().clone(), @@ -517,7 +513,7 @@ fn prepare_field_for_flight( /// See hydrate_dictionary for more information fn prepare_schema_for_flight( schema: &Schema, - next_dict_id: &mut i64, + dictionary_tracker: &mut DictionaryTracker, send_dictionaries: bool, ) -> Schema { let fields: Fields = schema @@ -533,8 +529,7 @@ fn prepare_schema_for_flight( ) .with_metadata(field.metadata().clone()) } else { - let dict_id = *next_dict_id; - *next_dict_id += 1; + let dict_id = dictionary_tracker.push_dict_id(field.as_ref()); Field::new_dict( field.name(), field.data_type().clone(), @@ -546,7 +541,7 @@ fn prepare_schema_for_flight( } } tpe if tpe.is_nested() => { - prepare_field_for_flight(field, next_dict_id, send_dictionaries) + prepare_field_for_flight(field, dictionary_tracker, send_dictionaries) } _ => field.as_ref().clone(), }) @@ -601,10 +596,11 @@ struct FlightIpcEncoder { impl FlightIpcEncoder { fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { + let preserve_dict_id = options.preserve_dict_id(); Self { options, data_gen: IpcDataGenerator::default(), - dictionary_tracker: DictionaryTracker::new(error_on_replacement), + dictionary_tracker: DictionaryTracker::new(error_on_replacement, preserve_dict_id), } } @@ -691,7 +687,7 @@ mod tests { /// fn test_encode_flight_data() { // use 8-byte alignment - default alignment is 64 which produces bigger ipc data - let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); + let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap(); let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) @@ -1277,6 +1273,7 @@ mod tests { let expected_schema = batches.first().unwrap().schema(); let encoder = FlightDataEncoderBuilder::default() + .with_options(IpcWriteOptions::default().with_preserve_dict_id(false)) .with_dictionary_handling(DictionaryHandling::Resend) .build(futures::stream::iter(batches.clone().into_iter().map(Ok))); @@ -1302,9 +1299,9 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let mut next_dict_id = 0; + let mut dictionary_tracker = DictionaryTracker::new(false, true); - let got = prepare_schema_for_flight(&schema, &mut next_dict_id, false); + let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); assert!(got.metadata().contains_key("some_key")); } @@ -1325,7 +1322,7 @@ mod tests { options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new(false); + let mut dictionary_tracker = DictionaryTracker::new(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -1586,7 +1583,9 @@ mod tests { let mut stream = FlightDataEncoderBuilder::new() .with_max_flight_data_size(max_flight_data_size) // use 8-byte alignment - default alignment is 64 which produces bigger ipc data - .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) + .with_options( + IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap(), + ) .build(futures::stream::iter([Ok(batch.clone())])); let mut i = 0; diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index a4b4ab7bc316..d162f3fe3b47 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -876,7 +876,7 @@ mod tests { // V4 with write_legacy_ipc_format = true // this will not write the continuation marker - let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap(); + let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4, true).unwrap(); let schema_ipc = SchemaAsIpc::new(&schema, &option); let result: SchemaResult = schema_ipc.try_into().unwrap(); let des_schema: Schema = (&result).try_into().unwrap(); diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 32716a52eb0d..008990a4df01 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -39,7 +39,7 @@ pub fn flight_data_from_arrow_batch( options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = writer::DictionaryTracker::new(false, options.preserve_dict_id()); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -149,7 +149,7 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = writer::DictionaryTracker::new(false, options.preserve_dict_id()); for batch in batches.iter() { let (encoded_dictionaries, encoded_batch) = diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c6b5a72ca6e2..7bb69dabfdca 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -123,7 +123,7 @@ async fn send_batch( options: &writer::IpcWriteOptions, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = writer::DictionaryTracker::new(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 25203ecb7697..5cdb4b6703ec 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -119,7 +119,7 @@ impl FlightService for FlightServiceImpl { .enumerate() .flat_map(|(counter, batch)| { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = writer::DictionaryTracker::new(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, &options) diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs index d780eb2ee0b5..77be217351dc 100644 --- a/arrow-integration-testing/tests/ipc_writer.rs +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -98,12 +98,12 @@ fn write_2_0_0_compression() { // writer options for each compression type let all_options = [ - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5, true) .unwrap() .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) .unwrap(), // write IPC version 5 with zstd - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5, true) .unwrap() .try_with_compression(Some(ipc::CompressionType::ZSTD)) .unwrap(), diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 70767a7a349e..0567580f3785 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -617,7 +617,6 @@ fn read_dictionary_impl( let id = batch.id(); let fields_using_this_dictionary = schema.fields_with_dict_id(id); let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - println!("schema: {schema:#?}"); ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; @@ -2083,7 +2082,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -2120,7 +2119,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index ca82f3bdef96..e50921272d31 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -59,6 +59,9 @@ pub struct IpcWriteOptions { /// Compression, if desired. Will result in a runtime error /// if the corresponding feature is not enabled batch_compression_type: Option, + /// Flag indicating whether the writer should preserver the dictionary IDs defined in the + /// schema or generate unique dictionary IDs internally during encoding. + preserve_dict_id: bool, } impl IpcWriteOptions { @@ -81,11 +84,12 @@ impl IpcWriteOptions { } Ok(self) } - /// Try create IpcWriteOptions, checking for incompatible settings + /// Try to create IpcWriteOptions, checking for incompatible settings pub fn try_new( alignment: usize, write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, + preserve_dict_id: bool, ) -> Result { let is_alignment_valid = alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64; @@ -106,6 +110,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + preserve_dict_id, }), crate::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -118,6 +123,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + preserve_dict_id, }) } } @@ -126,6 +132,15 @@ impl IpcWriteOptions { ))), } } + + pub fn preserve_dict_id(&self) -> bool { + self.preserve_dict_id + } + + pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { + self.preserve_dict_id = preserve_dict_id; + self + } } impl Default for IpcWriteOptions { @@ -135,6 +150,7 @@ impl Default for IpcWriteOptions { write_legacy_ipc_format: false, metadata_version: crate::MetadataVersion::V5, batch_compression_type: None, + preserve_dict_id: true, } } } @@ -159,7 +175,7 @@ impl Default for IpcWriteOptions { /// // Error of dictionary ids are replaced. /// let error_on_replacement = true; /// let options = IpcWriteOptions::default(); -/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement); +/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement,true); /// /// // encode the batch into zero or more encoded dictionaries /// // and the data for the actual array. @@ -198,28 +214,29 @@ impl IpcDataGenerator { } } - fn _encode_dictionaries( + fn _encode_dictionaries>( &self, column: &ArrayRef, encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - next_dict_id: &mut i64, + dict_id: &mut I, ) -> Result<(), ArrowError> { match column.data_type() { - DataType::Struct(_) => { + DataType::Struct(fields) => { let s = as_struct_array(column); - for column in s.columns() { + for (field, column) in fields.iter().zip(s.columns()) { self.encode_dictionaries( + field, column, encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } } - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(_, values) => { let data = column.to_data(); if data.child_data().len() != 2 { return Err(ArrowError::InvalidArgumentError(format!( @@ -231,77 +248,89 @@ impl IpcDataGenerator { // only for values array. let values_array = make_array(data.child_data()[1].clone()); self.encode_dictionaries( + values, &values_array, encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } - DataType::List(_) => { + DataType::List(field) => { let list = as_list_array(column); self.encode_dictionaries( + field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } - DataType::LargeList(_) => { + DataType::LargeList(field) => { let list = as_large_list_array(column); self.encode_dictionaries( + field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } - DataType::FixedSizeList(_, _) => { + DataType::FixedSizeList(field, _) => { let list = column .as_any() .downcast_ref::() .expect("Unable to downcast to fixed size list array"); self.encode_dictionaries( + field, list.values(), encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } - DataType::Map(_, _) => { + DataType::Map(field, _) => { let map_array = as_map_array(column); + let (keys, values) = match field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]), + _ => panic!("Incorrect field data type {:?}", field.data_type()), + }; + // keys self.encode_dictionaries( + keys, map_array.keys(), encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; // values self.encode_dictionaries( + values, map_array.values(), encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } DataType::Union(fields, _) => { let union = as_union_array(column); - for (type_id, _) in fields.iter() { + for (type_id, field) in fields.iter() { let column = union.child(type_id); self.encode_dictionaries( + field, column, encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id, )?; } } @@ -311,18 +340,23 @@ impl IpcDataGenerator { Ok(()) } - fn encode_dictionaries( + fn encode_dictionaries>( &self, + field: &Field, column: &ArrayRef, encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - next_dict_id: &mut i64, + dict_id_seq: &mut I, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Dictionary(_key_type, _value_type) => { - let dict_id = *next_dict_id; - *next_dict_id += 1; + let dict_id = dict_id_seq + .next() + .or_else(|| field.dict_id()) + .ok_or_else(|| { + ArrowError::IpcError(format!("no dict id for field {}", field.name())) + })?; let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; @@ -334,7 +368,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id_seq, )?; let emit = dictionary_tracker.insert(dict_id, column)?; @@ -352,7 +386,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, - next_dict_id, + dict_id_seq, )?, } @@ -371,16 +405,17 @@ impl IpcDataGenerator { let schema = batch.schema(); let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len()); - let mut next_dict_id = 0; + let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter(); - for i in 0..schema.fields().len() { + for (i, field) in schema.fields().iter().enumerate() { let column = batch.column(i); self.encode_dictionaries( + field, column, &mut encoded_dictionaries, dictionary_tracker, write_options, - &mut next_dict_id, + &mut dict_id, )?; } @@ -671,20 +706,43 @@ fn into_zero_offset_run_array( /// isn't allowed in the `FileWriter`. pub struct DictionaryTracker { written: HashMap, + dict_ids: Vec, error_on_replacement: bool, + preserve_dict_id: bool, } impl DictionaryTracker { /// Create a new [`DictionaryTracker`]. If `error_on_replacement` /// is true, an error will be generated if an update to an /// existing dictionary is attempted. - pub fn new(error_on_replacement: bool) -> Self { + pub fn new(error_on_replacement: bool, preserve_dict_id: bool) -> Self { Self { written: HashMap::new(), + dict_ids: Vec::new(), error_on_replacement, + preserve_dict_id, } } + pub fn push_dict_id(&mut self, field: &Field) -> i64 { + let next = if self.preserve_dict_id { + field.dict_id().expect("no dict_id in field") + } else { + self.dict_ids + .last() + .copied() + .map(|i| i + 1) + .unwrap_or_default() + }; + + self.dict_ids.push(next); + next + } + + pub fn dict_id(&mut self) -> &[i64] { + &self.dict_ids + } + /// Keep track of the dictionary with the given ID and values. Behavior: /// /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate @@ -771,6 +829,7 @@ impl FileWriter { // write the schema, set the written bytes to the schema + header let encoded_message = data_gen.schema_to_bytes(schema, &write_options); let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; + let preserve_dict_id = write_options.preserve_dict_id; Ok(Self { writer, write_options, @@ -779,7 +838,7 @@ impl FileWriter { dictionary_blocks: vec![], record_blocks: vec![], finished: false, - dictionary_tracker: DictionaryTracker::new(true), + dictionary_tracker: DictionaryTracker::new(true, preserve_dict_id), custom_metadata: HashMap::new(), data_gen, }) @@ -936,11 +995,12 @@ impl StreamWriter { // write the schema, set the written bytes to the schema let encoded_message = data_gen.schema_to_bytes(schema, &write_options); write_message(&mut writer, encoded_message, &write_options)?; + let preserve_dict_id = write_options.preserve_dict_id; Ok(Self { writer, write_options, finished: false, - dictionary_tracker: DictionaryTracker::new(false), + dictionary_tracker: DictionaryTracker::new(false, preserve_dict_id), data_gen, }) } @@ -1025,7 +1085,7 @@ impl StreamWriter { /// /// let schema = Schema::empty(); /// let buffer: Vec = Vec::new(); - /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?; + /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5,true)?; /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?; /// /// assert_eq!(stream_writer.into_inner()?, expected); @@ -1556,7 +1616,7 @@ mod tests { let mut stream_writer = StreamWriter::try_new_with_options( vec![], record.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), ) .unwrap(); stream_writer.write(record).unwrap(); @@ -1581,7 +1641,7 @@ mod tests { let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) .unwrap() .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) .unwrap(); @@ -1621,7 +1681,7 @@ mod tests { let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) .unwrap() .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) .unwrap(); @@ -1660,7 +1720,7 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap(); let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) .unwrap() .try_with_compression(Some(crate::CompressionType::ZSTD)) .unwrap(); @@ -1781,16 +1841,16 @@ mod tests { } #[test] fn test_write_null_file_v4() { - write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); - write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4, true).unwrap()); } #[test] fn test_write_null_file_v5() { - write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5, true).unwrap()); } #[test] @@ -1817,13 +1877,13 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema // so we expect the dict will be keyed to 0 - assert!(dict_tracker.written.contains_key(&0)); + assert!(dict_tracker.written.contains_key(&2)); } #[test] @@ -1853,13 +1913,11 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); - // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema - // so we expect the dict will be keyed to 0 - assert!(dict_tracker.written.contains_key(&0)); + assert!(dict_tracker.written.contains_key(&2)); } fn write_union_file(options: IpcWriteOptions) { @@ -1912,8 +1970,8 @@ mod tests { #[test] fn test_write_union_file_v4_v5() { - write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); - write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); + write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4, true).unwrap()); + write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap()); } #[test] @@ -2423,7 +2481,7 @@ mod tests { let mut writer = FileWriter::try_new_with_options( Vec::new(), batch.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), ) .unwrap(); writer.write(&batch).unwrap(); @@ -2478,7 +2536,7 @@ mod tests { let mut writer = FileWriter::try_new_with_options( Vec::new(), batch.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), ) .unwrap(); writer.write(&batch).unwrap(); From d5d04a583ab9ce85d73d5a16e8a2aaf68db92907 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 28 Jun 2024 10:21:42 -0400 Subject: [PATCH 6/9] Update docs --- arrow-flight/src/encode.rs | 4 ++-- arrow-ipc/src/writer.rs | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 1d0868a7231a..4de0dfa751c3 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -490,7 +490,7 @@ fn prepare_field_for_flight( ) .with_metadata(field.metadata().clone()) } else { - let dict_id = dictionary_tracker.push_dict_id(field.as_ref()); + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); Field::new_dict( field.name(), @@ -529,7 +529,7 @@ fn prepare_schema_for_flight( ) .with_metadata(field.metadata().clone()) } else { - let dict_id = dictionary_tracker.push_dict_id(field.as_ref()); + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); Field::new_dict( field.name(), field.data_type().clone(), diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index e50921272d31..aa31dda5f3c1 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -61,6 +61,8 @@ pub struct IpcWriteOptions { batch_compression_type: Option, /// Flag indicating whether the writer should preserver the dictionary IDs defined in the /// schema or generate unique dictionary IDs internally during encoding. + /// + /// Defaults to `true` preserve_dict_id: bool, } @@ -137,6 +139,8 @@ impl IpcWriteOptions { self.preserve_dict_id } + /// Set whether the IPC writer should preserve the dictionary IDs in the schema + /// or auto-assign uniquer dictionary IDs during encoding pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { self.preserve_dict_id = preserve_dict_id; self @@ -712,9 +716,16 @@ pub struct DictionaryTracker { } impl DictionaryTracker { - /// Create a new [`DictionaryTracker`]. If `error_on_replacement` + /// Create a new [`DictionaryTracker`]. + /// + /// If `error_on_replacement` /// is true, an error will be generated if an update to an /// existing dictionary is attempted. + /// + /// If `preserve_dict_id` is true, the dictionary ID defined in the schema + /// is used, otherwise a unique dictionary ID will be assigned by incrementing + /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been + /// seen) pub fn new(error_on_replacement: bool, preserve_dict_id: bool) -> Self { Self { written: HashMap::new(), @@ -724,7 +735,14 @@ impl DictionaryTracker { } } - pub fn push_dict_id(&mut self, field: &Field) -> i64 { + /// Set the dictionary ID for `field`. + /// + /// If `preserve_dict_id` is true, this will return the `dict_id` in `field` (or panic if `field` does + /// not have a `dict_id` defined). + /// + /// If `preserve_dict_id` is false, this will return the value of the last `dict_id` assigned incremented by 1 + /// or 0 in the case where no dictionary IDs have yet been assigned + pub fn set_dict_id(&mut self, field: &Field) -> i64 { let next = if self.preserve_dict_id { field.dict_id().expect("no dict_id in field") } else { @@ -739,6 +757,8 @@ impl DictionaryTracker { next } + /// Return the sequence of dictionary IDs in the order they should be observed while + /// traversing the schema pub fn dict_id(&mut self) -> &[i64] { &self.dict_ids } From 9ab8a5efc0b8e74ce220e6ab727248a5708f8873 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 28 Jun 2024 11:12:46 -0400 Subject: [PATCH 7/9] Remove breaking change --- arrow-flight/src/encode.rs | 6 ++-- arrow-flight/src/lib.rs | 2 +- arrow-integration-testing/tests/ipc_writer.rs | 4 +-- arrow-ipc/src/writer.rs | 35 +++++++++---------- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 4de0dfa751c3..b1a229584c39 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -687,7 +687,7 @@ mod tests { /// fn test_encode_flight_data() { // use 8-byte alignment - default alignment is 64 which produces bigger ipc data - let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap(); + let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) @@ -1583,9 +1583,7 @@ mod tests { let mut stream = FlightDataEncoderBuilder::new() .with_max_flight_data_size(max_flight_data_size) // use 8-byte alignment - default alignment is 64 which produces bigger ipc data - .with_options( - IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap(), - ) + .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) .build(futures::stream::iter([Ok(batch.clone())])); let mut i = 0; diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index d162f3fe3b47..a4b4ab7bc316 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -876,7 +876,7 @@ mod tests { // V4 with write_legacy_ipc_format = true // this will not write the continuation marker - let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4, true).unwrap(); + let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap(); let schema_ipc = SchemaAsIpc::new(&schema, &option); let result: SchemaResult = schema_ipc.try_into().unwrap(); let des_schema: Schema = (&result).try_into().unwrap(); diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs index 77be217351dc..d780eb2ee0b5 100644 --- a/arrow-integration-testing/tests/ipc_writer.rs +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -98,12 +98,12 @@ fn write_2_0_0_compression() { // writer options for each compression type let all_options = [ - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5, true) + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) .unwrap() .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) .unwrap(), // write IPC version 5 with zstd - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5, true) + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) .unwrap() .try_with_compression(Some(ipc::CompressionType::ZSTD)) .unwrap(), diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index aa31dda5f3c1..412f8dea9c5b 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -91,7 +91,6 @@ impl IpcWriteOptions { alignment: usize, write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, - preserve_dict_id: bool, ) -> Result { let is_alignment_valid = alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64; @@ -112,7 +111,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, - preserve_dict_id, + preserve_dict_id: true, }), crate::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -125,7 +124,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, - preserve_dict_id, + preserve_dict_id: true, }) } } @@ -1105,7 +1104,7 @@ impl StreamWriter { /// /// let schema = Schema::empty(); /// let buffer: Vec = Vec::new(); - /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5,true)?; + /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?; /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?; /// /// assert_eq!(stream_writer.into_inner()?, expected); @@ -1636,7 +1635,7 @@ mod tests { let mut stream_writer = StreamWriter::try_new_with_options( vec![], record.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), ) .unwrap(); stream_writer.write(record).unwrap(); @@ -1661,7 +1660,7 @@ mod tests { let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) .unwrap() .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) .unwrap(); @@ -1701,7 +1700,7 @@ mod tests { let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) .unwrap() .try_with_compression(Some(crate::CompressionType::LZ4_FRAME)) .unwrap(); @@ -1740,7 +1739,7 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap(); let mut file = tempfile::tempfile().unwrap(); { - let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5, true) + let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5) .unwrap() .try_with_compression(Some(crate::CompressionType::ZSTD)) .unwrap(); @@ -1861,16 +1860,16 @@ mod tests { } #[test] fn test_write_null_file_v4() { - write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4, true).unwrap()); - write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4, true).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4, true).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap()); } #[test] fn test_write_null_file_v5() { - write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap()); - write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5, true).unwrap()); + write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); + write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap()); } #[test] @@ -1990,8 +1989,8 @@ mod tests { #[test] fn test_write_union_file_v4_v5() { - write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4, true).unwrap()); - write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5, true).unwrap()); + write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap()); + write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); } #[test] @@ -2501,7 +2500,7 @@ mod tests { let mut writer = FileWriter::try_new_with_options( Vec::new(), batch.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), ) .unwrap(); writer.write(&batch).unwrap(); @@ -2556,7 +2555,7 @@ mod tests { let mut writer = FileWriter::try_new_with_options( Vec::new(), batch.schema_ref(), - IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5, true).unwrap(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), ) .unwrap(); writer.write(&batch).unwrap(); From 3c297cdba41b85dd43d5625d985daaaaa8cbe763 Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:35:49 -0400 Subject: [PATCH 8/9] Update arrow-ipc/src/writer.rs Co-authored-by: Andrew Lamb --- arrow-ipc/src/writer.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 412f8dea9c5b..9e802796da37 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -139,7 +139,12 @@ impl IpcWriteOptions { } /// Set whether the IPC writer should preserve the dictionary IDs in the schema - /// or auto-assign uniquer dictionary IDs during encoding + /// or auto-assign uniquer dictionary IDs during encoding (defaults to true) + /// + /// If this option is true, the application must handle assigning ids + /// to the dictionary batches in order to encode them correctly + /// + /// The default will change to `false` in future releases pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { self.preserve_dict_id = preserve_dict_id; self From 9d4553243fce46c0f6d64839c663b370af3d8050 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 28 Jun 2024 17:02:33 -0400 Subject: [PATCH 9/9] Remove breaking change to DictionaryTracker ctor --- arrow-flight/src/encode.rs | 9 +++-- arrow-flight/src/utils.rs | 6 ++-- .../integration_test.rs | 2 +- .../integration_test.rs | 3 +- arrow-ipc/src/reader.rs | 4 +-- arrow-ipc/src/writer.rs | 34 +++++++++++++++---- 6 files changed, 42 insertions(+), 16 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index b1a229584c39..e7722fd7f0a8 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -600,7 +600,10 @@ impl FlightIpcEncoder { Self { options, data_gen: IpcDataGenerator::default(), - dictionary_tracker: DictionaryTracker::new(error_on_replacement, preserve_dict_id), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + error_on_replacement, + preserve_dict_id, + ), } } @@ -1299,7 +1302,7 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let mut dictionary_tracker = DictionaryTracker::new(false, true); + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); assert!(got.metadata().contains_key("some_key")); @@ -1322,7 +1325,7 @@ mod tests { options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new(false, true); + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 008990a4df01..c1e2d61fc5a9 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -39,7 +39,8 @@ pub fn flight_data_from_arrow_batch( options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false, options.preserve_dict_id()); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -149,7 +150,8 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false, options.preserve_dict_id()); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); for batch in batches.iter() { let (encoded_dictionaries, encoded_batch) = diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index 7bb69dabfdca..ec88ce36a4d2 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -123,7 +123,7 @@ async fn send_batch( options: &writer::IpcWriteOptions, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false, true); + let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 5cdb4b6703ec..a03c1cd1a31a 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -119,7 +119,8 @@ impl FlightService for FlightServiceImpl { .enumerate() .flat_map(|(counter, batch)| { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false, true); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, &options) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 0567580f3785..1f83200d65f8 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -2082,7 +2082,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false, true); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -2119,7 +2119,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false, true); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 9e802796da37..c0782195999d 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -139,7 +139,7 @@ impl IpcWriteOptions { } /// Set whether the IPC writer should preserve the dictionary IDs in the schema - /// or auto-assign uniquer dictionary IDs during encoding (defaults to true) + /// or auto-assign unique dictionary IDs during encoding (defaults to true) /// /// If this option is true, the application must handle assigning ids /// to the dictionary batches in order to encode them correctly @@ -183,7 +183,7 @@ impl Default for IpcWriteOptions { /// // Error of dictionary ids are replaced. /// let error_on_replacement = true; /// let options = IpcWriteOptions::default(); -/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement,true); +/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement); /// /// // encode the batch into zero or more encoded dictionaries /// // and the data for the actual array. @@ -730,7 +730,21 @@ impl DictionaryTracker { /// is used, otherwise a unique dictionary ID will be assigned by incrementing /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been /// seen) - pub fn new(error_on_replacement: bool, preserve_dict_id: bool) -> Self { + pub fn new(error_on_replacement: bool) -> Self { + Self { + written: HashMap::new(), + dict_ids: Vec::new(), + error_on_replacement, + preserve_dict_id: true, + } + } + + /// Create a new [`DictionaryTracker`]. + /// + /// If `error_on_replacement` + /// is true, an error will be generated if an update to an + /// existing dictionary is attempted. + pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self { Self { written: HashMap::new(), dict_ids: Vec::new(), @@ -862,7 +876,10 @@ impl FileWriter { dictionary_blocks: vec![], record_blocks: vec![], finished: false, - dictionary_tracker: DictionaryTracker::new(true, preserve_dict_id), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + true, + preserve_dict_id, + ), custom_metadata: HashMap::new(), data_gen, }) @@ -1024,7 +1041,10 @@ impl StreamWriter { writer, write_options, finished: false, - dictionary_tracker: DictionaryTracker::new(false, preserve_dict_id), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + false, + preserve_dict_id, + ), data_gen, }) } @@ -1901,7 +1921,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false, true); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -1937,7 +1957,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false, true); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap();