Skip to content

Commit

Permalink
Add header for evaluator.
Browse files Browse the repository at this point in the history
Add initial version of evaluator and tests

WIP

Work towards simple examples in the evaluator

Requires implementation of lowering ops and monomorph

Evaluator now works on simple cases

Restore Function case in Evaluator

WIP

Fix rebase issues

working towards working version

RTS is now working again

RTS can add numbers now

Fix some rebase issues

Fix up tests post rebase

WIP

Issue type checking MLP

Remove dead file

Clean up evaluator

Remove accidental change

Reset changes from apache#1962
  • Loading branch information
jroesch committed Oct 23, 2018
1 parent c2b3615 commit 52ddd94
Show file tree
Hide file tree
Showing 25 changed files with 2,040 additions and 13 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ namespace tvm {
* You can find more about Relay by reading the language reference.
*/
namespace relay {

#define RELAY_DEBUG(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}

/*!
* \brief we always used NodeRef for referencing nodes.
*
Expand Down
140 changes: 140 additions & 0 deletions include/tvm/relay/evaluator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/evaluator.h
* \brief An evaluator for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, an a Relay expression it produces a value.
*
* This is intended as an implementation of the reference semantics for
* the Relay IR, as well as for debugging and testing.
*/
#ifndef TVM_RELAY_EVALUATOR_H_
#define TVM_RELAY_EVALUATOR_H_

#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {

/*!
* \brief A Relay value.
*/
class Value;

/*! \brief Evaluate an expression in the environment producing a value.
*
* This implements the reference semantics of Relay, giving us a tool
* for debugging and testing, especially in the development of alternative
* backends/runtimes.
*
* The resulting value can be passed to Python, making it easy to use
* for testing.
*
* The evaluator interprets the program pieces between TVM operators
* using TVM to back all Relay operator's evaluation.
*
* This is not intended to be an efficient implementation of Relay's
* semantics, eventually the TVM runtime will grow to support Relay's
* features.
*/
Value Evaluate(Environment env, Expr e);

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};

class Value : public NodeRef {
public:
Value() {}
explicit Value(NodePtr<Node> n) : NodeRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(node_.get());
}

using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;

ClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<Value> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }

/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);

/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);

static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);


} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EVALUATOR_H_
26 changes: 19 additions & 7 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/lowered_func.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>

Expand Down Expand Up @@ -94,7 +95,8 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
* although x is not shadowed.
*
* \param e the expression to check.
*
Expand All @@ -104,7 +106,8 @@ bool WellFormed(const Expr& e);

/*! \brief Get free variables from expression e.
*
* Free variables are variables that are not bound by a let or a function parameter in the context.
* Free variables are variables that are not bound by a let or a function
* parameter in the context.
*
* \param e the expression.
*
Expand All @@ -114,7 +117,8 @@ tvm::Array<Var> FreeVariables(const Expr& e);

/*! \brief Get free type parameters from expression e.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param e the expression.
*
Expand All @@ -124,7 +128,8 @@ tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);

/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
Expand All @@ -134,17 +139,24 @@ tvm::Array<TypeVar> FreeTypeVariables(const Type& t);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let binding that are not referenced, and if branch that are not entered.
* It will remove let binding that are not referenced, and if branch that are
* not entered.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
* Another example is `if (true) then 1 else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
*
* \param e the expression to optimize.
*
* \return the optimized expression.
*/
Expr DeadCodeElimination(const Expr& e);

Expr Monomorph(const Environment& env, const Expr& e);

Array<LoweredFunc> LowerOps(const Expr& e, const std::string& target = "llvm");

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_PASS_H_
15 changes: 15 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from ..api import register_func
from . import base
from . import ty
from . import expr
from . import env
from . import ir_pass
from . import testing

# Root operators
from .op import Op
Expand Down Expand Up @@ -46,6 +48,19 @@
If = expr.If
TupleGetItem = expr.TupleGetItem


# helper functions
var = expr.var
const = expr.const

@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())

@register_func("relay._constant_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())

@register_func("relay.debug")
def _debug(*args):
import pdb; pdb.set_trace()
4 changes: 4 additions & 0 deletions python/tvm/relay/_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api

_init_api("relay._eval", __name__)
88 changes: 88 additions & 0 deletions python/tvm/relay/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import _make
from . import _eval
from . import ir_pass
from .expr import Call, Constant

class Value(NodeBase):
"""Base class of all values.
"""
pass

@staticmethod
@register_func("relay.from_scalar")
def from_scalar(i, dtype=None):
if dtype is None:
if isinstance(i, int):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))

return TensorValue(nd.array(np.array(i, dtype=dtype)))


@register_relay_node
class TupleValue(Value):
def __init__(self, *fields):
self.__init_handle_by_constructor__(
_make.TupleValue, fields)

def __getitem__(self, field_no):
return self.fields[field_no]


@register_relay_node
class Closure(Value):
pass


@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the evaluator."""

def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)

self.__init_handle_by_constructor__(
_make.TensorValue, data)

def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()

def __eq__(self, other):
return self.data == other.data

def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data)
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
else:
raise Exception("errr")

def evaluate(env, expr, *args):
# assert len(args) == 0
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))

expr = Call(expr, relay_args)

ck_expr = ir_pass.infer_type(expr, env)
mm_expr = ir_pass.monomorph(env, ck_expr)
return _eval.evaluate(env, mm_expr)
5 changes: 4 additions & 1 deletion python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=no-else-return,
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
Expand Down Expand Up @@ -141,3 +141,6 @@ def alpha_equal(lhs, rhs):
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))

lower_ops = _ir_pass.LowerOps
monomorph = _ir_pass.Monomorph
14 changes: 14 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
#pylint: disable=invalid-name
"""Backend compiler related feature registration"""
import tvm
import topi
from . import register

def add_compiler(attrs, inputs, output_type):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]

def add_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)

register("add", "FTVMCompute", add_compiler)
register("add", "FTVMSchedule", add_schedule)
Loading

0 comments on commit 52ddd94

Please sign in to comment.