Skip to content

Commit

Permalink
[rust][tvm-graph-rt]: maintain error sources when propagating errors,…
Browse files Browse the repository at this point in the history
… swap Mutex for RwLock (apache#6815)
  • Loading branch information
adelbertc authored and Trevor Morris committed Dec 2, 2020
1 parent 1249b59 commit b241791
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
14 changes: 7 additions & 7 deletions rust/tvm-graph-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ use tvm_sys::DataType;

#[derive(Debug, Error)]
pub enum GraphFormatError {
#[error("Could not parse graph json")]
Parse(#[from] serde_json::Error),
#[error("Could not parse graph params")]
Params,
#[error("{0} is missing attr: {1}")]
#[error("Failed to parse graph with error: {0}")]
Parse(#[source] serde_json::Error),
#[error("Failed to parse graph parameters with error: {0:?}")]
Params(#[source] Option<nom::Err<(Vec<u8>, nom::error::ErrorKind)>>),
#[error("{0} is missing attribute: {1}")]
MissingAttr(String, String),
#[error("Graph has invalid attr that can't be parsed: {0}")]
InvalidAttr(#[from] std::num::ParseIntError),
#[error("Failed to parse graph attribute '{0}' with error: {1}")]
InvalidAttr(String, #[source] std::num::ParseIntError),
#[error("Missing field: {0}")]
MissingField(&'static str),
#[error("Invalid DLType: {0}")]
Expand Down
49 changes: 34 additions & 15 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use nom::{
character::complete::{alpha1, digit1},
complete, count, do_parse, length_count, map, named,
number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
opt, tag, take, tuple,
opt, tag, take, tuple, Err as NomErr,
};
use serde::{Deserialize, Serialize};
use serde_json;
Expand Down Expand Up @@ -121,27 +121,37 @@ impl Node {
.attrs
.as_ref()
.ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;

let func_name = get_node_attr!(self.name, attrs, "func_name")?.to_owned();

let num_outputs = get_node_attr!(self.name, attrs, "num_outputs")?
.parse::<usize>()
.map_err(|error| GraphFormatError::InvalidAttr("num_outputs".to_string(), error))?;

let flatten_data = get_node_attr!(self.name, attrs, "flatten_data")?
.parse::<u8>()
.map(|val| val == 1)
.map_err(|error| GraphFormatError::InvalidAttr("flatten_data".to_string(), error))?;

Ok(NodeAttrs {
func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
func_name,
num_outputs,
flatten_data,
})
}
}

impl<'a> TryFrom<&'a String> for Graph {
type Error = GraphFormatError;
fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error))
}
}

impl<'a> TryFrom<&'a str> for Graph {
type Error = GraphFormatError;
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error))
}
}

Expand Down Expand Up @@ -475,14 +485,23 @@ named! {

/// Loads a param dict saved using `relay.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
if remaining_bytes.is_empty() {
Ok(param_dict)
} else {
Err(GraphFormatError::Params)
match parse_param_dict(bytes) {
Ok((remaining_bytes, param_dict)) => {
if remaining_bytes.is_empty() {
Ok(param_dict)
} else {
Err(GraphFormatError::Params(None))
}
}
} else {
Err(GraphFormatError::Params)
Err(error) => Err(match error {
NomErr::Incomplete(error) => GraphFormatError::Params(Some(NomErr::Incomplete(error))),
NomErr::Error((remainder, error_kind)) => {
GraphFormatError::Params(Some(NomErr::Error((remainder.into(), error_kind))))
}
NomErr::Failure((remainder, error_kind)) => {
GraphFormatError::Params(Some(NomErr::Failure((remainder.into(), error_kind))))
}
}),
}
}

Expand Down
10 changes: 5 additions & 5 deletions rust/tvm-graph-rt/src/module/syslib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::RwLock,
};

use lazy_static::lazy_static;
Expand All @@ -35,14 +35,14 @@ extern "C" {
}

lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
static ref SYSTEM_LIB_FUNCTIONS: RwLock<HashMap<String, &'static (dyn PackedFunc)>> =
RwLock::new(HashMap::new());
}

impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS
.lock()
.read()
.unwrap()
.get(name.as_ref())
.copied()
Expand All @@ -65,7 +65,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
SYSTEM_LIB_FUNCTIONS.write().unwrap().insert(
name.to_string(),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
);
Expand Down

0 comments on commit b241791

Please sign in to comment.