From 871c9996019cc5af120d4115285014c59fb1521a Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Sat, 29 Jun 2024 07:34:56 -0400 Subject: [PATCH] Handle flight dictionary ID assignment automatically (#5971) * failing test * Handle dict ID assignment during flight encoding/decoding * remove println * One more println * Make auto-assign optional * Update docs * Remove breaking change * Update arrow-ipc/src/writer.rs Co-authored-by: Andrew Lamb * Remove breaking change to DictionaryTracker ctor --------- Co-authored-by: Andrew Lamb --- arrow-flight/src/encode.rs | 373 +++++++++++++++--- arrow-flight/src/utils.rs | 6 +- .../integration_test.rs | 2 +- .../integration_test.rs | 3 +- arrow-ipc/src/reader.rs | 10 +- arrow-ipc/src/writer.rs | 132 ++++++- 6 files changed, 443 insertions(+), 83 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index f59c29e68173..e7722fd7f0a8 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -304,7 +304,11 @@ impl FlightDataEncoder { // The first message is the schema message, and all // batches have the same schema let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; - let schema = Arc::new(prepare_schema_for_flight(schema, send_dictionaries)); + let schema = Arc::new(prepare_schema_for_flight( + schema, + &mut self.encoder.dictionary_tracker, + send_dictionaries, + )); let mut schema_flight_data = self.encoder.encode_schema(&schema); // attach any metadata requested @@ -438,24 +442,28 @@ pub enum DictionaryHandling { Resend, } -fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field { +fn prepare_field_for_flight( + field: &FieldRef, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Field { match field.data_type() { DataType::List(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, send_dictionaries), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::LargeList(inner) => Field::new_list( field.name(), - prepare_field_for_flight(inner, send_dictionaries), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), field.is_nullable(), ) .with_metadata(field.metadata().clone()), DataType::Struct(fields) => { let new_fields: Vec = fields .iter() - .map(|f| prepare_field_for_flight(f, send_dictionaries)) + .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries)) .collect(); Field::new_struct(field.name(), new_fields, field.is_nullable()) .with_metadata(field.metadata().clone()) @@ -463,17 +471,37 @@ fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field DataType::Union(fields, mode) => { let (type_ids, new_fields): (Vec, Vec) = fields .iter() - .map(|(type_id, f)| (type_id, prepare_field_for_flight(f, send_dictionaries))) + .map(|(type_id, f)| { + ( + type_id, + prepare_field_for_flight(f, dictionary_tracker, send_dictionaries), + ) + }) .unzip(); Field::new_union(field.name(), type_ids, new_fields, *mode) } - DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( - field.name(), - value_type.as_ref().clone(), - field.is_nullable(), - ) - .with_metadata(field.metadata().clone()), + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } _ => field.as_ref().clone(), } } @@ -483,18 +511,38 @@ fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field /// Convert dictionary types to underlying types /// /// See hydrate_dictionary for more information -fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema { +fn prepare_schema_for_flight( + schema: &Schema, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Schema { let fields: Fields = schema .fields() .iter() .map(|field| match field.data_type() { - DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( - field.name(), - value_type.as_ref().clone(), - field.is_nullable(), - ) - .with_metadata(field.metadata().clone()), - tpe if tpe.is_nested() => prepare_field_for_flight(field, send_dictionaries), + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + tpe if tpe.is_nested() => { + prepare_field_for_flight(field, dictionary_tracker, send_dictionaries) + } _ => field.as_ref().clone(), }) .collect(); @@ -548,10 +596,14 @@ struct FlightIpcEncoder { impl FlightIpcEncoder { fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { + let preserve_dict_id = options.preserve_dict_id(); Self { options, data_gen: IpcDataGenerator::default(), - dictionary_tracker: DictionaryTracker::new(error_on_replacement), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + error_on_replacement, + preserve_dict_id, + ), } } @@ -619,7 +671,10 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + + #[tokio::test] + async fn test_multiple_dictionaries_resend() { + // Create a schema with two dictionary fields that have the same dict ID + let schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false), + Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false), + ])); + + let arr_one_1: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_one_2: Arc> = + Arc::new(vec!["c", "c", "d"].into_iter().collect()); + let arr_two_1: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let arr_two_2: Arc> = + Arc::new(vec!["k", "d", "e"].into_iter().collect()); + let batch1 = + RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()]) + .unwrap(); + let batch2 = + RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()]) + .unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_list_hydration() { - let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); builder.append_value(vec![Some("a"), None, Some("b")]); @@ -766,6 +862,30 @@ mod tests { } } + #[tokio::test] + async fn test_dictionary_list_resend() { + let mut builder = ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_struct_hydration() { let struct_fields = vec![Field::new_list( @@ -774,26 +894,38 @@ mod tests { true, )]; - let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); - builder.append_value(vec![Some("a"), None, Some("b")]); + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("a"), None, Some("b")]); - let arr1 = Arc::new(builder.finish()); - let arr1 = StructArray::new(struct_fields.clone().into(), vec![arr1], None); + struct_builder.append(true); - builder.append_value(vec![Some("c"), None, Some("d")]); + let arr1 = struct_builder.finish(); - let arr2 = Arc::new(builder.finish()); - let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + struct_builder + .field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); let schema = Arc::new(Schema::new(vec![Field::new_struct( "struct", - struct_fields.clone(), + struct_fields, true, )])); let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); - let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); @@ -838,6 +970,43 @@ mod tests { } } + #[tokio::test] + async fn test_dictionary_struct_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut struct_builder = StructBuilder::new( + struct_fields.clone(), + vec![Box::new(builder::ListBuilder::new( + StringDictionaryBuilder::::new(), + ))], + ); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("a"), None, Some("b")]); + struct_builder.append(true); + + let arr1 = struct_builder.finish(); + + struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("c"), None, Some("d")]); + struct_builder.append(true); + + let arr2 = struct_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "struct", + struct_fields, + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2]).await; + } + #[tokio::test] async fn test_dictionary_union_hydration() { let struct_fields = vec![Field::new_list( @@ -1004,42 +1173,124 @@ mod tests { } #[tokio::test] - async fn test_send_dictionaries() { - let schema = Arc::new(Schema::new(vec![Field::new_dictionary( - "dict", - DataType::UInt16, - DataType::Utf8, - false, + async fn test_dictionary_union_resend() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let union_fields = [ + ( + 0, + Arc::new(Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )), + ), + ( + 1, + Arc::new(Field::new_struct("struct", struct_fields.clone(), true)), + ), + (2, Arc::new(Field::new("string", DataType::Utf8, true))), + ] + .into_iter() + .collect::(); + + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + let type_id_buffer = [0].into_iter().collect::>(); + let arr1 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + Arc::new(arr1) as Arc, + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let type_id_buffer = [1].into_iter().collect::>(); + let arr2 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + Arc::new(arr2), + new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + ], + ) + .unwrap(); + + let type_id_buffer = [2].into_iter().collect::>(); + let arr3 = UnionArray::try_new( + union_fields.clone(), + type_id_buffer, + None, + vec![ + new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + Arc::new(StringArray::from(vec!["e"])), + ], + ) + .unwrap(); + + let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields + .iter() + .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone())) + .unzip(); + let schema = Arc::new(Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + union_fields.clone(), + UnionMode::Sparse, )])); - let arr_one: Arc> = - Arc::new(vec!["a", "a", "b"].into_iter().collect()); - let arr_two: Arc> = - Arc::new(vec!["b", "a", "c"].into_iter().collect()); - let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); - let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap(); + + verify_flight_round_trip(vec![batch1, batch2, batch3]).await; + } + + async fn verify_flight_round_trip(mut batches: Vec) { + let expected_schema = batches.first().unwrap().schema(); let encoder = FlightDataEncoderBuilder::default() + .with_options(IpcWriteOptions::default().with_preserve_dict_id(false)) .with_dictionary_handling(DictionaryHandling::Resend) - .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + .build(futures::stream::iter(batches.clone().into_iter().map(Ok))); + + let mut expected_batches = batches.drain(..); let mut decoder = FlightDataDecoder::new(encoder); - let mut expected_array = arr_one; while let Some(decoded) = decoder.next().await { let decoded = decoded.unwrap(); match decoded.payload { DecodedPayload::None => {} - DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), DecodedPayload::RecordBatch(b) => { - assert_eq!(b.schema(), schema); - - let actual_array = Arc::new(downcast_array::>( - b.column_by_name("dict").unwrap(), - )); - - assert_eq!(actual_array, expected_array); - - expected_array = arr_two.clone(); + let expected_batch = expected_batches.next().unwrap(); + assert_eq!(b, expected_batch); } } } @@ -1051,7 +1302,9 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let got = prepare_schema_for_flight(&schema, false); + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); + + let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); assert!(got.metadata().contains_key("some_key")); } @@ -1072,7 +1325,7 @@ mod tests { options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new(false); + let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 32716a52eb0d..c1e2d61fc5a9 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -39,7 +39,8 @@ pub fn flight_data_from_arrow_batch( options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) @@ -149,7 +150,8 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); for batch in batches.iter() { let (encoded_dictionaries, encoded_batch) = diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c6b5a72ca6e2..ec88ce36a4d2 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -123,7 +123,7 @@ async fn send_batch( options: &writer::IpcWriteOptions, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, options) diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 25203ecb7697..a03c1cd1a31a 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -119,7 +119,8 @@ impl FlightService for FlightServiceImpl { .enumerate() .flat_map(|(counter, batch)| { let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut dictionary_tracker = + writer::DictionaryTracker::new_with_preserve_dict_id(false, true); let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, &options) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 3423d06b6fca..1f83200d65f8 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -617,7 +617,7 @@ fn read_dictionary_impl( let id = batch.id(); let fields_using_this_dictionary = schema.fields_with_dict_id(id); let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; // As the dictionary batch does not contain the type of the @@ -643,7 +643,7 @@ fn read_dictionary_impl( _ => None, } .ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) + ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; // We don't currently record the isOrdered field. This could be general @@ -1812,7 +1812,7 @@ mod tests { "values", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), true, - 1, + 2, false, )); let entry_struct = StructArray::from(vec![ @@ -2082,7 +2082,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); @@ -2119,7 +2119,7 @@ mod tests { .unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); let (_, encoded) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index f74a86e013cd..c0782195999d 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -59,6 +59,11 @@ pub struct IpcWriteOptions { /// Compression, if desired. Will result in a runtime error /// if the corresponding feature is not enabled batch_compression_type: Option, + /// Flag indicating whether the writer should preserver the dictionary IDs defined in the + /// schema or generate unique dictionary IDs internally during encoding. + /// + /// Defaults to `true` + preserve_dict_id: bool, } impl IpcWriteOptions { @@ -81,7 +86,7 @@ impl IpcWriteOptions { } Ok(self) } - /// Try create IpcWriteOptions, checking for incompatible settings + /// Try to create IpcWriteOptions, checking for incompatible settings pub fn try_new( alignment: usize, write_legacy_ipc_format: bool, @@ -106,6 +111,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + preserve_dict_id: true, }), crate::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -118,6 +124,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + preserve_dict_id: true, }) } } @@ -126,6 +133,22 @@ impl IpcWriteOptions { ))), } } + + pub fn preserve_dict_id(&self) -> bool { + self.preserve_dict_id + } + + /// Set whether the IPC writer should preserve the dictionary IDs in the schema + /// or auto-assign unique dictionary IDs during encoding (defaults to true) + /// + /// If this option is true, the application must handle assigning ids + /// to the dictionary batches in order to encode them correctly + /// + /// The default will change to `false` in future releases + pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self { + self.preserve_dict_id = preserve_dict_id; + self + } } impl Default for IpcWriteOptions { @@ -135,6 +158,7 @@ impl Default for IpcWriteOptions { write_legacy_ipc_format: false, metadata_version: crate::MetadataVersion::V5, batch_compression_type: None, + preserve_dict_id: true, } } } @@ -163,7 +187,7 @@ impl Default for IpcWriteOptions { /// /// // encode the batch into zero or more encoded dictionaries /// // and the data for the actual array. -/// let data_gen = IpcDataGenerator {}; +/// let data_gen = IpcDataGenerator::default(); /// let (encoded_dictionaries, encoded_message) = data_gen /// .encoded_batch(&batch, &mut dictionary_tracker, &options) /// .unwrap(); @@ -198,12 +222,13 @@ impl IpcDataGenerator { } } - fn _encode_dictionaries( + fn _encode_dictionaries>( &self, column: &ArrayRef, encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, + dict_id: &mut I, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Struct(fields) => { @@ -215,6 +240,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } } @@ -235,6 +261,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } DataType::List(field) => { @@ -245,6 +272,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } DataType::LargeList(field) => { @@ -255,6 +283,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } DataType::FixedSizeList(field, _) => { @@ -268,6 +297,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } DataType::Map(field, _) => { @@ -285,6 +315,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; // values @@ -294,6 +325,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } DataType::Union(fields, _) => { @@ -306,6 +338,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id, )?; } } @@ -315,19 +348,24 @@ impl IpcDataGenerator { Ok(()) } - fn encode_dictionaries( + fn encode_dictionaries>( &self, field: &Field, column: &ArrayRef, encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, + dict_id_seq: &mut I, ) -> Result<(), ArrowError> { match column.data_type() { DataType::Dictionary(_key_type, _value_type) => { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); + let dict_id = dict_id_seq + .next() + .or_else(|| field.dict_id()) + .ok_or_else(|| { + ArrowError::IpcError(format!("no dict id for field {}", field.name())) + })?; + let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; @@ -338,6 +376,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id_seq, )?; let emit = dictionary_tracker.insert(dict_id, column)?; @@ -355,6 +394,7 @@ impl IpcDataGenerator { encoded_dictionaries, dictionary_tracker, write_options, + dict_id_seq, )?, } @@ -373,6 +413,8 @@ impl IpcDataGenerator { let schema = batch.schema(); let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len()); + let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter(); + for (i, field) in schema.fields().iter().enumerate() { let column = batch.column(i); self.encode_dictionaries( @@ -381,6 +423,7 @@ impl IpcDataGenerator { &mut encoded_dictionaries, dictionary_tracker, write_options, + &mut dict_id, )?; } @@ -671,20 +714,73 @@ fn into_zero_offset_run_array( /// isn't allowed in the `FileWriter`. pub struct DictionaryTracker { written: HashMap, + dict_ids: Vec, error_on_replacement: bool, + preserve_dict_id: bool, } impl DictionaryTracker { - /// Create a new [`DictionaryTracker`]. If `error_on_replacement` + /// Create a new [`DictionaryTracker`]. + /// + /// If `error_on_replacement` /// is true, an error will be generated if an update to an /// existing dictionary is attempted. + /// + /// If `preserve_dict_id` is true, the dictionary ID defined in the schema + /// is used, otherwise a unique dictionary ID will be assigned by incrementing + /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been + /// seen) pub fn new(error_on_replacement: bool) -> Self { Self { written: HashMap::new(), + dict_ids: Vec::new(), error_on_replacement, + preserve_dict_id: true, } } + /// Create a new [`DictionaryTracker`]. + /// + /// If `error_on_replacement` + /// is true, an error will be generated if an update to an + /// existing dictionary is attempted. + pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self { + Self { + written: HashMap::new(), + dict_ids: Vec::new(), + error_on_replacement, + preserve_dict_id, + } + } + + /// Set the dictionary ID for `field`. + /// + /// If `preserve_dict_id` is true, this will return the `dict_id` in `field` (or panic if `field` does + /// not have a `dict_id` defined). + /// + /// If `preserve_dict_id` is false, this will return the value of the last `dict_id` assigned incremented by 1 + /// or 0 in the case where no dictionary IDs have yet been assigned + pub fn set_dict_id(&mut self, field: &Field) -> i64 { + let next = if self.preserve_dict_id { + field.dict_id().expect("no dict_id in field") + } else { + self.dict_ids + .last() + .copied() + .map(|i| i + 1) + .unwrap_or_default() + }; + + self.dict_ids.push(next); + next + } + + /// Return the sequence of dictionary IDs in the order they should be observed while + /// traversing the schema + pub fn dict_id(&mut self) -> &[i64] { + &self.dict_ids + } + /// Keep track of the dictionary with the given ID and values. Behavior: /// /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate @@ -771,6 +867,7 @@ impl FileWriter { // write the schema, set the written bytes to the schema + header let encoded_message = data_gen.schema_to_bytes(schema, &write_options); let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; + let preserve_dict_id = write_options.preserve_dict_id; Ok(Self { writer, write_options, @@ -779,7 +876,10 @@ impl FileWriter { dictionary_blocks: vec![], record_blocks: vec![], finished: false, - dictionary_tracker: DictionaryTracker::new(true), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + true, + preserve_dict_id, + ), custom_metadata: HashMap::new(), data_gen, }) @@ -936,11 +1036,15 @@ impl StreamWriter { // write the schema, set the written bytes to the schema let encoded_message = data_gen.schema_to_bytes(schema, &write_options); write_message(&mut writer, encoded_message, &write_options)?; + let preserve_dict_id = write_options.preserve_dict_id; Ok(Self { writer, write_options, finished: false, - dictionary_tracker: DictionaryTracker::new(false), + dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( + false, + preserve_dict_id, + ), data_gen, }) } @@ -1817,11 +1921,12 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); - // Dictionary with id 2 should have been written to the dict tracker + // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema + // so we expect the dict will be keyed to 0 assert!(dict_tracker.written.contains_key(&2)); } @@ -1852,11 +1957,10 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); + let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) .unwrap(); - // Dictionary with id 2 should have been written to the dict tracker assert!(dict_tracker.written.contains_key(&2)); }