diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml index 5d21ee509b02..5be1da82e820 100644 --- a/rust/common/Cargo.toml +++ b/rust/common/Cargo.toml @@ -3,6 +3,7 @@ name = "tvm-common" version = "0.1.0" authors = ["TVM Contributors"] license = "Apache-2.0" +edition = "2018" [features] bindings = [] diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs index ad72f36433c0..5df02f11035e 100644 --- a/rust/common/src/errors.rs +++ b/rust/common/src/errors.rs @@ -1,47 +1,11 @@ -use std::fmt; - -static TYPE_CODE_STRS: [&str; 15] = [ - "int", - "uint", - "float", - "handle", - "null", - "TVMType", - "TVMContext", - "ArrayHandle", - "NodeHandle", - "ModuleHandle", - "FuncHandle", - "str", - "bytes", - "NDArrayContainer", - "ExtBegin", -]; - #[derive(Debug, Fail)] +#[fail( + display = "Could not downcast `{}` into `{}`", + expected_type, actual_type +)] pub struct ValueDowncastError { - actual_type_code: i64, - expected_type_code: i64, -} - -impl ValueDowncastError { - pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self { - Self { - actual_type_code, - expected_type_code, - } - } -} - -impl fmt::Display for ValueDowncastError { - fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - write!( - formatter, - "Could not downcast TVMValue: expected `{}` but was {}", - TYPE_CODE_STRS[self.actual_type_code as usize], - TYPE_CODE_STRS[self.expected_type_code as usize] - ) - } + pub actual_type: String, + pub expected_type: &'static str, } #[derive(Debug, Fail)] @@ -62,18 +26,3 @@ impl FuncCallError { } } } - -// error_chain! { -// errors { -// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { -// description("mismatched types while downcasting TVMRetValue") -// display("invalid downcast: expected `{}` but was `{}`", -// expected_type, type_code_to_string(actual_type_code)) -// } -// } -// foreign_links { -// IntoString(std::ffi::IntoStringError); -// ParseInt(std::num::ParseIntError); -// Utf8(std::str::Utf8Error); -// } -// } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 966655e802f8..4a3a4ae57bb3 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -1,7 +1,7 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#![feature(box_syntax, trait_alias)] +#![feature(box_syntax, type_alias_enum_variants, trait_alias)] #[macro_use] extern crate failure; @@ -25,5 +25,5 @@ pub mod packed_func; pub mod value; pub use errors::*; -pub use ffi::{TVMContext, TVMType}; +pub use ffi::{TVMByteArray, TVMContext, TVMType}; pub use packed_func::{TVMArgValue, TVMRetValue}; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index a564fe656415..5ed36e4bd64c 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -1,9 +1,11 @@ -use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; - -use failure::Error; +use std::{ + convert::TryFrom, + ffi::{CStr, CString}, + os::raw::c_void, +}; pub use crate::ffi::TVMValue; -use crate::ffi::*; +use crate::{errors::ValueDowncastError, ffi::*}; pub trait PackedFunc = Fn(&[TVMArgValue]) -> Result + Send + Sync; @@ -15,298 +17,308 @@ pub trait PackedFunc = /// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` #[macro_export] macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; } -/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way -/// to obtain a `TVMArgValue` is automatically via `call_packed!`. -#[derive(Clone, Copy)] -pub struct TVMArgValue<'a> { - pub _lifetime: PhantomData<&'a ()>, - pub value: TVMValue, - pub type_code: i64, +/// Constructs a derivative of a TVMPodValue. +macro_rules! TVMPODValue { + { + $(#[$m:meta])+ + $name:ident $(<$a:lifetime>)? { + $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? + }, + match $value:ident { + $($tvm_type:ident => { $from_tvm_type:expr })+ + }, + match &self { + $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ + } + $(,)? + } => { + $(#[$m])+ + #[derive(Clone, Debug)] + pub enum $name $(<$a>)? { + Int(i64), + UInt(i64), + Float(f64), + Null, + Type(TVMType), + String(CString), + Context(TVMContext), + Handle(*mut c_void), + ArrayHandle(TVMArrayHandle), + NodeHandle(*mut c_void), + ModuleHandle(TVMModuleHandle), + FuncHandle(TVMFunctionHandle), + NDArrayContainer(*mut c_void), + $($extra_variant($variant_type)),+ + } + + impl $(<$a>)? $name $(<$a>)? { + pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { + use $name::*; + #[allow(non_upper_case_globals)] + unsafe { + match type_code { + DLDataTypeCode_kDLInt => Int($value.v_int64), + DLDataTypeCode_kDLUInt => UInt($value.v_int64), + DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMTypeCode_kNull => Null, + TVMTypeCode_kTVMType => Type($value.v_type), + TVMTypeCode_kTVMContext => Context($value.v_ctx), + TVMTypeCode_kHandle => Handle($value.v_handle), + TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle), + TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), + TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), + TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), + $( $tvm_type => { $from_tvm_type } ),+ + _ => unimplemented!("{}", type_code), + } + } + } + + pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + use $name::*; + match self { + Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), + UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), + Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull), + Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + String(val) => { + ( + TVMValue { v_handle: val.as_ptr() as *mut c_void }, + TVMTypeCode_kStr, + ) + } + Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle), + ArrayHandle(val) => { + ( + TVMValue { v_handle: *val as *const _ as *mut c_void }, + TVMTypeCode_kArrayHandle, + ) + }, + NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle), + ModuleHandle(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), + FuncHandle(val) => ( + TVMValue { v_handle: *val }, + TVMTypeCode_kFuncHandle + ), + NDArrayContainer(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kNDArrayContainer), + $( $self_type($val) => { $from_self_type } ),+ + } + } + } + } } -impl<'a> TVMArgValue<'a> { - pub fn new(value: TVMValue, type_code: i64) -> Self { - TVMArgValue { - _lifetime: PhantomData, - value: value, - type_code: type_code, +TVMPODValue! { + /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way + /// to obtain a `TVMArgValue` is automatically via `call_packed!`. + TVMArgValue<'a> { + Bytes(&'a TVMByteArray), + Str(&'a CStr), + }, + match value { + TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + }, + match &self { + Bytes(val) => { + (TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)} + } +} + +TVMPODValue! { + /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. + /// Can be downcasted using `try_from` if it contains the desired type. + /// + /// # Example + /// + /// ``` + /// let a = 42u32; + /// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); + /// + /// let s = "hello, world!"; + /// let t: TVMRetValue = s.into(); + /// assert_eq!(String::try_from(t).unwrap(), s); + /// ``` + TVMRetValue { + Bytes(TVMByteArray), + Str(&'static CStr), + }, + match value { + TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + }, + match &self { + Bytes(val) => + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) } + Str(val) => + { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) } } } #[macro_export] -macro_rules! ensure_type { - ($val:ident, $expected_type_code:expr) => { - ensure!( - $val.type_code == $expected_type_code as i64, - $crate::errors::ValueDowncastError::new( - $val.type_code as i64, - $expected_type_code as i64 - ) - ); +macro_rules! try_downcast { + ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { + match $val { + $( $pat => { Ok($converter) } )+ + _ => Err($crate::errors::ValueDowncastError { + actual_type: format!("{:?}", $val), + expected_type: stringify!($into), + }), + } }; } /// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. -macro_rules! impl_prim_tvm_arg { - ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => { +macro_rules! impl_pod_value { + ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { $( - impl From<$type> for TVMArgValue<'static> { + impl<'a> From<$type> for TVMArgValue<'a> { fn from(val: $type) -> Self { - TVMArgValue { - value: TVMValue { $field: val as $field_type }, - type_code: $type_code as i64, - _lifetime: PhantomData, - } + Self::$variant(val as $inner_ty) } } - impl<'a> From<&'a $type> for TVMArgValue<'a> { + + impl<'a, 'v> From<&'a $type> for TVMArgValue<'v> { fn from(val: &'a $type) -> Self { - TVMArgValue { - value: TVMValue { - $field: val.to_owned() as $field_type, - }, - type_code: $type_code as i64, - _lifetime: PhantomData, - } + Self::$variant(*val as $inner_ty) } } + impl<'a> TryFrom> for $type { - type Error = Error; + type Error = $crate::errors::ValueDowncastError; fn try_from(val: TVMArgValue<'a>) -> Result { - ensure_type!(val, $type_code); - Ok(unsafe { val.value.$field as $type }) + try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { val as $type }) } } - impl<'a> TryFrom<&TVMArgValue<'a>> for $type { - type Error = Error; - fn try_from(val: &TVMArgValue<'a>) -> Result { - ensure_type!(val, $type_code); - Ok(unsafe { val.value.$field as $type }) + impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: &'a TVMArgValue<'v>) -> Result { + try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { *val as $type }) + } + } + + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl TryFrom for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: TVMRetValue) -> Result { + try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { val as $type }) } } )+ }; } -impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]); -impl_prim_tvm_arg!( - DLDataTypeCode_kDLInt, - v_int64, - i64, - [i8, i16, i32, i64, isize] -); -impl_prim_tvm_arg!( - DLDataTypeCode_kDLUInt, - v_int64, - i64, - [u8, u16, u32, u64, usize] -); +impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); +impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); +impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(Type, TVMType, [TVMType]); +impl_pod_value!(Context, TVMContext, [TVMContext]); -#[cfg(feature = "bindings")] -// only allow this in bindings because pure-rust can't take ownership of leaked CString -impl<'a> From<&String> for TVMArgValue<'a> { - fn from(string: &String) -> Self { - TVMArgValue { - value: TVMValue { - v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(), - }, - type_code: TVMTypeCode_kStr as i64, - _lifetime: PhantomData, - } +impl<'a> From<&'a str> for TVMArgValue<'a> { + fn from(s: &'a str) -> Self { + Self::String(CString::new(s).unwrap()) } } -impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { - fn from(string: &std::ffi::CString) -> Self { - TVMArgValue { - value: TVMValue { - v_str: string.as_ptr(), - }, - type_code: TVMTypeCode_kStr as i64, - _lifetime: PhantomData, - } +impl<'a> From<&'a CStr> for TVMArgValue<'a> { + fn from(s: &'a CStr) -> Self { + Self::Str(s) } } -impl<'a> TryFrom> for &str { - type Error = Error; - fn try_from(arg: TVMArgValue<'a>) -> Result { - ensure_type!(arg, TVMTypeCode_kStr); - Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) +impl<'a> TryFrom> for &'a str { + type Error = ValueDowncastError; + fn try_from(val: TVMArgValue<'a>) -> Result { + try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() }) } } -impl<'a> TryFrom<&TVMArgValue<'a>> for &str { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - ensure_type!(arg, TVMTypeCode_kStr); - Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for &'v str { + type Error = ValueDowncastError; + fn try_from(val: &'a TVMArgValue<'v>) -> Result { + try_downcast!(val -> &str, |TVMArgValue::Str(s)| { s.to_str().unwrap() }) } } -/// Creates a conversion to a `TVMArgValue` for an object handle. -impl<'a, T> From<*const T> for TVMArgValue<'a> { +/// Converts an unspecialized handle to a TVMArgValue. +impl From<*const T> for TVMArgValue<'static> { fn from(ptr: *const T) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: ptr as *mut T as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } + Self::Handle(ptr as *mut c_void) } } -/// Creates a conversion to a `TVMArgValue` for a mutable object handle. -impl<'a, T> From<*mut T> for TVMArgValue<'a> { +/// Converts an unspecialized mutable handle to a TVMArgValue. +impl From<*mut T> for TVMArgValue<'static> { fn from(ptr: *mut T) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: ptr as *mut c_void, - }, - type_code: TVMTypeCode_kHandle as i64, - _lifetime: PhantomData, - } + Self::Handle(ptr as *mut c_void) } } impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { fn from(arr: &'a mut DLTensor) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arr as *mut _ as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } + Self::ArrayHandle(arr as *mut DLTensor) } } impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { fn from(arr: &'a DLTensor) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arr as *const _ as *mut DLTensor as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } + Self::ArrayHandle(arr as *const _ as *mut DLTensor) } } -impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType { - type Error = Error; - fn try_from(arg: &'a TVMArgValue<'v>) -> Result { - ensure_type!(arg, TVMTypeCode_kTVMType); - Ok(unsafe { arg.value.v_type.into() }) +impl TryFrom for String { + type Error = ValueDowncastError; + fn try_from(val: TVMRetValue) -> Result { + try_downcast!( + val -> String, + |TVMRetValue::String(s)| { s.into_string().unwrap() }, + |TVMRetValue::Str(s)| { s.to_str().unwrap().to_string() } + ) } } -/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. -/// Can be downcasted using `try_from` if it contains the desired type. -/// -/// # Example -/// -/// ``` -/// let a = 42u32; -/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); -/// -/// let s = "hello, world!"; -/// let t: TVMRetValue = s.into(); -/// assert_eq!(String::try_from(t).unwrap(), s); -/// ``` -pub struct TVMRetValue { - pub value: TVMValue, - pub box_value: Box, - pub type_code: i64, -} - -impl TVMRetValue { - pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { - Self { - value, - type_code, - box_value: box (), - } - } - - pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { - (self.value, self.type_code as TVMTypeCode) +impl From for TVMRetValue { + fn from(s: String) -> Self { + Self::String(std::ffi::CString::new(s).unwrap()) } } -impl Default for TVMRetValue { - fn default() -> Self { - TVMRetValue { - value: TVMValue { v_int64: 0 as i64 }, - type_code: 0, - box_value: box (), - } +impl From for TVMRetValue { + fn from(arr: TVMByteArray) -> Self { + Self::Bytes(arr) } } -macro_rules! impl_pod_ret_value { - ($code:expr, [ $( $ty:ty ),+ ] ) => { - $( - impl From<$ty> for TVMRetValue { - fn from(val: $ty) -> Self { - Self { - value: val.into(), - type_code: $code as i64, - box_value: box (), - } - } - } - - impl TryFrom for $ty { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { - ensure_type!(ret, $code); - Ok(ret.value.into()) - } - } - )+ - }; -} - -impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]); -impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]); -impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]); -impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]); -impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]); - -impl TryFrom for String { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - ensure_type!(ret, TVMTypeCode_kStr); - let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) }; - let ret_str = cs.clone().into_string(); - if cfg!(feature = "bindings") { - std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall) - } - Ok(ret_str?) +impl TryFrom for TVMByteArray { + type Error = ValueDowncastError; + fn try_from(val: TVMRetValue) -> Result { + try_downcast!(val -> TVMByteArray, |TVMRetValue::Bytes(val)| { val }) } } -impl From for TVMRetValue { - fn from(s: String) -> Self { - let cs = std::ffi::CString::new(s).unwrap(); - Self { - value: TVMValue { - v_str: cs.into_raw() as *mut i8, - }, - box_value: box (), - type_code: TVMTypeCode_kStr as i64, - } +impl Default for TVMRetValue { + fn default() -> Self { + Self::Int(0) } } diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index c7c040b0060e..0ece1f636c2a 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -137,3 +137,18 @@ impl_tvm_context!( DLDeviceType_kDLROCM: [rocm], DLDeviceType_kDLExtDev: [ext_dev] ); + +impl TVMByteArray { + pub fn data(&self) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) } + } +} + +impl<'a> From<&'a [u8]> for TVMByteArray { + fn from(bytes: &[u8]) -> Self { + Self { + data: bytes.as_ptr() as *const i8, + size: bytes.len(), + } + } +} diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index eb1f5b8db021..a37b1eb08614 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -9,6 +9,7 @@ readme = "README.md" keywords = ["rust", "tvm", "nnvm"] categories = ["api-bindings", "science"] authors = ["TVM Contributors"] +edition = "2018" [lib] name = "tvm_frontend" diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs index 9274dba862da..db455abbf648 100644 --- a/rust/frontend/src/bytearray.rs +++ b/rust/frontend/src/bytearray.rs @@ -3,9 +3,9 @@ //! //! For more detail, please see the example `resnet` in `examples` repository. -use std::os::raw::{c_char, c_void}; +use std::os::raw::c_char; -use tvm_common::{ffi, TVMArgValue}; +use tvm_common::ffi; /// A struct holding TVM byte-array. /// @@ -44,8 +44,9 @@ impl TVMByteArray { } } -impl<'a> From<&'a Vec> for TVMByteArray { - fn from(arg: &Vec) -> Self { +impl<'a, T: AsRef<[u8]>> From for TVMByteArray { + fn from(arg: T) -> Self { + let arg = arg.as_ref(); let barr = ffi::TVMByteArray { data: arg.as_ptr() as *const c_char, size: arg.len(), @@ -54,18 +55,6 @@ impl<'a> From<&'a Vec> for TVMByteArray { } } -impl<'a> From<&TVMByteArray> for TVMArgValue<'a> { - fn from(arr: &TVMByteArray) -> Self { - Self { - value: ffi::TVMValue { - v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void, - }, - type_code: ffi::TVMTypeCode_kBytes as i64, - _lifetime: std::marker::PhantomData, - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 5d800a8b9644..6a561fceb9f9 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -26,10 +26,7 @@ use std::{ use failure::Error; -use tvm_common::{ - ffi::{self, TVMValue}, - TVMArgValue, -}; +use tvm_common::ffi; use crate::function; @@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType { } } -impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> { - fn from(dev_type: &'a TVMDeviceType) -> Self { - Self { - value: TVMValue { - v_int64: dev_type.0 as i64, - }, - type_code: ffi::DLDataTypeCode_kDLInt as i64, - _lifetime: std::marker::PhantomData, - } - } -} - /// Represents the underlying device context. Default is cpu. /// /// ## Examples @@ -209,7 +194,7 @@ impl TVMContext { let dt = self.device_type.0 as usize; // `unwrap` is ok here because if there is any error, // if would occure inside `call_packed!` - let ret: u64 = call_packed!(func, &dt, &self.device_id, &0) + let ret: u64 = call_packed!(func, dt, self.device_id, 0) .unwrap() .try_into() .unwrap(); @@ -238,7 +223,9 @@ macro_rules! impl_device_attrs { // `unwrap` is ok here because if there is any error, // if would occur in function call. function::Builder::from(func) - .args(&[dt, self.device_id as usize, $attr_kind]) + .arg(dt) + .arg(self.device_id as usize) + .arg($attr_kind) .invoke() .unwrap() .try_into() diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index f0fbcbe67e25..099195c95f78 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> { } /// Pushes a [`TVMArgValue`] into the function argument buffer. - pub fn arg(&mut self, arg: &'a T) -> &mut Self + pub fn arg(&mut self, arg: T) -> &mut Self where - TVMArgValue<'a>: From<&'a T>, + TVMArgValue<'a>: From, { self.arg_buf.push(arg.into()); self @@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> { ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = self - .arg_buf - .iter() - .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode)) - .unzip(); + let (mut values, mut type_codes): (Vec, Vec) = + self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); let mut ret_val = unsafe { std::mem::uninitialized::() }; - let mut ret_type_code = 0; + let mut ret_type_code = 0i32; check_call!(ffi::TVMFuncCall( self.func.ok_or(errors::FunctionNotFoundError)?.handle, values.as_mut_ptr(), @@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> { &mut ret_type_code as *mut _ )); - Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) }) + Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as u32) }) } } @@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback( { check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); } - local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into())); + local_args.push(TVMArgValue::from_tvm_value(value.into(), tcode as u32)); } let rv = match rust_fn(local_args.as_slice()) { @@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback( } }; - let (mut ret_val, ret_tcode) = rv.into_tvm_value(); + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; check_call!(ffi::TVMCFuncSetReturn( ret, @@ -437,8 +434,9 @@ mod tests { let str_arg = CString::new("test").unwrap(); let mut func = Builder::default(); func.get_function("tvm.graph_runtime.remote_create") - .args(&[10, 20]) - .arg(&str_arg); + .arg(10) + .arg(20) + .arg(str_arg.as_c_str()); assert_eq!(func.arg_buf.len(), 3); } } diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs index 9c27387520dc..a27a4ca93b47 100644 --- a/rust/frontend/src/module.rs +++ b/rust/frontend/src/module.rs @@ -80,7 +80,7 @@ impl Module { CString::new(path.as_ref().to_str().ok_or_else(|| { format_err!("Bad module load path: `{}`.", path.as_ref().display()) })?)?; - let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?; + let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?; Ok(ret) } @@ -90,7 +90,10 @@ impl Module { // `unwrap` is safe here because if there is any error during the // function call, it would occur in `call_packed!`. let tgt = CString::new(target).unwrap(); - let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap(); + let ret: i64 = call_packed!(func, tgt.as_c_str()) + .unwrap() + .try_into() + .unwrap(); ret != 0 } diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index 1939c92c0f0b..9a774fa185b5 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -161,7 +161,7 @@ impl NDArray { /// Converts the NDArray to [`TVMByteArray`]. pub fn to_bytearray(&self) -> Result { let v = self.to_vec::()?; - Ok(TVMByteArray::from(&v)) + Ok(TVMByteArray::from(v)) } /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs index eb62f10cabec..3f383527add3 100644 --- a/rust/frontend/src/value.rs +++ b/rust/frontend/src/value.rs @@ -2,140 +2,80 @@ //! and their conversions needed for the types used in frontend crate. //! `TVMRetValue` is the owned version of `TVMPODValue`. -use std::{convert::TryFrom, os::raw::c_void}; +use std::convert::TryFrom; -use failure::Error; use tvm_common::{ - ensure_type, - ffi::{self, TVMValue}, + errors::ValueDowncastError, + ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle}, + try_downcast, }; -use crate::{ - common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray, - TVMRetValue, -}; +use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue}; -macro_rules! impl_tvm_val_from_handle { - ($ty:ident, $type_code:expr, $handle:ty) => { - impl<'a> From<&'a $ty> for TVMArgValue<'a> { - fn from(arg: &$ty) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arg.handle as *mut _ as *mut c_void, - }, - type_code: $type_code as i64, - _lifetime: std::marker::PhantomData, - } +macro_rules! impl_handle_val { + ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { + impl<'a> From<&'a $type> for TVMArgValue<'a> { + fn from(arg: &'a $type) -> Self { + TVMArgValue::$variant(arg.handle() as $inner_type) } } - impl<'a> From<&'a mut $ty> for TVMArgValue<'a> { - fn from(arg: &mut $ty) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arg.handle as *mut _ as *mut c_void, - }, - type_code: $type_code as i64, - _lifetime: std::marker::PhantomData, - } + impl<'a> From<&'a mut $type> for TVMArgValue<'a> { + fn from(arg: &'a mut $type) -> Self { + TVMArgValue::$variant(arg.handle() as $inner_type) } } - impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty { - type Error = Error; - fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> { - ensure_type!(arg, $type_code); - Ok($ty::new(unsafe { arg.value.v_handle as $handle })) + impl<'a> TryFrom> for $type { + type Error = ValueDowncastError; + fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) }) } } - impl From<$ty> for TVMRetValue { - fn from(val: $ty) -> TVMRetValue { - TVMRetValue { - value: TVMValue { - v_handle: val.handle() as *mut c_void, - }, - box_value: box val, - type_code: $type_code as i64, - } + impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type { + type Error = ValueDowncastError; + fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) }) } } - impl TryFrom for $ty { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { - ensure_type!(ret, $type_code); - Ok($ty::new(unsafe { ret.value.v_handle as $handle })) - } - } - }; -} - -impl_tvm_val_from_handle!( - Function, - ffi::TVMTypeCode_kFuncHandle, - ffi::TVMFunctionHandle -); -impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle); -impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle); - -impl<'a> From<&'a TVMByteArray> for TVMValue { - fn from(barr: &TVMByteArray) -> Self { - TVMValue { - v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void, - } - } -} - -macro_rules! impl_boxed_ret_value { - ($type:ty, $code:expr) => { impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - value: TVMValue { v_int64: 0 }, - box_value: box val, - type_code: $code as i64, - } + fn from(val: $type) -> TVMRetValue { + TVMRetValue::$variant(val.handle() as $inner_type) } } + impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> { - if let Ok(val) = ret.box_value.downcast::<$type>() { - Ok(*val) - } else { - bail!(ValueDowncastError::new($code as i64, ret.type_code as i64)) - } + type Error = ValueDowncastError; + fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) }) } } }; } -impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext); -impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes); - -impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'v>) -> Result { - ensure_type!(arg, ffi::TVMTypeCode_kBytes); - Ok(TVMByteArray::new(unsafe { - *(arg.value.v_handle as *mut ffi::TVMByteArray) - })) - } -} +impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); +impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); +impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new); #[cfg(test)] mod tests { - use super::*; use std::{convert::TryInto, str::FromStr}; - use tvm_common::ffi::TVMType; + + use tvm_common::{TVMByteArray, TVMContext, TVMType}; + + use super::*; #[test] fn bytearray() { let w = vec![1u8, 2, 3, 4, 5]; - let v = TVMByteArray::from(&w); + let v = TVMByteArray::from(w.as_slice()); let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap(); - assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::>()); + assert_eq!( + tvm.data(), + w.iter().map(|e| *e).collect::>().as_slice() + ); } #[test] @@ -147,7 +87,7 @@ mod tests { #[test] fn ctx() { - let c = TVMContext::from("gpu"); + let c = TVMContext::from_str("gpu").unwrap(); let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap(); assert_eq!(tvm, c); } diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs index 3b2ad65a2f45..02daa84f38c9 100644 --- a/rust/frontend/tests/callback/src/bin/string.rs +++ b/rust/frontend/tests/callback/src/bin/string.rs @@ -24,9 +24,9 @@ fn main() { registered.get_function("concate_str"); assert!(registered.func.is_some()); let ret: String = registered - .arg(&a) - .arg(&b) - .arg(&c) + .arg(a.as_c_str()) + .arg(b.as_c_str()) + .arg(c.as_c_str()) .invoke() .unwrap() .try_into() diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index ae73ae721224..33e59e16d8e7 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -8,6 +8,7 @@ readme = "README.md" keywords = ["tvm", "nnvm"] categories = ["api-bindings", "science"] authors = ["TVM Contributors"] +edition = "2018" [features] default = ["nom/std"] diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index eeb819c61289..440cf4d0aded 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { .iter() .map(|t| t.into()) .collect::>(); - func(args.as_slice()).unwrap(); + func(&args).unwrap(); }; op_execs.push(op); } @@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { // TODO: consider `new_with_params` to avoid ever allocating let ptr = self.tensors[idx].data.as_ptr(); let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); - let mut owner = to_replace.nth(0).unwrap(); + let owner = to_replace.nth(0).unwrap(); if value.data.is_owned() { // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr // mem::replace(&mut (*owner), value); diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index 636c4e8ff5cf..3faf81ed24a3 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func( func: BackendPackedCFunc, ) -> Box { box move |args: &[TVMArgValue]| { - let exit_code = func( - args.iter() - .map(|ref arg| arg.value) - .collect::>() - .as_ptr(), - args.iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - ); + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); if exit_code == 0 { Ok(TVMRetValue::default()) } else { diff --git a/rust/runtime/tests/test_graph_serde.rs b/rust/runtime/tests/test_graph_serde.rs index 18ac19a79df3..b52e98b57d87 100644 --- a/rust/runtime/tests/test_graph_serde.rs +++ b/rust/runtime/tests/test_graph_serde.rs @@ -1,5 +1,3 @@ -#![feature(try_from)] - extern crate serde; extern crate serde_json; diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nnvm/src/main.rs index 50179798cd32..8cdc2d9467e9 100644 --- a/rust/runtime/tests/test_nnvm/src/main.rs +++ b/rust/runtime/tests/test_nnvm/src/main.rs @@ -1,5 +1,3 @@ -#![feature(try_from)] - #[macro_use] extern crate ndarray; extern crate serde;