Skip to content

Commit

Permalink
[Rust] Clean up conversions between TVM and Rust functions (#6114)
Browse files Browse the repository at this point in the history
* Replace ToBoxedFn with From

* Compact and improve Typed and ToFunction impls

- Clone one less time
- Don't panic if number of args is wrong, return an error
- Actually drop functions/closures on the rust side

* Retry
  • Loading branch information
mwillsey authored Jul 23, 2020
1 parent 5046ff2 commit 06d7565
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 291 deletions.
2 changes: 1 addition & 1 deletion rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
let res: #ret_type = func_ref(#(#args),*)?;
Ok(res)
}
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn sum(x: i64, y: i64, z: i64) -> i64 {
fn main() {
register(sum, "mysum".to_owned()).unwrap();
let func = Function::get("mysum").unwrap();
let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
let boxed_fn: Box<dyn Fn(i64, i64, i64) -> Result<i64>> = func.into();
let ret = boxed_fn(10, 20, 30).unwrap();
assert_eq!(ret, 60);
}
Expand Down
42 changes: 29 additions & 13 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
//!
//! See the tests and examples repository for more examples.

use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::{
ffi::CString,
os::raw::{c_char, c_int},
Expand All @@ -34,8 +34,6 @@ use std::{

use crate::errors::Error;

use super::to_boxed_fn::ToBoxedFn;

pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};

Expand Down Expand Up @@ -94,11 +92,13 @@ impl Function {
}
}

pub fn get_boxed<F: ?Sized, S: AsRef<str>>(name: S) -> Option<Box<F>>
pub fn get_boxed<F, S>(name: S) -> Option<Box<F>>
where
F: ToBoxedFn,
S: AsRef<str>,
F: ?Sized,
Self: Into<Box<F>>,
{
Self::get(name).map(|f| f.to_boxed_fn::<F>())
Self::get(name).map(|f| f.into())
}

/// Returns the underlying TVM function handle.
Expand Down Expand Up @@ -141,15 +141,31 @@ impl Function {

Ok(rv)
}
}

pub fn to_boxed_fn<F: ?Sized>(self) -> Box<F>
where
F: ToBoxedFn,
{
F::to_boxed_fn(self)
}
macro_rules! impl_to_fn {
() => { impl_to_fn!(@impl); };
($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); };
(@impl $($t:ident,)*) => {
impl<Err, Out, $($t,)*> From<Function> for Box<dyn Fn($($t,)*) -> Result<Out>>
where
Error: From<Err>,
Out: TryFrom<RetValue, Error = Err>,
$($t: Into<ArgValue<'static>>),*
{
fn from(func: Function) -> Self {
#[allow(non_snake_case)]
Box::new(move |$($t : $t),*| {
let args = vec![ $($t.into()),* ];
Ok(func.invoke(args)?.try_into()?)
})
}
}
};
}

impl_to_fn!(T1, T2, T3, T4, T5, T6,);

impl Clone for Function {
fn clone(&self) -> Function {
Self {
Expand Down Expand Up @@ -248,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
///
/// register(sum, "mysum".to_owned()).unwrap();
/// let func = Function::get("mysum").unwrap();
/// let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
/// let boxed_fn: Box<dyn Fn(i64, i64, i64) -> Result<i64>> = func.into();
/// let ret = boxed_fn(10, 20, 30).unwrap();
/// assert_eq!(ret, 60);
/// ```
Expand Down
1 change: 0 additions & 1 deletion rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ pub mod errors;
pub mod function;
pub mod module;
pub mod ndarray;
pub mod to_boxed_fn;
mod to_function;
pub mod value;

Expand Down
138 changes: 0 additions & 138 deletions rust/tvm-rt/src/to_boxed_fn.rs

This file was deleted.

Loading

0 comments on commit 06d7565

Please sign in to comment.