diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 2fcee49d3abde..4359db9b8c20b 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -144,7 +144,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let wrapper = quote! { pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); - let func_ref: Box #result_type<#ret_type>> = func_ref.to_boxed_fn(); + let func_ref: Box #result_type<#ret_type>> = func_ref.into(); let res: #ret_type = func_ref(#(#args),*)?; Ok(res) } diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md index 7c87939db3016..662687e0e32f8 100644 --- a/rust/tvm-rt/README.md +++ b/rust/tvm-rt/README.md @@ -53,7 +53,7 @@ fn sum(x: i64, y: i64, z: i64) -> i64 { fn main() { register(sum, "mysum".to_owned()).unwrap(); let func = Function::get("mysum").unwrap(); - let boxed_fn = func.to_boxed_fn:: Result>(); + let boxed_fn: Box Result> = func.into(); let ret = boxed_fn(10, 20, 30).unwrap(); assert_eq!(ret, 60); } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 591b5cce8cc70..94a20ac56b685 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -25,7 +25,7 @@ //! //! See the tests and examples repository for more examples. -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::{ ffi::CString, os::raw::{c_char, c_int}, @@ -34,8 +34,6 @@ use std::{ use crate::errors::Error; -use super::to_boxed_fn::ToBoxedFn; - pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -94,11 +92,13 @@ impl Function { } } - pub fn get_boxed>(name: S) -> Option> + pub fn get_boxed(name: S) -> Option> where - F: ToBoxedFn, + S: AsRef, + F: ?Sized, + Self: Into>, { - Self::get(name).map(|f| f.to_boxed_fn::()) + Self::get(name).map(|f| f.into()) } /// Returns the underlying TVM function handle. @@ -141,15 +141,31 @@ impl Function { Ok(rv) } +} - pub fn to_boxed_fn(self) -> Box - where - F: ToBoxedFn, - { - F::to_boxed_fn(self) - } +macro_rules! impl_to_fn { + () => { impl_to_fn!(@impl); }; + ($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); }; + (@impl $($t:ident,)*) => { + impl From for Box Result> + where + Error: From, + Out: TryFrom, + $($t: Into>),* + { + fn from(func: Function) -> Self { + #[allow(non_snake_case)] + Box::new(move |$($t : $t),*| { + let args = vec![ $($t.into()),* ]; + Ok(func.invoke(args)?.try_into()?) + }) + } + } + }; } +impl_to_fn!(T1, T2, T3, T4, T5, T6,); + impl Clone for Function { fn clone(&self) -> Function { Self { @@ -248,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { /// /// register(sum, "mysum".to_owned()).unwrap(); /// let func = Function::get("mysum").unwrap(); -/// let boxed_fn = func.to_boxed_fn:: Result>(); +/// let boxed_fn: Box Result> = func.into(); /// let ret = boxed_fn(10, 20, 30).unwrap(); /// assert_eq!(ret, 60); /// ``` diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index a56a25be82fbb..ad4c1ca885f9f 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -97,7 +97,6 @@ pub mod errors; pub mod function; pub mod module; pub mod ndarray; -pub mod to_boxed_fn; mod to_function; pub mod value; diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs deleted file mode 100644 index 8416f2ce650f9..0000000000000 --- a/rust/tvm-rt/src/to_boxed_fn.rs +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module provides a method for converting type erased TVM functions -//! into a boxed Rust closure. -//! -//! To call a registered function check the [`ToBoxedFn::to_boxed_fn`] method. -//! -//! See the tests and examples repository for more examples. - -pub use tvm_sys::{ffi, ArgValue, RetValue}; - -use crate::errors; - -use super::function::{Function, Result}; - -pub trait ToBoxedFn { - fn to_boxed_fn(func: Function) -> Box; -} - -use std::convert::{TryFrom, TryInto}; - -impl ToBoxedFn for dyn Fn() -> Result -where - errors::Error: From, - O: TryFrom, -{ - fn to_boxed_fn(func: Function) -> Box { - Box::new(move || { - let res = func.invoke(vec![])?; - let res = res.try_into()?; - Ok(res) - }) - } -} - -impl ToBoxedFn for dyn Fn(A) -> Result -where - errors::Error: From, - A: Into>, - O: TryFrom, -{ - fn to_boxed_fn(func: Function) -> Box { - Box::new(move |a: A| { - let args = vec![a.into()]; - let res = func.invoke(args)?; - let res = res.try_into()?; - Ok(res) - }) - } -} - -impl ToBoxedFn for dyn Fn(A, B) -> Result -where - errors::Error: From, - A: Into>, - B: Into>, - O: TryFrom, -{ - fn to_boxed_fn(func: Function) -> Box { - Box::new(move |a: A, b: B| { - let args = vec![a.into(), b.into()]; - let res = func.invoke(args)?; - let res = res.try_into()?; - Ok(res) - }) - } -} - -impl ToBoxedFn for dyn Fn(A, B, C) -> Result -where - errors::Error: From, - A: Into>, - B: Into>, - C: Into>, - O: TryFrom, -{ - fn to_boxed_fn(func: Function) -> Box { - Box::new(move |a: A, b: B, c: C| { - let args = vec![a.into(), b.into(), c.into()]; - let res = func.invoke(args)?; - let res = res.try_into()?; - Ok(res) - }) - } -} - -impl ToBoxedFn for dyn Fn(A, B, C, D) -> Result -where - errors::Error: From, - A: Into>, - B: Into>, - C: Into>, - D: Into>, - O: TryFrom, -{ - fn to_boxed_fn(func: Function) -> Box { - Box::new(move |a: A, b: B, c: C, d: D| { - let args = vec![a.into(), b.into(), c.into(), d.into()]; - let res = func.invoke(args)?; - let res = res.try_into()?; - Ok(res) - }) - } -} - -#[cfg(test)] -mod tests { - use crate::function::{self, Function, Result}; - - #[test] - fn to_boxed_fn0() { - fn boxed0() -> i64 { - return 10; - } - - function::register_override(boxed0, "boxed0".to_owned(), true).unwrap(); - let func = Function::get("boxed0").unwrap(); - let typed_func: Box Result> = func.to_boxed_fn(); - assert_eq!(typed_func().unwrap(), 10); - } -} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 445c99ea98694..a89652b0378c8 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -49,85 +49,6 @@ pub trait Typed { fn ret(o: O) -> Result; } -impl Typed<(), O> for F -where - F: Fn() -> O, - Error: From, - O: TryInto, -{ - fn args(_args: Vec>) -> Result<()> { - debug_assert!(_args.len() == 0); - Ok(()) - } - - fn ret(o: O) -> Result { - o.try_into().map_err(|e| e.into()) - } -} - -impl Typed<(A,), O> for F -where - F: Fn(A) -> O, - Error: From, - Error: From, - A: TryFrom, Error = E1>, - O: TryInto, -{ - fn args(args: Vec>) -> Result<(A,)> { - debug_assert!(args.len() == 1); - let a: A = args[0].clone().try_into()?; - Ok((a,)) - } - - fn ret(o: O) -> Result { - o.try_into().map_err(|e| e.into()) - } -} - -impl Typed<(A, B), O> for F -where - F: Fn(A, B) -> O, - Error: From, - Error: From, - A: TryFrom, Error = E1>, - B: TryFrom, Error = E1>, - O: TryInto, -{ - fn args(args: Vec>) -> Result<(A, B)> { - debug_assert!(args.len() == 2); - let a: A = args[0].clone().try_into()?; - let b: B = args[1].clone().try_into()?; - Ok((a, b)) - } - - fn ret(o: O) -> Result { - o.try_into().map_err(|e| e.into()) - } -} - -impl Typed<(A, B, C), O> for F -where - F: Fn(A, B, C) -> O, - Error: From, - Error: From, - A: TryFrom, Error = E1>, - B: TryFrom, Error = E1>, - C: TryFrom, Error = E1>, - O: TryInto, -{ - fn args(args: Vec>) -> Result<(A, B, C)> { - debug_assert!(args.len() == 3); - let a: A = args[0].clone().try_into()?; - let b: B = args[1].clone().try_into()?; - let c: C = args[2].clone().try_into()?; - Ok((a, b, c)) - } - - fn ret(o: O) -> Result { - o.try_into().map_err(|e| e.into()) - } -} - pub trait ToFunction: Sized { type Handle; @@ -269,95 +190,100 @@ impl ToFunction>, RetValue> fn drop(_: *mut Self::Handle) {} } -impl ToFunction<(), O> for F -where - F: Fn() -> O + 'static, -{ - type Handle = Box O + 'static>; - - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(Box::new(self)); - Box::into_raw(ptr) - } +macro_rules! impl_typed_and_to_function { + ($len:literal; $($t:ident),*) => { + impl Typed<($($t,)*), Out> for F + where + F: Fn($($t),*) -> Out, + Out: TryInto, + Error: From, + $( $t: TryFrom>, + Error: From<$t::Error>, )* + { + #[allow(non_snake_case, unused_variables, unused_mut)] + fn args(args: Vec>) -> Result<($($t,)*)> { + if args.len() != $len { + return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", + std::any::type_name::(), + $len, args.len()))) + } + let mut args = args.into_iter(); + $(let $t = args.next().unwrap().try_into()?;)* + Ok(($($t,)*)) + } - fn call(handle: *mut Self::Handle, _: Vec>) -> Result - where - F: Typed<(), O>, - { - // Ideally we shouldn't need to clone, probably doesn't really matter. - let out = unsafe { (*handle)() }; - F::ret(out) - } + fn ret(out: Out) -> Result { + out.try_into().map_err(|e| e.into()) + } + } - fn drop(_: *mut Self::Handle) {} -} -macro_rules! to_function_instance { - ($(($param:ident,$index:tt),)+) => { - impl ToFunction<($($param,)+), O> for - F where F: Fn($($param,)+) -> O + 'static { - type Handle = Box O + 'static>; + impl ToFunction<($($t,)*), Out> for F + where + F: Fn($($t,)*) -> Out + 'static + { + type Handle = Box Out + 'static>; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(Box::new(self)); Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result where F: Typed<($($param,)+), O> { - // Ideally we shouldn't need to clone, probably doesn't really matter. - let args = F::args(args)?; - let out = unsafe { - (*handle)($(args.$index),+) - }; + #[allow(non_snake_case)] + fn call(handle: *mut Self::Handle, args: Vec>) -> Result + where + F: Typed<($($t,)*), Out> + { + let ($($t,)*) = F::args(args)?; + let out = unsafe { (*handle)($($t),*) }; F::ret(out) } - fn drop(_: *mut Self::Handle) {} + fn drop(ptr: *mut Self::Handle) { + let bx = unsafe { Box::from_raw(ptr) }; + std::mem::drop(bx) + } } } } -to_function_instance!((A, 0),); -to_function_instance!((A, 0), (B, 1),); -to_function_instance!((A, 0), (B, 1), (C, 2),); -to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),); +impl_typed_and_to_function!(0;); +impl_typed_and_to_function!(1; A); +impl_typed_and_to_function!(2; A, B); +impl_typed_and_to_function!(3; A, B, C); +impl_typed_and_to_function!(4; A, B, C, D); +impl_typed_and_to_function!(5; A, B, C, D, E); #[cfg(test)] mod tests { - use super::{Function, ToFunction, Typed}; - - fn zero() -> i32 { - 10 - } + use super::*; - fn helper(f: F) -> Function + fn call(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, { - f.to_function() + F::call(f.into_raw(), args) } #[test] fn test_to_function0() { - helper(zero); - } - - fn one_arg(i: i32) -> i32 { - i - } - - #[test] - fn test_to_function1() { - helper(one_arg); - } - - fn two_arg(i: i32, j: i32) -> i32 { - i + j + fn zero() -> i32 { + 10 + } + let _ = zero.to_function(); + let good = call(zero, vec![]).unwrap(); + assert_eq!(i32::try_from(good).unwrap(), 10); + let bad = call(zero, vec![1.into()]).unwrap_err(); + assert!(matches!(bad, Error::CallFailed(..))); } #[test] fn test_to_function2() { - helper(two_arg); + fn two_arg(i: i32, j: i32) -> i32 { + i + j + } + let good = call(two_arg, vec![3.into(), 4.into()]).unwrap(); + assert_eq!(i32::try_from(good).unwrap(), 7); } }