Skip to content

Commit

Permalink
Add IPC FileDecoder (#5249)
Browse files Browse the repository at this point in the history
* Add IPC FileDecoder

* Clippy

* Update arrow-ipc/src/reader.rs

Co-authored-by: Liang-Chi Hsieh <[email protected]>

---------

Co-authored-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
tustvold and viirya authored Dec 30, 2023
1 parent fad103a commit 9863486
Showing 1 changed file with 186 additions and 97 deletions.
283 changes: 186 additions & 97 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ pub fn read_dictionary(
batch: crate::DictionaryBatch,
schema: &Schema,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &crate::MetadataVersion,
metadata: &MetadataVersion,
) -> Result<(), ArrowError> {
if batch.isDelta() {
return Err(ArrowError::InvalidArgumentError(
Expand Down Expand Up @@ -522,6 +522,174 @@ fn parse_message(buf: &[u8]) -> Result<Message, ArrowError> {
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
}

/// Read the footer length from the last 10 bytes of an Arrow IPC file
///
/// Expects a 4 byte footer length followed by `b"ARROW1"`
pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
if buf[4..] != super::ARROW_MAGIC {
return Err(ArrowError::ParseError(
"Arrow file does not contain correct footer".to_string(),
));
}

// read footer length
let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
footer_len
.try_into()
.map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
}

/// A low-level, push-based interface for reading an IPC file
///
/// For a higher-level interface see [`FileReader`]
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::*;
/// # use arrow_array::types::Int32Type;
/// # use arrow_buffer::Buffer;
/// # use arrow_ipc::convert::fb_to_schema;
/// # use arrow_ipc::reader::{FileDecoder, read_footer_length};
/// # use arrow_ipc::root_as_footer;
/// # use arrow_ipc::writer::FileWriter;
/// // Write an IPC file
///
/// let batch = RecordBatch::try_from_iter([
/// ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
/// ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
/// ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", "hello", "world"])) as _),
/// ]).unwrap();
///
/// let schema = batch.schema();
///
/// let mut out = Vec::with_capacity(1024);
/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap();
/// writer.write(&batch).unwrap();
/// writer.finish().unwrap();
///
/// drop(writer);
///
/// // Read IPC file
///
/// let buffer = Buffer::from_vec(out);
/// let trailer_start = buffer.len() - 10;
/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
///
/// let back = fb_to_schema(footer.schema().unwrap());
/// assert_eq!(&back, schema.as_ref());
///
/// let mut decoder = FileDecoder::new(schema, footer.version());
///
/// // Read dictionaries
/// for block in footer.dictionaries().iter().flatten() {
/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
/// let data = buffer.slice_with_length(block.offset() as _, block_len);
/// decoder.read_dictionary(&block, &data).unwrap();
/// }
///
/// // Read record batch
/// let batches = footer.recordBatches().unwrap();
/// assert_eq!(batches.len(), 1); // Only wrote a single batch
///
/// let block = batches.get(0);
/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
/// let data = buffer.slice_with_length(block.offset() as _, block_len);
/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap();
///
/// assert_eq!(batch, back);
/// ```
#[derive(Debug)]
pub struct FileDecoder {
schema: SchemaRef,
dictionaries: HashMap<i64, ArrayRef>,
version: MetadataVersion,
projection: Option<Vec<usize>>,
}

impl FileDecoder {
/// Create a new [`FileDecoder`] with the given schema and version
pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
Self {
schema,
version,
dictionaries: Default::default(),
projection: None,
}
}

/// Specify a projection
pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
self.projection = Some(projection);
self
}

fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
let message = parse_message(buf)?;

// some old test data's footer metadata is not set, so we account for that
if self.version != MetadataVersion::V1 && message.version() != self.version {
return Err(ArrowError::IpcError(
"Could not read IPC message as metadata versions mismatch".to_string(),
));
}
Ok(message)
}

/// Read the dictionary with the given block and data buffer
pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
let message = self.read_message(buf)?;
match message.header_type() {
crate::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().unwrap();
read_dictionary(
&buf.slice(block.metaDataLength() as _),
batch,
&self.schema,
&mut self.dictionaries,
&message.version(),
)
}
t => Err(ArrowError::ParseError(format!(
"Expecting DictionaryBatch in dictionary blocks, found {t:?}."
))),
}
}

/// Read the RecordBatch with the given block and data buffer
pub fn read_record_batch(
&self,
block: &Block,
buf: &Buffer,
) -> Result<Option<RecordBatch>, ArrowError> {
let message = self.read_message(buf)?;
match message.header_type() {
crate::MessageHeader::Schema => Err(ArrowError::IpcError(
"Not expecting a schema when messages are read".to_string(),
)),
crate::MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().ok_or_else(|| {
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
})?;
// read the block that makes up the record batch into a buffer
read_record_batch(
&buf.slice(block.metaDataLength() as _),
batch,
self.schema.clone(),
&self.dictionaries,
self.projection.as_deref(),
&message.version(),
)
.map(Some)
}
crate::MessageHeader::NONE => Ok(None),
t => Err(ArrowError::InvalidArgumentError(format!(
"Reading types other than record batches not yet supported, unable to read {t:?}"
))),
}
}
}

/// Build an Arrow [`FileReader`] with custom options.
#[derive(Debug)]
pub struct FileReaderBuilder {
Expand Down Expand Up @@ -599,17 +767,10 @@ impl FileReaderBuilder {
reader.seek(SeekFrom::End(-10))?;
reader.read_exact(&mut buffer)?;

if buffer[4..] != super::ARROW_MAGIC {
return Err(ArrowError::ParseError(
"Arrow file does not contain correct footer".to_string(),
));
}

// read footer length
let footer_len = i32::from_le_bytes(buffer[..4].try_into().unwrap());
let footer_len = read_footer_length(buffer)?;

// read footer
let mut footer_data = vec![0; footer_len as usize];
let mut footer_data = vec![0; footer_len];
reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
reader.read_exact(&mut footer_data)?;

Expand Down Expand Up @@ -641,50 +802,26 @@ impl FileReaderBuilder {
}
}

let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
if let Some(projection) = self.projection {
decoder = decoder.with_projection(projection)
}

// Create an array of optional dictionary value arrays, one per field.
let mut dictionaries_by_id = HashMap::new();
if let Some(dictionaries) = footer.dictionaries() {
for block in dictionaries {
let buf = read_block(&mut reader, block)?;
let message = parse_message(&buf)?;

match message.header_type() {
crate::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().unwrap();
read_dictionary(
&buf.slice(block.metaDataLength() as _),
batch,
&schema,
&mut dictionaries_by_id,
&message.version(),
)?;
}
t => {
return Err(ArrowError::ParseError(format!(
"Expecting DictionaryBatch in dictionary blocks, found {t:?}."
)));
}
}
decoder.read_dictionary(block, &buf)?;
}
}
let projection = match self.projection {
Some(projection_indices) => {
let schema = schema.project(&projection_indices)?;
Some((projection_indices, schema))
}
_ => None,
};

Ok(FileReader {
reader,
schema: Arc::new(schema),
blocks: blocks.iter().copied().collect(),
current_block: 0,
total_blocks,
dictionaries_by_id,
metadata_version: footer.version(),
decoder,
custom_metadata,
projection,
})
}
}
Expand All @@ -694,45 +831,31 @@ pub struct FileReader<R: Read + Seek> {
/// Buffered file reader that supports reading and seeking
reader: R,

/// The schema that is read from the file header
schema: SchemaRef,
/// The decoder
decoder: FileDecoder,

/// The blocks in the file
///
/// A block indicates the regions in the file to read to get data
blocks: Vec<crate::Block>,
blocks: Vec<Block>,

/// A counter to keep track of the current block that should be read
current_block: usize,

/// The total number of blocks, which may contain record batches and other types
total_blocks: usize,

/// Optional dictionaries for each schema field.
///
/// Dictionaries may be appended to in the streaming format.
dictionaries_by_id: HashMap<i64, ArrayRef>,

/// Metadata version
metadata_version: crate::MetadataVersion,

/// User defined metadata
custom_metadata: HashMap<String, String>,

/// Optional projection and projected_schema
projection: Option<(Vec<usize>, Schema)>,
}

impl<R: Read + Seek> fmt::Debug for FileReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("FileReader<R>")
.field("schema", &self.schema)
.field("decoder", &self.decoder)
.field("blocks", &self.blocks)
.field("current_block", &self.current_block)
.field("total_blocks", &self.total_blocks)
.field("dictionaries_by_id", &self.dictionaries_by_id)
.field("metadata_version", &self.metadata_version)
.field("projection", &self.projection)
.finish_non_exhaustive()
}
}
Expand Down Expand Up @@ -761,7 +884,7 @@ impl<R: Read + Seek> FileReader<R> {

/// Return the schema of the file
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
self.decoder.schema.clone()
}

/// Read a specific record batch
Expand All @@ -785,41 +908,7 @@ impl<R: Read + Seek> FileReader<R> {

// read length
let buffer = read_block(&mut self.reader, block)?;
let message = parse_message(&buffer)?;

// some old test data's footer metadata is not set, so we account for that
if self.metadata_version != MetadataVersion::V1
&& message.version() != self.metadata_version
{
return Err(ArrowError::IpcError(
"Could not read IPC message as metadata versions mismatch".to_string(),
));
}

match message.header_type() {
crate::MessageHeader::Schema => Err(ArrowError::IpcError(
"Not expecting a schema when messages are read".to_string(),
)),
crate::MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().ok_or_else(|| {
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
})?;
// read the block that makes up the record batch into a buffer
read_record_batch(
&buffer.slice(block.metaDataLength() as _),
batch,
self.schema(),
&self.dictionaries_by_id,
self.projection.as_ref().map(|x| x.0.as_ref()),
&message.version(),
)
.map(Some)
}
crate::MessageHeader::NONE => Ok(None),
t => Err(ArrowError::InvalidArgumentError(format!(
"Reading types other than record batches not yet supported, unable to read {t:?}"
))),
}
self.decoder.read_record_batch(block, &buffer)
}

/// Gets a reference to the underlying reader.
Expand Down Expand Up @@ -852,7 +941,7 @@ impl<R: Read + Seek> Iterator for FileReader<R> {

impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
fn schema(&self) -> SchemaRef {
self.schema.clone()
self.schema()
}
}

Expand Down

0 comments on commit 9863486

Please sign in to comment.