Skip to content

Commit

Permalink
rust: use thiserror for error types (#4341)
Browse files Browse the repository at this point in the history
Summary:
Our errors now have nice string formatting, easier `From` conversions,
and `std::error::Error` implementations.

Test Plan:
Included some tests for the `Display` implementations. They’re often not
necessary—one benefit of deriving traits is that you can be confident in
the implementation without manually testing it. But sometimes, if the
format string is non-trivial, it can be nice to actually see the full
text written out.

wchargin-branch: rust-use-thiserror
  • Loading branch information
wchargin authored Nov 18, 2020
1 parent 44752ea commit a9e1bba
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 33 deletions.
1 change: 1 addition & 0 deletions tensorboard/data/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ rust_library(
"//third_party/rust:prost",
"//third_party/rust:rand",
"//third_party/rust:rand_chacha",
"//third_party/rust:thiserror",
"//third_party/rust:tonic",
],
)
Expand Down
30 changes: 8 additions & 22 deletions tensorboard/data/server/event_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,34 +35,20 @@ pub struct EventFileReader<R> {
}

/// Error returned by [`EventFileReader::read_event`].
#[derive(Debug)]
#[derive(Debug, thiserror::Error)]
pub enum ReadEventError {
/// The record failed its checksum.
InvalidRecord(ChecksumError),
#[error(transparent)]
InvalidRecord(#[from] ChecksumError),
/// The record passed its checksum, but the contained protocol buffer is invalid.
InvalidProto(DecodeError),
#[error(transparent)]
InvalidProto(#[from] DecodeError),
/// The record is a valid `Event` proto, but its `wall_time` is `NaN`.
#[error("NaN wall time at step {}", .0.step)]
NanWallTime(Event),
/// An error occurred reading the record. May or may not be fatal.
ReadRecordError(ReadRecordError),
}

impl From<DecodeError> for ReadEventError {
fn from(e: DecodeError) -> Self {
ReadEventError::InvalidProto(e)
}
}

impl From<ChecksumError> for ReadEventError {
fn from(e: ChecksumError) -> Self {
ReadEventError::InvalidRecord(e)
}
}

impl From<ReadRecordError> for ReadEventError {
fn from(e: ReadRecordError) -> Self {
ReadEventError::ReadRecordError(e)
}
#[error(transparent)]
ReadRecordError(#[from] ReadRecordError),
}

impl ReadEventError {
Expand Down
12 changes: 10 additions & 2 deletions tensorboard/data/server/masked_crc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.

//! Checksums as used by TFRecords.

use std::fmt::{self, Debug};
use std::fmt::{self, Debug, Display};

/// A CRC-32C (Castagnoli) checksum that has undergone a masking permutation.
///
Expand All @@ -30,7 +30,13 @@ pub struct MaskedCrc(pub u32);
// Implement `Debug` manually to use zero-padded hex output.
impl Debug for MaskedCrc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MaskedCrc({:#010x?})", self.0)
write!(f, "MaskedCrc({})", self)
}
}

impl Display for MaskedCrc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:#010x?}", self.0)
}
}

Expand Down Expand Up @@ -78,9 +84,11 @@ mod tests {
#[test]
fn test_debug() {
let long_crc = MaskedCrc(0xf1234567);
assert_eq!(format!("{}", long_crc), "0xf1234567");
assert_eq!(format!("{:?}", long_crc), "MaskedCrc(0xf1234567)");

let short_crc = MaskedCrc(0x00000123);
assert_eq!(format!("{}", short_crc), "0x00000123");
assert_eq!(format!("{:?}", short_crc), "MaskedCrc(0x00000123)");
}
}
43 changes: 34 additions & 9 deletions tensorboard/data/server/tf_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ pub struct TfRecord {
}

/// A buffer's checksum was computed, but it did not match the expected value.
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
#[error("checksum mismatch: got {got}, want {want}")]
pub struct ChecksumError {
/// The actual checksum of the buffer.
pub got: MaskedCrc,
Expand Down Expand Up @@ -112,31 +113,29 @@ impl TfRecord {
}

/// Error returned by [`TfRecordReader::read_record`].
#[derive(Debug)]
#[derive(Debug, thiserror::Error)]
pub enum ReadRecordError {
/// Length field failed checksum. The file is corrupt, and reading must abort.
#[error("length checksum mismatch: got {}, want {}", .0.got, .0.want)]
BadLengthCrc(ChecksumError),
/// No fatal errors so far, but the record is not complete. Call `read_record` again with the
/// same state buffer once new data may be available.
///
/// This includes the "trivial truncation" case where there are no bytes in a new record, so
/// repeatedly reading records from a file of zero or more well-formed records will always
/// finish with a `Truncated` error.
#[error("record truncated")]
Truncated,
/// Record is too large to be represented in memory on this system.
///
/// In principle, it would be possible to recover from this error, but in practice this should
/// rarely occur since serialized protocol buffers do not exceed 2 GiB in size. Thus, no
/// recovery codepath has been implemented, so reading must abort.
#[error("record too large to fit in memory ({0} bytes)")]
TooLarge(u64),
/// Underlying I/O error. May be retryable if the underlying error is.
Io(io::Error),
}

impl From<io::Error> for ReadRecordError {
fn from(e: io::Error) -> Self {
ReadRecordError::Io(e)
}
#[error(transparent)]
Io(#[from] io::Error),
}

impl<R: Debug> Debug for TfRecordReader<R> {
Expand Down Expand Up @@ -405,6 +404,32 @@ mod tests {
}
}

#[test]
fn test_error_display() {
let e = ReadRecordError::BadLengthCrc(ChecksumError {
got: MaskedCrc(0x01234567),
want: MaskedCrc(0xfedcba98),
});
assert_eq!(
e.to_string(),
"length checksum mismatch: got 0x01234567, want 0xfedcba98"
);

let e = ReadRecordError::Truncated;
assert_eq!(e.to_string(), "record truncated");

let e = ReadRecordError::TooLarge(999);
assert_eq!(
e.to_string(),
"record too large to fit in memory (999 bytes)"
);

let io_error = io::Error::new(io::ErrorKind::BrokenPipe, "pipe machine broke");
let expected_message = io_error.to_string();
let e = ReadRecordError::Io(io_error);
assert_eq!(e.to_string(), expected_message);
}

#[test]
fn test_from_data() {
let test_cases = vec![
Expand Down

0 comments on commit a9e1bba

Please sign in to comment.