diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 4ae35f585c6fd..520d1557a58ae 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -22,6 +22,13 @@ namespace tvm { * You can find more about Relay by reading the language reference. */ namespace relay { + +#define RELAY_DEBUG(...) \ +{ auto fdebug = runtime::Registry::Get("relay.debug"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ +} + /*! * \brief we always used NodeRef for referencing nodes. * diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h new file mode 100644 index 0000000000000..3521d468c1a45 --- /dev/null +++ b/include/tvm/relay/interpreter.h @@ -0,0 +1,140 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/interpreter.h + * \brief An interpreter for Relay. + * + * This file implements a simple reference interpreter for Relay programs. + * Given a Relay environment, an a Relay expression it produces a value. + * + * This is intended as an implementation of the reference semantics for + * the Relay IR, as well as for debugging and testing. + */ +#ifndef TVM_RELAY_INTERPRETER_H_ +#define TVM_RELAY_INTERPRETER_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A Relay value. + */ +class Value; + +/*! \brief Evaluate an expression using the interpreter producing a value. + * + * This implements the reference semantics of Relay, giving us a tool + * for debugging and testing, especially in the development of alternative + * backends/runtimes. + * + * The resulting value can be passed to Python, making it easy to use + * for testing. + * + * The interpreter interprets the program pieces between TVM operators + * using TVM to back all Relay operator's evaluation. + * + * This is not intended to be an efficient implementation of Relay's + * semantics, eventually the TVM runtime will grow to support Relay's + * features. + */ +Value Evaluate(Environment env, Expr e); + +/*! \brief The base container type of Relay values. */ +class ValueNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Value"; + TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); +}; + +class Value : public NodeRef { + public: + Value() {} + explicit Value(NodePtr n) : NodeRef(n) {} + const ValueNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = ValueNode; +}; + +/*! \brief A Relay closure, i.e a scope and a function. */ +class Closure; + +/*! \brief The container type of Closures. */ +class ClosureNode : public ValueNode { + public: + /*! \brief The set of free variables in the closure. + * + * These are the captured variables which are required for + * evaluation when we call the closure. + */ + tvm::Map env; + /*! \brief The function which implements the closure. + * + * \note May reference the variables contained in the env. + */ + Function func; + + ClosureNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("env", &env); + v->Visit("func", &func); + } + + TVM_DLL static Closure make(tvm::Map env, Function func); + + static constexpr const char* _type_key = "relay.Closure"; + TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); + +/*! \brief A tuple value. */ +class TupleValue; + +/*! \brief Tuple (x, ... y). */ +struct TupleValueNode : ValueNode { + tvm::Array fields; + + TupleValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + + TVM_DLL static TupleValue make(tvm::Array value); + + static constexpr const char* _type_key = "relay.TupleValue"; + TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); + +/*! \brief A tensor value. */ +class TensorValue; + +/*! \brief The tensor value container, wrapping an NDArray. */ +struct TensorValueNode : ValueNode { + runtime::NDArray data; + + TensorValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } + + /*! \brief Build a value from an NDArray. */ + TVM_DLL static TensorValue make(runtime::NDArray data); + + /*! \brief Construct an empty tensor value from t. */ + TVM_DLL static TensorValue FromType(const Type& t); + + static constexpr const char* _type_key = "relay.TensorValue"; + TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); + + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_INTERPRETER_H_ \ No newline at end of file diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1b3462659e18a..1165df083d5c6 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,6 +6,7 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ +#include #include #include @@ -94,7 +95,8 @@ bool AlphaEqual(const Type& t1, const Type& t2); * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * - * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed. + * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, + * although x is not shadowed. * * \param e the expression to check. * @@ -103,6 +105,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); bool WellFormed(const Expr& e); /*! \brief Get free Vars from expr in PostDFS order. + * + * Free variables are variables that are not bound by a let or a function + * parameter in the context. + * + * \param e the expression. + * + * \return the set of free variable. + */ +tvm::Array FreeVariables(const Expr& e); + +/*! \brief Get free type parameters from expression e. * * Free variables are variables that are not bound by a * let or a function parameter in the context. @@ -115,7 +128,8 @@ tvm::Array FreeVars(const Expr& expr); /*! \brief Get free TypeVars from expression expr. * - * Free type parameters are type parameters that are not bound by a function type in the context. + * Free type parameters are type parameters that are not bound by a function + * type in the context. * * \param expr the expression. * @@ -125,10 +139,12 @@ tvm::Array FreeTypeVars(const Expr& expr); /*! \brief Remove expressions which does not effect the program result. * - * It will remove let binding that are not referenced, and if branch that are not entered. + * It will remove let binding that are not referenced, and if branch that are + * not entered. * - * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. - * Another example is `if (true) then 1 else 2` will be optimized into 1. + * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of + * the expression does not depend on a. Another example is `if (true) then 1 + * else 2` will be optimized into 1. * * \param e the expression to optimize. * @@ -136,6 +152,11 @@ tvm::Array FreeTypeVars(const Expr& expr); */ Expr DeadCodeElimination(const Expr& e); +Expr Monomorph(const Environment& env, const Expr& e); + +Array LowerOps(const Expr& e, const std::string& target = "llvm"); + } // namespace relay } // namespace tvm + #endif // TVM_RELAY_PASS_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 731a816460eee..c49802f2a142d 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,10 +1,12 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" +from ..api import register_func from . import base from . import ty from . import expr from . import env from . import ir_pass +from . import testing # Root operators from .op import Op @@ -46,6 +48,19 @@ If = expr.If TupleGetItem = expr.TupleGetItem + # helper functions var = expr.var const = expr.const + +@register_func("relay._tensor_value_repr") +def _tensor_value_repr(tv): + return str(tv.data.asnumpy()) + +@register_func("relay._constant_repr") +def _tensor_value_repr(tv): + return str(tv.data.asnumpy()) + +@register_func("relay.debug") +def _debug(*args): + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/python/tvm/relay/_eval.py b/python/tvm/relay/_eval.py new file mode 100644 index 0000000000000..8f7ddcc9bc675 --- /dev/null +++ b/python/tvm/relay/_eval.py @@ -0,0 +1,4 @@ +"""The interface to the Evaluator exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._eval", __name__) \ No newline at end of file diff --git a/python/tvm/relay/eval.py b/python/tvm/relay/eval.py new file mode 100644 index 0000000000000..ada046bb12aa0 --- /dev/null +++ b/python/tvm/relay/eval.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import +import numpy as np +from .. import register_func, nd +from .base import NodeBase, register_relay_node +from . import _make +from . import _eval +from . import ir_pass +from .expr import Call, Constant +from . import const + +class Value(NodeBase): + """Base class of all values. + """ + pass + + @staticmethod + @register_func("relay.from_scalar") + def from_scalar(i, dtype=None): + if dtype is None: + if isinstance(i, int): + dtype = 'int32' + elif isinstance(i, float): + dtype = 'float32' + elif isinstance(i, bool): + dtype = 'uint8' + else: + raise Exception("unable to infer dtype {0}".format(type(i))) + + return TensorValue(nd.array(np.array(i, dtype=dtype))) + + +@register_relay_node +class TupleValue(Value): + def __init__(self, *fields): + self.__init_handle_by_constructor__( + _make.TupleValue, fields) + + def __getitem__(self, field_no): + return self.fields[field_no] + + +@register_relay_node +class Closure(Value): + pass + + +@register_relay_node +class TensorValue(Value): + """A Tensor value produced by the evaluator.""" + + def __init__(self, data): + """Allocate a new TensorValue and copy the data from `array` into + the new array. + """ + if isinstance(data, np.ndarray): + data = nd.array(data) + + self.__init_handle_by_constructor__( + _make.TensorValue, data) + + def as_ndarray(self): + """Convert a Relay TensorValue into a tvm.ndarray.""" + return self.data + def asnumpy(self): + """Convert a Relay TensorValue into a numpy.ndarray.""" + return self.data.asnumpy() + + def __eq__(self, other): + return self.data == other.data + +def _arg_to_ast(arg): + if isinstance(arg, TensorValue): + return Constant(arg.data) + elif isinstance(arg, np.ndarray): + return Constant(nd.array(arg)) + elif isinstance(arg, Constant): + return arg + else: + return const(arg) + +def apply_passes(expr, env=None): + ck_expr = ir_pass.infer_type(expr, env) + fused_expr = ir_pass.fuse_ops(ck_expr, env) + return fused_expr + +def evaluate(env, expr, *args): + # assert len(args) == 0 + relay_args = [] + for arg in args: + relay_args.append(_arg_to_ast(arg)) + + expr = Call(expr, relay_args) + opt_expr = apply_passes(expr, env) + return _eval.evaluate(env, opt_expr) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index c6d5aa7515bcc..496f24e56ece6 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,4 +1,4 @@ -# pylint: disable=no-else-return, +# pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck """The set of passes for Relay. @@ -148,6 +148,9 @@ def alpha_equal(lhs, rhs): """ return bool(_make._alpha_equal(lhs, rhs)) +lower_ops = _ir_pass.LowerOps +fuse_ops = _ir_pass.FuseOps +monomorph = _ir_pass.Monomorph def graph_equal(lhs, rhs): """Compare two Relay expr for data-flow equivalence. diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 0bc2054cebdfd..30d6e8a308610 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,2 +1,16 @@ #pylint: disable=invalid-name """Backend compiler related feature registration""" +import tvm +import topi +from . import register + +def add_compiler(attrs, inputs, output_type): + assert len(inputs) == 2 + return [topi.add(inputs[0], inputs[1])] + +def add_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("add", "FTVMCompute", add_compiler) +register("add", "FTVMSchedule", add_schedule) \ No newline at end of file diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index f1130b52e7ce4..ed78d35cd2c75 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,7 +3,8 @@ from ..base import register_relay_node from ..expr import Expr - +from ...api import register_func +from ...build_module import lower, build @register_relay_node class Op(Expr): @@ -75,3 +76,12 @@ def _register(v): _init_api("relay.op", __name__) + +@register_func("relay.op.compiler._lower") +def _lower(name, schedule, inputs, outputs): + lf = lower(schedule, list(inputs) + list(outputs), name=name) + return lf + +@register_func("relay.op.compiler._build") +def _build(lowered_funcs): + return build(lowered_funcs, target="llvm") \ No newline at end of file diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py new file mode 100644 index 0000000000000..f61c7230d8199 --- /dev/null +++ b/python/tvm/relay/to_tvm.py @@ -0,0 +1,356 @@ +"""A compiler from Relay programs to TVM's graph runtime. +""" +from __future__ import absolute_import +import json +import attr +from . import ir_pass +from .op import Op +from .ty import TensorType +from .expr import Var, Function, Let, Call, If, GlobalVar, Constant, Let +from ..build_module import build +from typing import Any, Dict, List, Tuple +from .. contrib import graph_runtime +from .ir_pass import infer_type, monomorph +from .. import cpu + +class AbstractExprVisitor(object): + """A visitor over Expr in Python.""" + + # pylint: disable=no-else-return + def visit(self, expr): + """Apply the visitor to an expression.""" + if isinstance(expr, Function): + return self.visit_function(expr) + elif isinstance(expr, Call): + return self.visit_call(expr) + elif isinstance(expr, Let): + return self.visit_let(expr) + elif isinstance(expr, Var): + return self.visit_var(expr) + elif isinstance(expr, GlobalVar): + return self.visit_global_var(expr) + elif isinstance(expr, If): + return self.visit_if(expr) + elif isinstance(expr, Tuple): + return self.visit_tuple(expr) + elif isinstance(expr, Constant): + return self.visit_constant(expr) + else: + raise Exception(f"warning unhandled case: {type(expr)}") + + def visit_function(self, _): + raise Exception("Abstract method please implement me.") + + def visit_let(self, _): + raise Exception("Abstract method please implement me.") + + def visit_call(self, _): + raise Exception("Abstract method please implement me.") + + def visit_var(self, _): + raise Exception("Abstract method please implement me.") + + def visit_type(self, typ): + return typ + + def visit_if(self, _): + raise Exception("Abstract method please implement me.") + + def visit_tuple(self, _): + raise Exception("Abstract method please implement me.") + + def visit_constant(self, _): + raise Exception("Abstract method please implement me.") + + def visit_global_var(self, _): + raise Exception("Abstract method please implement me.") + +class ExprMutator(AbstractExprVisitor): + """A functional visitor over Expr in Python.""" + + def visit_function(self, fn): + new_body = self.visit(fn.body) + return Function( + list(fn.params), + fn.ret_type, new_body, + fn.type_params) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_var(self, var): + return var + + def visit_global_id(self, global_var): + return global_var + + def visit_if(self, ite): + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_constant(self, const): + return const + +@attr.s +class NodeRef(object): + ident = attr.ib() + index = attr.ib(default=0) + version = attr.ib(default=0) + + def to_json(self): + return [self.ident, self.index, self.version] + + +@attr.s +class Node(object): + name = attr.ib() + attrs = attr.ib() + is_output = attr.ib() + + def to_json(self) -> Any: + raise Exception("Abstract method, please implement me.") + + +@attr.s +class InputNode(Node): + """An input node in the graph representation we lower to before NNVM's graph.""" + name = attr.ib() + attrs = attr.ib() + is_output = attr.ib(default=False) + + def to_json(self): + return { + "op": "null", + "name": self.name, + "inputs": [] + } + + +@attr.s +class OpNode(Node): + """An operator node in the graph representation we lower to before NNVM's graph.""" + op_name = attr.ib() + inputs = attr.ib() + op_attrs = attr.ib() + is_output = attr.ib(default=False) + + def to_json(self): + attrs = dict.copy(self.op_attrs) + # Extend ops with extra info. + attrs['func_name'] = self.op_name + # When do we flatten? + attrs['flatten_data'] = "0" + # Fix me! + attrs['num_inputs'] = str(len(self.inputs)) + attrs['num_outputs'] = "1" + + return { + "op": "tvm_op", + "name": self.name, + "attrs": attrs, + "inputs": self.inputs + } + + +def shape_to_json(shape): + return [sh.value for sh in shape] + + +def from_tensor(typ): + return (typ.dtype, shape_to_json(typ.shape)) + + +class TVMRTSCompiler(ExprMutator): + """The compiler from Relay to the TVM runtime system.""" + nodes = attr.ib() + id_map = attr.ib() + all_ops = attr.ib() + + def __init__(self): + self.nodes = [] + self.id_map = {} + self.all_ops = set() + + def add_node(self, node): + self.nodes.append(node) + ident = len(self.nodes) - 1 + return NodeRef(ident) + + def add_binding(self, ident, ref): + self.id_map[ident] = ref + + def let_bind(self, ident, node): + ref = self.add_node(node) + self.add_binding(ident, ref) + return ref + + def get_node(self, ref): + return self.nodes[ref.ident] + + def lookup(self, ident): + return self.id_map[ident] + + def compile(self, func): + """Compile a single function into a graph.""" + # TODO: (@jroesch) Restore me + # assert len(fn.ty_params) == 0 + + # First we convert all the parameters into input nodes. + params = func.params + + for param in params: + dtype, shape = from_tensor(param.type_annotation) + node = InputNode(f"{param.name_hint}", { + "shape": shape, + "dtype": dtype, + }) + self.let_bind(param, node) + + # Then we compile the body into a graph which can depend + # on input variables. + output_ref = self.visit(func.body) + + # Finally we retreive return value of program, which will + # become our output node. + self.get_node(output_ref).is_output = True + + def visit_let(self, let): + """Visit the Let binding, by first traversing its value, + then setting the metadata on the returned NodeRef. + + Finally visit the body, and return the NodeRef corresponding + to it. + """ + ident = let.var + val = let.value + body = let.body + + # Need to add type info? + val_ref = self.visit(val) + dtype, shape = from_tensor(val.checked_type()) + val_node = self.get_node(val_ref) + val_node.attrs["dtype"] = dtype + val_node.attrs["shape"] = shape + self.add_binding(ident, val_ref) + return self.visit(body) + + def visit_var(self, ident): + return self.lookup(ident) + + def visit_call(self, call): + """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" + inputs = [] + for arg in call.args: + inputs.append(self.visit(arg).to_json()) + + if isinstance(call.op, Op): + self.all_ops.add(call.op.name) + else: + raise Exception("TVM runtime does not support function calls.") + + op_name = call.op.name + attrs = {'shape': shape_to_json(call.checked_type.shape), + 'dtype': call.checked_type.dtype} + op_node = OpNode("call_name", attrs, op_name, inputs, {}) + return self.add_node(op_node) + + def to_json(self): + """Convert the sequence of nodes stored by the compiler into the + JSON format defined in: https://docs.tvm.ai/dev/nnvm_json_spec.html. + """ + nodes = [] + # First we compute "nodes" field. + for node in self.nodes: + nodes.append(node.to_json()) + + arg_nodes = [] + heads = [] + # Compute "arg_nodes" and "heads" fields. + for i, node in enumerate(self.nodes): + if isinstance(node, InputNode): + arg_nodes.append(i) + + if node.is_output: + # Need to fix this. + heads.append(NodeRef(i).to_json()) + + # Compute "node_row_ptr". + # TODO + + # Compute "attrs" field. + attrs = {} + + # These fields are mandatory. + shapes = [] + storage_ids = [] + dtype = [] + dltype = [] + + for i, node in enumerate(self.nodes): + storage_ids.append(i) + shapes.append(node.attrs['shape']) + if node.attrs['dtype'] == 'float32': + dtype.append(0) + dltype.append('float32') + + attrs["shape"] = ["list_shape", shapes] + attrs["storage_id"] = ["list_int", storage_ids] + attrs["dtype"] = ["list_int", dtype] + attrs["dltype"] = ["list_str", dltype] + + json_dict = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "heads": heads, + "attrs": attrs + } + + return json.dumps(json_dict) + + +def compile_to_tvm(func, target=None): + """Compile a single function to the components needed by the + TVM RTS. + """ + if target is None: + target = 'llvm' + + comp = TVMRTSCompiler() + comp.compile(func) + lowered_funcs = ir_pass.lower_ops(func) + mod = build(lowered_funcs, target) + graph_json = comp.to_json() + return graph_json, mod, None # params currently isn't supported by API + +def evaluate_rts(env, func, *args): + func = infer_type(func, env) + func = monomorph(env, func) + func = infer_type(func, env) + graph_json, mod, params = compile_to_tvm(func) + assert params is None + # Temporary hack for node_row_ptr + import nnvm + graph = nnvm.graph.load_json(graph_json) + gmodule = graph_runtime.create(graph, mod, cpu(0)) + # Create map of inputs. + inputs = {} + for i, arg in enumerate(args): + inputs[func.params[i].name_hint] = arg + # Set the inputs here. + gmodule.set_input(**inputs) + # Run the module, and fetch the output. + gmodule.run() + return gmodule.get_output(0) \ No newline at end of file diff --git a/src/relay/interpreter.cc b/src/relay/interpreter.cc new file mode 100644 index 0000000000000..750dbd66e09dc --- /dev/null +++ b/src/relay/interpreter.cc @@ -0,0 +1,431 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/interpreter.cc + * \brief An interpreter for the Relay IR. + */ + +#include +#include +#include +#include +#include +#include +#include "./ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + +/* Value Implementation */ +Closure ClosureNode::make(tvm::Map env, Function func) { + NodePtr n = make_node(); + n->env = std::move(env); + n->func = std::move(func); + return Closure(n); +} + +TVM_REGISTER_API("relay._make.Closure") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ClosureNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ClosureNode* node, tvm::IRPrinter* p) { + p->stream << "ClosureNode(" << node->func << ")"; + }); + +TupleValue TupleValueNode::make(tvm::Array value) { + NodePtr n = make_node(); + n->fields = value; + return TupleValue(n); +} + +TVM_REGISTER_API("relay._make.TupleValue") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TupleValueNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleValueNode* node, + tvm::IRPrinter* p) { + p->stream << "TupleValueNode(" << node->fields << ")"; + }); + +TensorValue TensorValueNode::make(runtime::NDArray data) { + NodePtr n = make_node(); + n->data = std::move(data); + return TensorValue(n); +} + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorValueNode* node, + tvm::IRPrinter* p) { + auto to_str = GetPackedFunc("relay._tensor_value_repr"); + std::string data_str = to_str(GetRef(node)); + p->stream << "TensorValueNode(" << data_str << ")"; + }); + +TensorValue TensorValueNode::FromType(const Type& t) { + if (auto tt_node = t.as()) { + std::vector dims; + + for (auto dim : tt_node->shape) { + auto int_node = dim.as(); + CHECK(int_node) << "expected concrete dimensions"; + dims.push_back(int_node->value); + } + + DLDataType dtype; + DLContext context; + + switch (tt_node->dtype.code()) { + case halideir_type_int: + dtype.code = kDLInt; + break; + case halideir_type_uint: + dtype.code = kDLUInt; + break; + case halideir_type_float: + dtype.code = kDLFloat; + break; + default: + throw dmlc::Error("can not convert HalideIR type into DLTensor dtype"); + } + + dtype.bits = tt_node->dtype.bits(); + dtype.lanes = tt_node->dtype.lanes(); + + // TODO(@jroesch): Is this the right place to place the tensor? + context.device_type = DLDeviceType::kDLCPU; + context.device_id = 0; + runtime::NDArray data = NDArray::Empty(dims, dtype, context); + return TensorValueNode::make(data); + } else { + LOG(FATAL) << "expected a tensor type"; + return TensorValue(); + } +} + +TVM_REGISTER_API("relay._make.TensorValue") + .set_body([](TVMArgs args, TVMRetValue* ret) { + runtime::NDArray data = args[0]; + *ret = TensorValueNode::make(data); + }); + +/* Evaluator Implementation. */ +struct EvalError : dmlc::Error { + explicit EvalError(const std::string& msg) : Error(msg) {} +}; + +struct IsSimpleType : TypeVisitor<> { + bool is_simple; + IsSimpleType() : is_simple(true) {} + void VisitType_(const FuncTypeNode* fn_ty) override { + if (fn_ty->type_params.size() != 0) { + is_simple = false; + } + } +}; + +bool is_simple_type(const Type& t) { + IsSimpleType ist; + ist.VisitType(t); + return ist.is_simple; +} + +struct Frame { + // In the efficient version this should seperate args, locals, and return + // address. + tvm::Map locals; + + explicit Frame(tvm::Map locals) : locals(locals) {} +}; + +struct Stack { + std::vector frames; + Stack() : frames() { frames.push_back(Frame({})); } + + Frame& current_frame() { return frames.back(); } + + Value lookup(const Var& local) { + for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { + if (frame->locals.find(local) != frame->locals.end()) { + return frame->locals.at(local); + } + } + throw dmlc::Error("internal error could not find"); + } + struct LocalFrame { + Stack& st; + explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { + st.frames.push_back(fr); + } + ~LocalFrame() { st.frames.pop_back(); } + }; +}; + +struct Interpreter : ExprFunctor { + Environment env; + Stack stack; + std::unordered_map operator_map_; + + template + T with_frame(const Frame& fr, const std::function& f) { + Stack::LocalFrame lf(stack, fr); + return f(); + } + + Interpreter(Environment env) : env(env), operator_map_() {} + + void extend(const Var& id, Value v) { + this->stack.current_frame().locals.Set(id, v); + } + + inline Value lookup(const Var& local) { + return this->stack.lookup(local); + } + + Value Eval(const Expr& expr) { + return (*this)(expr); + } + + Value VisitExpr(const Expr& expr) override { + RELAY_LOG(INFO) << "VisitExpr: " << expr << std::endl; + auto ret = ExprFunctor::VisitExpr(expr); + return ret; + } + + Value VisitExpr_(const VarNode* var_node) override { + Var var = GetRef(var_node); + for (auto frame = this->stack.frames.rbegin(); + frame != this->stack.frames.rend(); frame++) { + auto ivar = frame->locals.find(var); + if (ivar != frame->locals.end()) { + Value result = (*ivar).second; + return result; + } + } + + throw EvalError("internal error local variable can not be found " + + var->name_hint); + } + + Value VisitExpr_(const GlobalVarNode* op) override { + Function func = this->env->Lookup(GetRef(op)); + return Eval(func->body); + } + + Value VisitExpr_(const OpNode* id) override { + // TODO(@jroesch): Eta-expand and return in this case. + throw EvalError( + "internal error, need to wrap intrinsic into call synthetic call node " + "in " + "this case, eta expand"); + } + + Value VisitExpr_(const ConstantNode* op) override { + return TensorValueNode::make(op->data); + } + + Value VisitExpr_(const TupleNode* op) override { + std::vector values; + + for (auto field : op->fields) { + Value field_value = Eval(field); + values.push_back(field_value); + } + + return TupleValueNode::make(values); + } + + Value VisitExpr_(const FunctionNode* func_node) override { + auto func = GetRef(func_node); + tvm::Map captured_env; + Array free_vars = FreeVariables(func); + + for (const auto& var : free_vars) { + captured_env.Set(var, Eval(var)); + } + + return ClosureNode::make(captured_env, func); + } + + Value invoke_operator(const Op& op, tvm::Array& args) { + auto op_type = op->op_type; + + std::cout << op->name << std::endl; + PackedFunc op_impl = Op::GetAttr("FEvaluate")[op]; + + // Marshal the arguments. + auto arg_len = args.size() + 1; + std::vector values(arg_len); + std::vector codes(arg_len); + TVMArgsSetter setter(values.data(), codes.data()); + TVMRetValue ret; + + // We need real type information to properly allocate the structure. + for (size_t i = 0; i < args.size(); i++) { + if (const TensorValueNode* tv = args[i].as()) { + setter(i, tv->data); + } + } + + if (auto* tan = op_type.as()) { + // TVM's calling convention is that the final argument is the output + // buffer. To preserve the illusion of being a functional language + // we need to allocate space for the output buffer based on the + // return type. + + CHECK(tan->ret_type.as()); + + auto out_tensor = TensorValueNode::FromType(tan->ret_type); + + setter(arg_len - 1, out_tensor->data); + op_impl.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &ret); + return out_tensor; + } else { + throw EvalError("operators must have function types"); + } + } + + Value invoke(const Closure& closure, const tvm::Array& args) { + // Get a reference to the function inside the closure. + auto func = closure->func; + + // Allocate a frame with the parameters and free variables. + tvm::Map locals; + + CHECK(func->params.size() == args.size()); + + for (size_t i = 0; i < func->params.size(); i++) { + locals.Set(func->params[i], args[i]); + } + + // Add the var to value mappings from the Closure's environment. + for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { + locals.Set((*it).first, (*it).second); + } + + return with_frame(Frame(locals), [&]() { return Eval(func->body); }); + } + + Value VisitExpr_(const CallNode* op) override { + tvm::Array args; + for (auto arg : op->args) { + args.push_back(Eval(arg)); + } + + std::vector arg_types; + + // We need to catch the case where we are invoking a primitive directly. + if (const OpNode* intr = op->op.as()) { + return this->invoke_operator(GetRef(intr), args); + } else { + Value fn_val = Eval(op->op); + if (const ClosureNode* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); + return this->invoke(closure, args); + } else { + throw EvalError( + "Type error, expected function value in the call position"); + } + } + } + + Value VisitExpr_(const LetNode* op) override { + auto value = Eval(op->value); + this->extend(op->var, value); + return Eval(op->body); + } + + Value VisitExpr_(const TupleGetItemNode* op) override { + Value val = Eval(op->tuple); + if (auto product_node = val.as()) { + return product_node->fields[op->index]; + } else { + throw EvalError("not a product"); + } + } + + Value VisitExpr_(const IfNode* op) override { + Value v = Eval(op->cond); + if (const TensorValueNode* bv = v.as()) { + // TODO(@jroesch): Ask TQ + if (reinterpret_cast(bv->data->data)[0]) { + return Eval(op->true_branch); + } else { + return Eval(op->false_branch); + } + } else { + throw EvalError("type error, type system should have caught this"); + } + } +}; + +void CompileOperators(const Expr& e) { + auto lowered_funcs = LowerOps(e); + RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_funcs << std::endl; + if (lowered_funcs.size()) { + const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build"); + CHECK(fbuild_ptr); + auto fbuild = *fbuild_ptr; + Module module = fbuild(lowered_funcs); + for (auto lf : lowered_funcs) { + RELAY_LOG(INFO) << "LoweredFunc: " << lf->name << std::endl; + auto fevaluate = module.GetFunction(lf->name); + auto op_reg_ptr = tvm::relay::OpRegistry::Registry()->Find(lf->name); + CHECK(op_reg_ptr); + OpRegistry op_reg = *op_reg_ptr; + op_reg.set_attr("FEvaluate", fevaluate, 1); + } + } +} + +Value Evaluate(Environment env, Expr e) { + CompileOperators(e); + Interpreter interp(env); + return interp.Eval(e); +} + +TVM_REGISTER_API("relay._eval.evaluate") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Environment env = args[0]; + Expr expr = args[1]; + *ret = Evaluate(env, expr); + }); + +// TVM_REGISTER_API("relay._eval.invoke") +// .set_body([](TVMArgs args, TVMRetValue* ret) { +// // tood maybe tweak interface +// Environment env = args[0]; +// GlobalVar id = args[1]; +// tvm::Array relay_args = args[2]; + +// // Because we are interfacing with the runtime here, we first need to +// type +// // check the arguments to the function at runtime. +// // +// // Because we have values we can easily compute a type from them and +// just +// // type check the call before execution. +// Evaluator eval(env); +// Value fn_val = eval.Eval(id); +// if (const ClosureNode* closure_node = fn_val.as()) { +// auto closure = GetRef(closure_node); +// auto result = eval.invoke(closure, relay_args); +// *ret = result; +// } else { +// throw EvalError( +// "Type error, expected function value in the call position"); +// } +// }); + +} // namespace tvm +} // namespace tvm \ No newline at end of file diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index c75c414c8ce9b..d3bbeb80dbd02 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstantNode* node, tvm::IRPrinter* p) { - p->stream << "Constant(TODO)"; + const PackedFunc* fprint = Registry::Get("relay._constant_repr"); + CHECK(fprint) << "unable to find printing function for constants"; + std::string data = (*fprint)(GetRef(node)); + p->stream << "Constant(" << data << ")"; }); TensorType ConstantNode::tensor_type() const { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 557daa98e8998..a5374aec979cf 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -194,6 +194,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { this->VisitType(ty_arg); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc new file mode 100644 index 0000000000000..39df0f53f5d3a --- /dev/null +++ b/src/relay/pass/fuse_ops.cc @@ -0,0 +1,96 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file src/tvm/relay/pass/fuse_ops.cc + * + * \brief Fuse Relay eligble sequences of Relay operators into a single one. + * + */ +#include +#include +#include +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +struct AbstractFusableOps : ExprMutator { + Environment env; + Array fusable_funcs; + int counter = 0; + + AbstractFusableOps(Environment env) : env(env) {} + + Expr VisitExpr_(const CallNode* call) { + if (auto op_node = call->op.as()) { + // Placeholder fusion algorithm which abstracts + // single definitions into functions only. + Array params; + Array args; + + int param_number = 0; + for (auto arg : call->args) { + auto name = std::string("p") + std::to_string(param_number); + auto type = arg->checked_type(); + auto var = VarNode::make(name, type); + params.push_back(var); + args.push_back(var); + } + + auto body = CallNode::make(call->op, args, call->attrs); + auto func = FunctionNode::make(params, body, call->checked_type(), {}); + std::string func_name = "fused_"; + func_name += op_node->name; + fuc_name += "_"; + func_name += std::to_string(counter++); + auto gv = env->GetGlobalVar(func_name); + env->Add(gv, func); + fusable_funcs.push_back(gv); + return CallNode::make(gv, call->args, Attrs()); + } else { + return ExprMutator::VisitExpr_(call); + } + } +}; + +struct RewriteFusable : ExprMutator { + Environment env; + Array func_to_ops; + int counter = 0; + + AbstractFusableOps(Environment env, func) : env(env), funcs_to_ops(funcs) {} + + Expr VisitExpr_(const CallNode* call) { + const GlobalVar* gv = call->op.as(); + if (gv) { + + } else { + return GetRef(call); + } + } +}; + +Expr FuseOps(const Environment& env, const Expr& e) { + // First we convert all chains of fusable ops into + // abstracted functions which we mark as primtive + // then we convert these primtive functions into + // new operators. + auto abstract = AbstractFusableOps(env); + auto abstracted_e = abstract.VisitExpr(e); + std::cout << "Abstracted Thing: " << abstracted_e << std::endl; + return e; +} + +TVM_REGISTER_API("relay._ir_pass.FuseOps") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuseOps(args[1], args[0]); +}); + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/lower_ops.cc b/src/relay/pass/lower_ops.cc new file mode 100644 index 0000000000000..01c211d22261c --- /dev/null +++ b/src/relay/pass/lower_ops.cc @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file src/tvm/relay/pass/lower_ops.cc + * + * \brief Lower a Relay program to set of TVM operators. + * + */ +#include +#include +#include +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +// TODO(@jroesch): do full liveness through definitions. +struct LiveOps : ExprVisitor { + LiveOps() : calls() {} + // std::set ops; + Array calls; + + // void VisitExpr_(const OpNode* node) final { + // ops.insert(GetRef(node)); + // } + + void VisitExpr_(const CallNode* call) final { + if (call->op.as()) { + calls.push_back(GetRef(call)); + } + } +}; + +/*! \brief Return the set of operators in their TVM format. */ +Array LowerOps(const Expr& e, const std::string& target) { + RELAY_LOG(INFO) << "LowerOps: e=" << e; + auto flower_ptr = Registry::Get("relay.op.compiler._lower"); + CHECK(flower_ptr); + PackedFunc flower = *flower_ptr; + auto live_ops = LiveOps(); + live_ops.VisitExpr(e); + + auto schedule_reg = Op::GetAttr("FTVMSchedule"); + auto compute_reg = Op::GetAttr("FTVMCompute"); + + Array lowered_funcs; + + for (const Call& call : live_ops.calls) { + CHECK(IsPrimitiveOp(call->op)) << "failed to lower " + << call->op << "can only lower primitve operations"; + + auto op = Downcast(call->op); + + Array inputs; + std::string input_name = "in"; + int i = 0; + for (auto type_arg : call->type_args) { + auto tt = Downcast(type_arg); + inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i), tt->shape, tt->dtype).output(0)); + i++; + } + auto output_tt = op->op_type->ret_type; + Array outputs = compute_reg[op](call->attrs, inputs, output_tt); + auto schedule = schedule_reg[op](outputs, target); + size_t call_addr = (size_t)call.operator->(); + LoweredFunc lf = flower(op->name + std::to_string(call_addr), schedule, inputs, outputs); + lowered_funcs.push_back(lf); + } + + return lowered_funcs; +} + +TVM_REGISTER_API("relay._ir_pass.LowerOps") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LowerOps(args[0]); +}); + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/monomorph.cc b/src/relay/pass/monomorph.cc new file mode 100644 index 0000000000000..10e6a1345ae3c --- /dev/null +++ b/src/relay/pass/monomorph.cc @@ -0,0 +1,131 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file src/tvm/relay/pass/monomorph.cc + * + * \brief Remove polymorphism/generics from a Relay program. + * + */ +#include +#include +#include +#include +#include +#include "./type_subst.h" +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +using MMCacheKey = std::pair>; + +struct MMCacheKeyEqual : std::binary_function { + bool operator()(const MMCacheKey& x, const MMCacheKey& y) const { + bool expr_match = AlphaEqual(x.first, y.first); + + if (x.second.size() != y.second.size()) { + return false; + } + + bool types_match = true; + for (size_t i = 0; i < x.second.size(); i++) { + types_match &= x.second[i] == y.second[i]; + } + + return expr_match && types_match; + } +}; + +struct MonoMorphizer : ExprMutator { + std::map mm_cache; + + FuncType Instantiate(FuncType fn_ty, Array type_args) const { + tvm::Map subst_map; + + CHECK(fn_ty->type_params.size() == type_args.size()) << + "internal error: type parameters " << fn_ty->type_params << + "do not match the number of type arguments" << type_args; + + // Build a subsitituion map up from the function type and type arguments. + for (size_t i = 0; i < type_args.size(); i++) { + subst_map.Set(fn_ty->type_params[i], type_args[i]); + } + + // TODO(@jroesch): handle type constraints. + Type inst_ty = + FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + + return Downcast(TypeSubst(inst_ty, subst_map)); + } + + Op SpecializeOp(const Op& op, const std::string& mangled_op_name, + Array type_args) const { + auto registry = ::tvm::relay::OpRegistry::Registry(); + auto spec_op_reg = registry->Find(mangled_op_name); + + if (spec_op_reg) { + return spec_op_reg->op(); + } else { + OpRegistry& new_op_reg = registry->__REGISTER_OR_GET__(mangled_op_name).set_name(); + + auto fn_ty = op->op_type; + new_op_reg.op()->op_type = Instantiate(fn_ty, type_args); + + // Now we want to copy over some attributes. + PackedFunc compiler = Op::GetAttr("FTVMCompute")[op]; + PackedFunc schedule = Op::GetAttr("FTVMSchedule")[op]; + + new_op_reg.set_attr("FTVMCompute", compiler); + new_op_reg.set_attr("FTVMSchedule", schedule); + return new_op_reg.op(); + } + } + + std::string Mangle(const std::string& name, const Array& args, const Attrs attrs) const { + // TODO(@jroesch): How do we make it possible for multiple programs to monomorph. + // We should really compute hash or soemthing? + std::stringstream ss; + ss << name << args << attrs; + return ss.str(); + } + + Expr VisitExpr_(const CallNode* call) { + // Process the arguments. + Array mm_args; + for (auto arg : call->args) { + mm_args.push_back(this->VisitExpr(arg)); + } + + if (auto op_node = call->op.as()) { + auto op = GetRef(op_node); + + // Check the cache. + MMCacheKey key = {op, call->type_args}; + auto in_cache = this->mm_cache.find(key); + if (in_cache != this->mm_cache.end()) { + return CallNode::make(in_cache->second, mm_args, call->attrs, {}); + } else { + auto new_name = Mangle(op->name, call->type_args, call->attrs); + auto new_op = SpecializeOp(op, new_name, call->type_args); + this->mm_cache.insert({key, new_op}); + return CallNode::make(new_op, mm_args, call->attrs, {}); + } + } else { + auto mm_op = this->VisitExpr(call->op); + return CallNode::make(mm_op, mm_args, call->attrs, {}); + } + } +}; + +Expr Monomorph(const Environment& env, const Expr& e) { + auto mm = MonoMorphizer(); + return mm.VisitExpr(e); +} + +TVM_REGISTER_API("relay._ir_pass.Monomorph") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Monomorph(args[0], args[1]); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 7c8eeef92c5d5..80cc9c1edd94c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -272,8 +272,8 @@ class TypeInferencer : private ExprFunctor { auto* fn_ty_node = ftype.as(); CHECK(fn_ty_node != nullptr) - << "only expressions with function types can be called, at " - << call->span; + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; Array type_args; FuncType fn_ty = Instantiate(fn_ty_node, &type_args); diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index c1f00c7b65e02..336ff615c5cf0 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -3,10 +3,11 @@ * * \file util.cc * - * \brief simple util for relay. + * \brief Utility functions for Relay. */ #include #include +#include #include "../ir/type_functor.h" namespace tvm { diff --git a/tests/python/relay/test_evaluator.py b/tests/python/relay/test_evaluator.py new file mode 100644 index 0000000000000..aa4f50238b9e7 --- /dev/null +++ b/tests/python/relay/test_evaluator.py @@ -0,0 +1,94 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.eval import Value, TupleValue, evaluate +from tvm.relay import op +from tvm.relay.scope_builder import ScopeBuilder + + +def check_eval(expr, args, expected_result, env=None): + if env is None: + env = relay.env.Environment({}) + + result = evaluate(env, expr, *args) + np.testing.assert_allclose(result.asnumpy(), expected_result) + + +def test_from_scalar(): + np.testing.assert_allclose(Value.from_scalar(1).asnumpy(), 1) + np.testing.assert_allclose(Value.from_scalar(10.0).asnumpy(), 10.0) + np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True) + + +def test_tuple_value(): + tv = TupleValue(Value.from_scalar( + 1), Value.from_scalar(2), Value.from_scalar(3)) + np.testing.assert_allclose(tv[0].asnumpy(), 1) + np.testing.assert_allclose(tv[1].asnumpy(), 2) + np.testing.assert_allclose(tv[2].asnumpy(), 3) + + +def test_id(): + x = relay.var('x', 'float32') + ident = relay.Function([x], x) + env = relay.env.Environment({}) + res = evaluate(env, ident, 1.0) + check_eval(ident, [1.0], 1.0) + + +def test_add_const(): + two = op.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + check_eval(func, [], 2) + + +def test_mul_param(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(1, 10)) + func = relay.Function([x, y], op.multiply(x, y)) + x_data = np.random.rand(10, 10) + y_data = np.random.rand(1, 10) + check_eval(func, [x_data, y_data], x_data * y_data) + + +def test_linear(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + y = op.add(op.nn.dense(x, w), b) + func = relay.Function([x, w, b], y) + x_data = np.random.rand(10, 10) + w_data = np.random.rand(10, 10) + b_data = np.random.rand(10) + check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data) + + +def test_loop(): + pass + +# @no_type_check +# @relay +# def loop_debug(i: Tensor[Int, (10, 1)]) -> Tensor[Int, (10, 1)]: +# return relay.debug(i - 1) + +# out = loop_debug(np.ones((10, 1), dtype=np.int32)) + +# import pdb; pdb.set_trace() + +# # @no_type_check +# # @relay +# # def loop_debug(i: Int[64], step: Int[64], zero: Int[64]) -> Int[64]: +# # if relay.iequal(i, zero): +# # return i +# # else: +# # return loop_debug(relay.debug(relay.isubtract(i, step)), step, zero) + +# # out = loop_debug(10, 1, 0) + +# # import pdb; pdb.set_trace() + + +if __name__ == "__main__": + test_id() + test_add_const() + test_linear() diff --git a/tests/python/relay/test_tvm_rts.py b/tests/python/relay/test_tvm_rts.py new file mode 100644 index 0000000000000..9e2931078deec --- /dev/null +++ b/tests/python/relay/test_tvm_rts.py @@ -0,0 +1,86 @@ +import numpy as np + +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.eval import evaluate +from tvm.relay.to_tvm import evaluate_rts +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.op import add +from tvm.relay.env import Environment + +# @tq, @jr should we put this in testing ns? +def check_rts(env, expr, args, expected_result): + """ + Check that evaluating `expr` applied to the arguments produces + `result` on both the evaluator and TVM runtime. + + Parameters + ---------- + expr: + The expression to evaluate + + args: list of Expr + The arguments to supply the expr. + + expected_result: + The expected result of running the expression. + """ + eval_result = evaluate(env, expr, *args) + rts_result = evaluate_rts(env, expr, *args) + np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) + +def test_add_op_scalar(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=()) + y = relay.var('y', shape=()) + func = relay.Function([x, y], add(x, y)) + x_data = np.array(10.0, dtype='float32') + y_data = np.array(1.0, dtype='float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +def test_add_op_tensor(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(10, 5)) + func = relay.Function([x, y], add(x, y)) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(10, 5).astype('float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +def test_add_op_broadcast(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(1, 5)) + func = relay.Function([x, y], add(x, y)) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(1, 5).astype('float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +def test_mlp(): + net, params = relay.testing.mlp.get_workload(1, 10) + import pdb; pdb.set_trace() + + +if __name__ == "__main__": + test_add_op_scalar() + test_add_op_tensor() + test_add_op_broadcast() + test_mlp() \ No newline at end of file diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e1d749e758631..a7ab06c548fa1 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -6,6 +6,13 @@ from tvm.relay.ir_pass import infer_type from tvm import relay +def assert_has_type(expr, typ, env=Environment({})): + checked_expr = infer_type(env, expr) + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % ( + checked_type, typ)) + def test_monomorphic_let(): "Program: let x = 1; return x" @@ -16,6 +23,51 @@ def test_monomorphic_let(): assert xchecked.checked_type == relay.scalar_type("float64") +# def test_single_op(): +# "Program: fn (x : float32) { let t1 = f(x); t1 }" +# b = IRBuilder() +# with b.function(('x', 'float32')) as func: +# x, = func.param_ids() +# t1 = b.let('t1', log(x)) +# b.ret(t1) +# assert_has_type(func.to_func(), func_type(['float32'], 'float32')) + +# def test_add_op(): +# """ +# Program: +# fn (x, y) { +# return x + y; +# } +# """ +# b = IRBuilder() +# x = b.param('x', tensor_type(5, 5, 5)) +# y = b.param('y', tensor_type(5, 5, 5)) +# with b.function(x, y) as func: +# b.ret(add(x.var, y.var)) +# b.ret(func) +# prog, env = b.get() +# ttype = tensor_type(5, 5, 5) +# expected_ty = func_type([ttype, ttype], ttype) +# assert_has_type(func.to_func(), expected_ty) + +# def test_add_broadcast_op(): +# """ +# Program: +# fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { +# return x + y; +# } +# """ +# b = IRBuilder() +# x = b.param('x', tensor_type(10, 4)) +# y = b.param('y', tensor_type(5, 10, 1)) +# with b.function(x, y) as func: +# b.ret(add(x.var, y.var)) +# b.ret(func) +# prog, env = b.get() +# ttype = tensor_type(5, 5, 5) +# expected_ty = func_type([ttype, ttype], ttype) +# assert_has_type(func.to_func(), expected_ty) + def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { @@ -76,10 +128,23 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in env.astext() assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) +# This currently fails and should pass under the type system. +# This test is to illustrate problem with +def test_incomplete_call(): + ib = IRBuilder() + inc_call = ib.global_var('inc_call') + x = ib.param('x', ty='int32') + f = ib.param('f') + with ib.decl(inc_call, x, f): + ib.ret(f(x)) + import pdb; pdb.set_trace() +# This currently fails and should pass under the type system. def test_tuple(): tp = relay.TensorType((10,)) x = relay.var("x", tp) + f = ib.param('f') + with ib.decl(dup, x): res = relay.Tuple([x, x]) assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) @@ -110,10 +175,13 @@ def test_type_args(): if __name__ == "__main__": test_free_expr() test_dual_op() + test_single_op() + test_add_op() test_recursion() test_monomorphic_let() test_decl() test_recursion() test_tuple() + test_incomplete_call() test_free_expr() test_type_args()