Skip to content

Commit

Permalink
Clean up and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Sep 16, 2018
1 parent 3f5e2c8 commit c908b7a
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 241 deletions.
46 changes: 37 additions & 9 deletions include/tvm/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,56 @@ class EnvironmentNode : public RelayNode {

TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);

/*! \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);

/*! \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);

/*! \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);

/*! \brief Lookup a global function by its variable. */
/*! \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);

/*! \brief Lookup a global function by its variable. */
Function Lookup(const GlobalVar& id);
/*! \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);

/*! \brief Lookup a global function by its string name */
Function Lookup(const std::string& s);
/*! \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);

// TODO(@jroesch, @tqchen): what are the semantics here
void Merge(const Environment& env);
/*! \brief Combine with another Environment.
* \param other The other environment.
*/
void Merge(const Environment& other);

using Transformer =
runtime::TypedPackedFunc<runtime::TypedPackedFunc<Function(
const GlobalVar&, const Function&)>(const Environment&)>;

/*! \brief Apply a function over every function in the global environment. */
void Transform(Transformer tranformer);
/*! \brief Apply a function over every function in the global environment.
* \param transformer The transformation function.
*/
void Transform(Transformer transformer);

static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand Down
14 changes: 7 additions & 7 deletions include/tvm/relay/expr_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor<Expr(const Expr& n)> {
auto type = this->VisitType(op->type);
return ParamNode::make(var, type);
} else {
CHECK(false) << "the default param visitor expected a Var found: "
LOG(FATAL) << "the default param visitor expected a Var found: "
<< var_expr << std::endl;
__builtin_unreachable();
return Expr();
}
}

Expand All @@ -112,10 +112,10 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor<Expr(const Expr& n)> {
auto ty_param_ref = GetRef<TypeParam>(ty_param);
ty_params.push_back(ty_param_ref);
} else {
CHECK(false)
LOG(FATAL)
<< "the default function visitor expected a TypeParam found: "
<< ty_param_type << std::endl;
__builtin_unreachable();
return Expr();
}
}

Expand All @@ -128,7 +128,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor<Expr(const Expr& n)> {
} else {
CHECK(false) << "the default function visitor expected a Param found: "
<< param_expr << std::endl;
__builtin_unreachable();
return Expr();
}
}

Expand Down Expand Up @@ -165,9 +165,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor<Expr(const Expr& n)> {
auto body = this->VisitExpr(op->body);
return LetNode::make(var, value, body, type);
} else {
CHECK(false) << "the default let visitor expected a Var found: "
LOG(FATAL) << "the default let visitor expected a Var found: "
<< var_expr << std::endl;
__builtin_unreachable();
return Expr();
}
}

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class OpRegistry {
}
return *this;
}
/*! \return The global single retistry */
/*! \return The global single registry */
TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();

private:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from . import _make
from . import _env


@register_relay_node
class Environment(NodeBase):
"""The global Relay environment containing functions,
options and more.
"""

def __init__(self, funcs) -> None:
"""Construct an environment.
Expand Down
227 changes: 3 additions & 224 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
@@ -1,232 +1,11 @@
# pylint: disable=no-else-return,
# pylint: disable=unidiomatic-typecheck
"""The optimizer for Relay.
"""The set of passes for Relay.
Exposes an interface for configuring the optimizer and scripting
it directly in Python.
Exposes an interface for configuring the passes and scripting
them in Python.
"""
from typing import TypeVar, Generic, Union
from typing import Dict, Tuple, List, Callable
import tvm

from .expr import Expr
from .expr import Function, Let, Call, Var
from .expr import GlobalVar, If, Constant
from .type import Type, TypeParam
from .env import Environment
from .op import Op
from .op.op import specialize_op
# import relay.make as relay_mk
# from relay import ir
# from relay.env import Environment
# from relay.tyck import check_expr
# from relay.first_order_reverse_ad import fo_with_gradient
# from relay.anf import to_anf
from . import _ir_pass

# Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name
check_expr = _ir_pass.check_expr

# # pylint: disable=invalid-name
# concretize = _opt.concretize

# # pylint: disable=invalid-name
# optimize = _opt.optimize

# # pylint: disable=invalid-name
# type_specialize = _opt.type_specialize

# # pylint: disable=invalid-name
# compile_ops_to_module = _opt.compile_ops_to_module


@tvm.register_func("relay.mangle")
def mangle(name: str, types: List[Type]) -> str:
for typ in types:
name += str(typ) + "_"
return name


T = TypeVar('T')


class AbstractExprVisitor(Generic[T]):
"""A functional visitor over Expr in Python."""

# pylint: disable=no-else-return
def visit(self, expr: Expr) -> T:
"""Apply the visitor to an expression."""
if isinstance(expr, Function):
return self.visit_function(expr)
elif isinstance(expr, Call):
return self.visit_call(expr)
elif isinstance(expr, Let):
return self.visit_let(expr)
elif isinstance(expr, Var):
return self.visit_local_var(expr)
elif isinstance(expr, GlobalVar):
return self.visit_global_var(expr)
elif isinstance(expr, If):
return self.visit_if(expr)
elif isinstance(expr, Tuple):
return self.visit_tuple(expr)
elif isinstance(expr, Constant):
return self.visit_constant(expr)
else:
raise Exception(f"warning unhandled case: {type(expr)}")

def visit_function(self, _: Function) -> T:
raise Exception("Abstract method please implement me.")

def visit_let(self, _: Let) -> T:
raise Exception("Abstract method please implement me.")

def visit_call(self, _: Call) -> T:
raise Exception("Abstract method please implement me.")

def visit_local_id(self, _: Var) -> T:
raise Exception("Abstract method please implement me.")

def visit_type(self, typ: Type) -> Type:
return typ

def visit_if(self, _: If) -> T:
raise Exception("Abstract method please implement me.")

def visit_tuple(self, _: Tuple) -> T:
raise Exception("Abstract method please implement me.")

def visit_constant(self, _: Constant) -> T:
raise Exception("Abstract method please implement me.")

def visit_global_var(self, _: GlobalVar) -> T:
raise Exception("Abstract method please implement me.")

@classmethod
def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]:
def _outer_wrapper(env):
visitor = cls(env)

def _inner_wrapper(_, func):
return visitor.visit(func)
return _inner_wrapper
return _outer_wrapper


class ExprVisitor(AbstractExprVisitor[Expr]):
"""A functional visitor over Expr in Python."""

def visit_function(self, fn: Function) -> Expr:
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)

def visit_let(self, let: Let) -> Expr:
new_var = self.visit(let.var)
new_value_type = self.visit_type(let.value_type)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body, new_value_type)

def visit_call(self, call: Call) -> Expr:
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_local_var(self, local_var: Var) -> Expr:
return local_var

def visit_global_id(self, global_var: GlobalVar) -> Expr:
return global_var

def visit_if(self, ite: If) -> Expr:
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))

def visit_tuple(self, tup: Tuple) -> Expr:
return Tuple([self.visit(field) for field in tup.fields])

def visit_constant(self, const: Constant) -> Expr:
return const


MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]]


class Monomorphize(ExprVisitor):
"""A monomorphization pass.
Implements what is known as "monomorphization" in
classic compiler literature. This pass removes
polymorphism replacing calls to functions and
operators with type specialized versions.
"""
monomorph_map: Dict[MMCacheKey, Union[Op, Function]]

# pylint: disable=super-init-not-called
def __init__(self, env: Environment) -> None:
self.env = env
# Stores (GlobalVar, Type), should eventually store attributes.
self.monomorph_map = {}

# pylint: disable=no-else-return
def visit_call(self, call: Call) -> Expr:
cache_key = (call.op, call.type_args)
new_args = [self.visit(arg) for arg in call.args]

if cache_key in self.monomorph_map:
op = self.monomorph_map[cache_key]
new_args = [self.visit(arg) for arg in call.args]
return Call(op, new_args, call.attrs)
else:
if isinstance(call.op, Op):
poly_name = call.op.name
mono_name = mangle(poly_name, call.type_args)
for arg in call.type_args:
if isinstance(arg, TypeParam):
# raise Exception("...") # Fix me in the morning!!!
return call

mono_op = specialize_op(poly_name, mono_name, call.type_args)
self.monomorph_map[cache_key] = mono_op
return Call(mono_op, new_args, call.attrs, [])
elif isinstance(call.op, GlobalVar):
return call
# defn = self.env.lookup(call.op)
# new_id = self.env.global_id(defn.id.name + str(1))
# cache_key = (call.op, call.type_args)
# self.monomorph_map[cache_key] = new_id
# new_body = self.visit(type_specialize(call.type_args, defn.body))
# new_body = Function(
# [], new_body.params, new_body.ret_type, new_body.body)
# new_ty = check_expr(self.env, new_body)
# # TODO(@jroesch): move into C++
# # TODO(@joresch): implement and call name mangler
# defn = Defn(new_id, new_ty, new_body)
# self.env.add(defn)
# self.visit_item(defn)
# return Call(new_id, call.args, call.attrs)

elif isinstance(call.op, Function):
return call
# new_func = type_specialize(call.type_args, call.op)
# new_func = self.visit(new_func)
# new_func = Function([],
# new_func.params,
# new_func.ret_type,
# new_func.body)
# check_expr(self.env, new_func)
# return Call(new_func, call.args, call.attrs)
else:
new_fn = self.visit(call.op)
return Call(new_fn, new_args, call.attrs)


# TODO(@jroesch): Fix up my type
__tgt_host__ = __tgt__ = "llvm"
__relay_tvm_context__ = tvm.cpu()
2 changes: 2 additions & 0 deletions tests/scripts/task_python_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ TVM_FFI=cython python -m nose -v tests/python/integration || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1
TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1
TVM_FFI=cython python -m nose -v tests/python/relay || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/relay || exit -1

# Do not enabke OpenGL
# TVM_FFI=cython python -m nose -v tests/webgl || exit -1
Expand Down

0 comments on commit c908b7a

Please sign in to comment.