From 6dc8e22b2696f410ddb53208e0f2821fc8f80d38 Mon Sep 17 00:00:00 2001 From: Adelbert Chang Date: Mon, 2 Nov 2020 15:35:09 -0800 Subject: [PATCH] [rust][tvm-graph-rt]: maintain error sources when propagating errors, swap Mutex for RwLock (#6815) --- rust/tvm-graph-rt/src/errors.rs | 14 ++++---- rust/tvm-graph-rt/src/graph.rs | 49 ++++++++++++++++++-------- rust/tvm-graph-rt/src/module/syslib.rs | 10 +++--- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/rust/tvm-graph-rt/src/errors.rs b/rust/tvm-graph-rt/src/errors.rs index 2ca97bdabb6b..c4bddb85b0de 100644 --- a/rust/tvm-graph-rt/src/errors.rs +++ b/rust/tvm-graph-rt/src/errors.rs @@ -22,14 +22,14 @@ use tvm_sys::DataType; #[derive(Debug, Error)] pub enum GraphFormatError { - #[error("Could not parse graph json")] - Parse(#[from] serde_json::Error), - #[error("Could not parse graph params")] - Params, - #[error("{0} is missing attr: {1}")] + #[error("Failed to parse graph with error: {0}")] + Parse(#[source] serde_json::Error), + #[error("Failed to parse graph parameters with error: {0:?}")] + Params(#[source] Option, nom::error::ErrorKind)>>), + #[error("{0} is missing attribute: {1}")] MissingAttr(String, String), - #[error("Graph has invalid attr that can't be parsed: {0}")] - InvalidAttr(#[from] std::num::ParseIntError), + #[error("Failed to parse graph attribute '{0}' with error: {1}")] + InvalidAttr(String, #[source] std::num::ParseIntError), #[error("Missing field: {0}")] MissingField(&'static str), #[error("Invalid DLType: {0}")] diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index 87dd4a76d5e4..646a20daaf5b 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -26,7 +26,7 @@ use nom::{ character::complete::{alpha1, digit1}, complete, count, do_parse, length_count, map, named, number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8}, - opt, tag, take, tuple, + opt, tag, take, tuple, Err as NomErr, }; use serde::{Deserialize, Serialize}; use serde_json; @@ -121,10 +121,22 @@ impl Node { .attrs .as_ref() .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?; + + let func_name = get_node_attr!(self.name, attrs, "func_name")?.to_owned(); + + let num_outputs = get_node_attr!(self.name, attrs, "num_outputs")? + .parse::() + .map_err(|error| GraphFormatError::InvalidAttr("num_outputs".to_string(), error))?; + + let flatten_data = get_node_attr!(self.name, attrs, "flatten_data")? + .parse::() + .map(|val| val == 1) + .map_err(|error| GraphFormatError::InvalidAttr("flatten_data".to_string(), error))?; + Ok(NodeAttrs { - func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(), - num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::()?, - flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::()? == 1, + func_name, + num_outputs, + flatten_data, }) } } @@ -132,16 +144,14 @@ impl Node { impl<'a> TryFrom<&'a String> for Graph { type Error = GraphFormatError; fn try_from(graph_json: &String) -> Result { - let graph = serde_json::from_str(graph_json)?; - Ok(graph) + serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error)) } } impl<'a> TryFrom<&'a str> for Graph { type Error = GraphFormatError; fn try_from(graph_json: &'a str) -> Result { - let graph = serde_json::from_str(graph_json)?; - Ok(graph) + serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error)) } } @@ -475,14 +485,23 @@ named! { /// Loads a param dict saved using `relay.save_param_dict`. pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { - if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { - if remaining_bytes.is_empty() { - Ok(param_dict) - } else { - Err(GraphFormatError::Params) + match parse_param_dict(bytes) { + Ok((remaining_bytes, param_dict)) => { + if remaining_bytes.is_empty() { + Ok(param_dict) + } else { + Err(GraphFormatError::Params(None)) + } } - } else { - Err(GraphFormatError::Params) + Err(error) => Err(match error { + NomErr::Incomplete(error) => GraphFormatError::Params(Some(NomErr::Incomplete(error))), + NomErr::Error((remainder, error_kind)) => { + GraphFormatError::Params(Some(NomErr::Error((remainder.into(), error_kind)))) + } + NomErr::Failure((remainder, error_kind)) => { + GraphFormatError::Params(Some(NomErr::Failure((remainder.into(), error_kind)))) + } + }), } } diff --git a/rust/tvm-graph-rt/src/module/syslib.rs b/rust/tvm-graph-rt/src/module/syslib.rs index 0279e31be079..efc29a336620 100644 --- a/rust/tvm-graph-rt/src/module/syslib.rs +++ b/rust/tvm-graph-rt/src/module/syslib.rs @@ -18,7 +18,7 @@ */ use std::{ - collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::RwLock, }; use lazy_static::lazy_static; @@ -35,14 +35,14 @@ extern "C" { } lazy_static! { - static ref SYSTEM_LIB_FUNCTIONS: Mutex> = - Mutex::new(HashMap::new()); + static ref SYSTEM_LIB_FUNCTIONS: RwLock> = + RwLock::new(HashMap::new()); } impl Module for SystemLibModule { fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { SYSTEM_LIB_FUNCTIONS - .lock() + .read() .unwrap() .get(name.as_ref()) .copied() @@ -65,7 +65,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( func: BackendPackedCFunc, ) -> i32 { let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; - SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( + SYSTEM_LIB_FUNCTIONS.write().unwrap().insert( name.to_string(), &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), );