diff --git a/CMakeLists.txt b/CMakeLists.txt index 572f4aef1432..65a7d9e36e2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS src/schedule/*.cc ) +file(GLOB_RECURSE RELAY_SRCS + src/relay/*.cc + ) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + + if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) diff --git a/docs/conf.py b/docs/conf.py index e3f7f6a82c24..717003824703 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,7 @@ # General information about the project. project = u'tvm' author = u'%s developers' % project -copyright = u'2017, %s' % author +copyright = u'2018, %s' % author github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/' # add markdown parser diff --git a/include/tvm/base.h b/include/tvm/base.h index c2d796b6002c..464259bc0527 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,6 +134,5 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) - } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h new file mode 100644 index 000000000000..09f3a94e1edb --- /dev/null +++ b/include/tvm/relay/base.h @@ -0,0 +1,172 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/base.h + * \brief Base classes for the Relay IR. + */ +#ifndef TVM_RELAY_BASE_H_ +#define TVM_RELAY_BASE_H_ + +#include +#include +#include +#include + +namespace tvm { +/*! + * \brief Relay: a high level functional IR for TVM. + * + * This namespace contains the abstract syntax tree, and other + * essential data structures for the Relay IR. + * + * You can find more about Relay by reading the language reference. + */ +namespace relay { +/*! + * \brief we always used NodeRef for referencing nodes. + * + * By default, NodeRef is a std::shared_ptr of node + */ +using NodeRef = tvm::NodeRef; + +/*! + * \brief Content data type. + */ +using DataType = ::tvm::Type; + +/*! + * \brief Symbolic expression for tensor shape. + */ +using ShapeExpr = ::tvm::Expr; + +/*! + * \brief Hash function for nodes. + * e.g. std::unordered_map + */ +using NodeHash = ::tvm::NodeHash; +/*! + * \brief Equality check function for nodes. + */ +using NodeEqual = ::tvm::NodeEqual; + +/*! + * \brief Macro to make it easy to define node ref type given node + * \param TypeName The name of the reference type. + * \param NodeName The internal contrainer name. + * \param NodeRefBase The base type. + */ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ + explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() { return this->defined(); } \ + using ContainerType = NodeName; \ + }; + + +/*! + * \brief The source name in the Span + * \sa SourceNameNode, Span + */ +class SourceName; +/*! + * \brief The source name in the Span + */ +class SourceNameNode : public Node { + public: + /*! \brief The source name */ + std::string name; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + } + + TVM_DLL static SourceName make(std::string name); + + static constexpr const char* _type_key = "relay.SourceName"; + TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); +}; + +RELAY_DEFINE_NODE_REF(SourceName, SourceNameNode, NodeRef); + +/*! + * \brief Span information for debugging purposes + */ +class Span; +/*! + * \brief Stores locations in frontend source that generated a node. + * + */ +class SpanNode : public Node { + public: + /*! \brief The source name */ + SourceName source; + /*! \brief Line number */ + int lineno; + /*! \brief column offset */ + int col_offset; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("source", &source); + v->Visit("lineno", &lineno); + v->Visit("col_offset", &col_offset); + } + + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); + + static constexpr const char* _type_key = "relay.Span"; + TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); + +/*! + * \brief This is the base node container of all relay structures. + */ +class RelayNode : public Node { + public: + /*! \brief The debug information, can be null, check with span.defined() */ + mutable Span span; + + static constexpr const char* _type_key = "relay.Node"; + TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); +}; + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +RefType GetRef(const NodeType* ptr) { + static_assert(std::is_same::value, + "Can only cast to the ref of same container type"); + return RefType(const_cast(ptr)->shared_from_this()); +} + +/*! + * \brief Get PackedFunction from global registry and + * report error if it does not exist + * \param name The name of the function. + * \return The created PackedFunc. + */ +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BASE_H_ diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h new file mode 100644 index 000000000000..29cde295398d --- /dev/null +++ b/include/tvm/relay/environment.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/environment.h + * \brief The global environment: contains information needed to + * compile & optimize Relay programs. + */ +#ifndef TVM_RELAY_ENVIRONMENT_H_ +#define TVM_RELAY_ENVIRONMENT_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +struct Environment; + +/*! \brief The global environment of Relay programs. + * + * The global environment contains the global + * information needed to compile a Relay program. + * + * It contains all global functions, and configuration + * options. + * + * Many operations require access to the global + * Environment. We pass the Environment by value + * in a functional style as an explicit argument, + * but we will mutate the Environment while optimizing + * Relay programs. + * + * The functional style allows users to construct custom + * environments easily, for example each thread can store + * an Environment while auto-tuning. + * */ + +class EnvironmentNode : public RelayNode { + private: + /*! \brief A map from string names to global variables ensures global + * uniqueness. */ + tvm::Map global_map_; + /*! \brief A map from file names to source fragments. */ + SourceMap source_map_; + /*! \brief A list of the errors reported during the current run. */ + std::vector errors_; + + public: + /*! \brief A map from ids to all global functions. */ + tvm::Map functions; + + EnvironmentNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static Environment make(tvm::Map global_funcs); + + void Add(const GlobalVar& var, const Function& func, bool update = false); + void Update(const GlobalVar& var, const Function& func); + void Remove(const GlobalVar& var); + + /*! \brief Lookup a global function by its variable. */ + GlobalVar GetGlobalVar(const std::string& str); + + /*! \brief Lookup a global function by its variable. */ + Function Lookup(const GlobalVar& id); + + /*! \brief Lookup a global function by its string name */ + Function Lookup(const std::string& s); + + // TODO(@jroesch, @tqchen): what are the semantics here + void Merge(const Environment& env); + + /*! \brief Add a source fragment to the environment. */ + SourceName AddSource(std::string file_name, std::string source); + + using Transformer = runtime::TypedPackedFunc< + runtime::TypedPackedFunc(const Environment&)>; + + /*! \brief Apply a function over every function in the global environment. */ + void Transform(Transformer tranformer); + + void AddDiagnostic(SpannedError); + void DisplayErrors(); + + static constexpr const char* _type_key = "relay.Environment"; + TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); +}; + +struct Environment : public NodeRef { + Environment() {} + explicit Environment(std::shared_ptr p) : NodeRef(p) {} + + inline EnvironmentNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = EnvironmentNode; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ENVIRONMENT_H_ diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h new file mode 100644 index 000000000000..989285d341b3 --- /dev/null +++ b/include/tvm/relay/error.h @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_H_ +#define TVM_RELAY_ERROR_H_ + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +struct Error : dmlc::Error { + explicit Error(const std::string &msg) : dmlc::Error(msg) {} +}; + +struct InternalError : Error { + explicit InternalError(const std::string &msg) : Error(msg) {} +}; + +struct SpannedError { + std::string msg; + Span sp; + SpannedError(const std::string &msg, Span sp) : msg(msg), sp(sp) {} +}; + +// FIX, we should change spanned errors to have a method which allow them to +// report on the Environment, inverting control to error definition. +struct FatalTypeError : dmlc::Error { + explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {} +}; + +struct TypecheckerError : public dmlc::Error { + explicit TypecheckerError(const std::string &msg) : Error(msg) {} +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h new file mode 100644 index 000000000000..521fd57b880d --- /dev/null +++ b/include/tvm/relay/expr.h @@ -0,0 +1,386 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr.h + * \brief Relay expression language. + */ +#ifndef TVM_RELAY_EXPR_H_ +#define TVM_RELAY_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include "./base.h" +#include "./type.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A Relay expression. + */ +class Expr; +/*! + * \brief Base type of the Relay type hiearchy. + */ +class ExprNode : public RelayNode { + public: + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * this value is discarded during serialization. + */ + mutable Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const { + CHECK(checked_type_.defined()) << "internal error: the type checker has " + "not populated the checked_type " + << "field for this node"; + return this->checked_type_; + } + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); + +/*! + * \brief Constant tensor, backed by an NDArray on cpu(0). + * + * \note scalar constants are represented by rank-0 const tensor. + * Constant folding are handled uniformly via Tensor types. + */ +class Constant; +/*! + * \brief Constant tensor type. + */ +class ConstantNode : public ExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return whether it is scalar(rank-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("span", &span); + } + + TVM_DLL static Constant make(runtime::NDArray data); + + static constexpr const char* _type_key = "relay.Constant"; + TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); + +/*! \brief Tuple of multiple Exprs */ +class Tuple; +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + TVM_DLL static Tuple make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.Tuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); + +/*! + * \brief Local variables used in the let expression. + * This is similar to Var that is being used in the low level tensor expression. + * + * \note Each LocalVar is bind only once and is immutable/ + */ +class LocalVar; +/*! \brief Container for LocalVar */ +class LocalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static LocalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.LocalVar"; + TVM_DECLARE_NODE_TYPE_INFO(LocalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(LocalVar, LocalVarNode, Expr); + +/*! + * \brief Global variable that leaves in the top-level environment. + * This is used to enable recursive calls between function. + * + * \note GlobalVar can only corresponds to functions. + */ +class GlobalVar; +/*! \brief A GlobalId from the node's current type to target type. */ +class GlobalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static GlobalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); + +/*! + * \brief Function parameter declaration. + */ +class Param; +/*! \brief A parameter. */ +class ParamNode : public ExprNode { + public: + /*! \brief The variable */ + LocalVar var; + /*! \brief The type of the parameter */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("type", &type); + v->Visit("span", &span); + } + + TVM_DLL static Param make(LocalVar var, Type type); + + static constexpr const char* _type_key = "relay.Param"; + TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr); + +/*! + * \brief Function (subgraph in computational graph) + */ +class Function; +/*! \brief Function container */ +class FunctionNode : public ExprNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! \brief User annotated return type of the function. */ + Type ret_type; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + Expr body; + /*! + * \brief Type parameters of the function. + * Enables the function to vary its type based on these. + * This corresponds to template paramaters in c++'s terminology. + * + * \note This can be usually empty for non-polymorphic functions. + */ + tvm::Array type_params; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("body", &body); + v->Visit("type_params", &type_params); + v->Visit("span", &span); + } + + Type fn_type() const; + + TVM_DLL static Function make(tvm::Array params, Type ret_type, + Expr body, tvm::Array ty_params); + + static constexpr const char* _type_key = "relay.Function"; + TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class Call; +/*! \brief Call container. */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, LocalVar). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("type_args", &type_args); + v->Visit("span", &span); + } + + TVM_DLL static Call make(Expr op, + Array args, + Attrs attrs = Attrs(), + Array ty_args = Array()); + + static constexpr const char* _type_key = "relay.Call"; + TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); + +/*! + * \brief Let binding that binds a local var and optionally a type annotation. + * + * \note Let is useful to transform the program to be A-normal form. + * where each of the expression corresponds to a let binding. + * + * For developers who are familar with the computational graph. + * Each of the let can be viewed as a operator node in the computational graph. + * Traversing the list of let bindings is similar to running + * PostDFS-order(topo-order) traversal on the computational graph. + */ +class Let; +/*! \brief A binding of a sub-network. */ +class LetNode : public ExprNode { + public: + /*! \brief The variable we bind to */ + LocalVar var; + /*! \brief The value we bind var to */ + Expr value; + /*! \brief The body of the let binding */ + Expr body; + /*! \brief type annotation of value, this can be null */ + Type value_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + v->Visit("value_type", &value_type); + v->Visit("span", &span); + } + + TVM_DLL static Let make(LocalVar var, Expr value, Expr body, Type value_type); + + static constexpr const char* _type_key = "relay.Let"; + TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); + +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * let x = if (true) { 1 } else { 0 }; // x is 1 + * let y = if (false) { 1 } else { 0 }; // y is 0 + */ +class If; +/*! \brief container of If */ +class IfNode : public ExprNode { + public: + /*! \brief The condition */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + Expr true_value; + /*! \brief The expression evaluated when condition is false */ + Expr false_value; + + IfNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("cond", &cond); + v->Visit("true_value", &true_value); + v->Visit("false_value", &false_value); + v->Visit("span", &span); + } + + TVM_DLL static If make(Expr cond, Expr true_value, Expr false_value); + + static constexpr const char* _type_key = "relay.If"; + TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(If, IfNode, Expr); + +// template +// T Downcast(U u) { + +// } + +} // namespace relay +} // namespace tvm + +namespace std { + +template<> +struct hash<::tvm::relay::LocalVar> { + std::size_t operator()(const ::tvm::relay::LocalVar & lv) const { + return lv.hash(); + } +}; + +} +#endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h new file mode 100644 index 000000000000..0d736212c9eb --- /dev/null +++ b/include/tvm/relay/expr_functor.h @@ -0,0 +1,116 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ +#define TVM_RELAY_EXPR_FUNCTOR_H_ + +#include +#include +#include "./expr.h" +#include "./op.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LocalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Node* op, Args...) { + throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LocalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); + RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); + RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h new file mode 100644 index 000000000000..0febad503b12 --- /dev/null +++ b/include/tvm/relay/expr_visitor.h @@ -0,0 +1,183 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. + */ +#ifndef TVM_RELAY_EXPR_VISITOR_H_ +#define TVM_RELAY_EXPR_VISITOR_H_ + +#include + +namespace tvm { +namespace relay { + +class ExprVisitor : public ::tvm::relay::ExprFunctor { + public: + void VisitExpr_(const LocalVarNode* op) override { return; } + + void VisitExpr_(const GlobalVarNode* op) override { return; } + + void VisitExpr_(const ConstantNode* op) override { return; } + + void VisitExpr_(const TupleNode* op) override { + for (auto field : op->fields) { + this->VisitExpr(field); + } + } + + void VisitExpr_(const ParamNode* op) override { + this->VisitExpr(op->var); + } + + void VisitExpr_(const FunctionNode* op) override { + for (auto param : op->params) { + this->VisitExpr(param); + } + + this->VisitExpr(op->body); + } + + void VisitExpr_(const CallNode* op) override { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg); + } + } + + void VisitExpr_(const LetNode* op) override { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); + } + + void VisitExpr_(const IfNode* op) override { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_value); + this->VisitExpr(op->false_value); + } + + void VisitExpr_(const OpNode* op) override { return; } + + virtual void VisitType(const Type& t) {} +}; + +class ExprFVisitor : public ::tvm::relay::ExprFunctor { + public: + Expr VisitExpr_(const LocalVarNode* op) override { + return GetRef(op); + } + + Expr VisitExpr_(const ConstantNode* op) override { + return GetRef(op); + } + + Expr VisitExpr_(const GlobalVarNode* op) override { + return GetRef(op); + } + + Expr VisitExpr_(const OpNode* op) override { + return GetRef(op); + } + + Expr VisitExpr_(const TupleNode* op) override { + tvm::Array fields; + for (auto field : op->fields) { + fields.push_back(this->VisitExpr(field)); + } + + return TupleNode::make(fields); + } + + Expr VisitExpr_(const ParamNode* op) override { + Expr var_expr = this->VisitExpr(op->var); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->type); + return ParamNode::make(var, type); + } else { + throw dmlc::Error("the default param visitor has bug"); + } + } + + Expr VisitExpr_(const FunctionNode* op) override { + tvm::Array ty_params; + + for (auto ty : op->type_params) { + Type ty_param_type = VisitType(ty); + if (auto ty_param = ty_param_type.as()) { + auto ty_param_ref = GetRef(ty_param); + ty_params.push_back(ty_param_ref); + } else { + throw dmlc::Error("the default func visitor has bug"); + } + } + + tvm::Array params; + for (auto param : op->params) { + Expr param_expr = this->VisitExpr(param); + if (const ParamNode* param_node = param_expr.as()) { + auto param = GetRef(param_node); + params.push_back(param); + } else { + throw dmlc::Error("the default func visitor has bug"); + } + } + + auto ret_type = this->VisitType(op->ret_type); + auto body = this->VisitExpr(op->body); + return FunctionNode::make(params, ret_type, body, ty_params); + } + + Expr VisitExpr_(const CallNode* call_node) override { + auto fn = this->VisitExpr(call_node->op); + + tvm::Array ty_args; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg); + ty_args.push_back(new_ty_arg); + } + + tvm::Array call_args; + for (auto arg : call_node->args) { + call_args.push_back(this->VisitExpr(arg)); + } + + auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); + + return call; + } + + Expr VisitExpr_(const LetNode* op) override { + Expr var_expr = this->VisitExpr(op->var); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->value_type); + auto value = this->VisitExpr(op->value); + auto body = this->VisitExpr(op->body); + return LetNode::make(var, value, body, type); + } else { + throw dmlc::Error("the default let visitor has error"); + } + } + + Expr VisitExpr_(const IfNode* op) override { + auto guard = this->VisitExpr(op->cond); + auto true_b = this->VisitExpr(op->true_value); + auto false_b = this->VisitExpr(op->false_value); + return IfNode::make(guard, true_b, false_b); + } + + virtual Type VisitType(const Type& t) { return t; } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_VISITOR_H_ diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h new file mode 100644 index 000000000000..c53cd15ee72e --- /dev/null +++ b/include/tvm/relay/logging.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/logging.h + * \brief A wrapper around dmlc-core/logging.h which adds the ability + * to toggle logging via an environment variable. + */ + +#ifndef TVM_RELAY_LOGGING_H_ +#define TVM_RELAY_LOGGING_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +static bool logging_enabled() { + if (auto var = std::getenv("RELAY_LOG")) { + std::string is_on(var); + return is_on == "1"; + } else { + return false; + } +} + +#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled()) + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_LOGGING_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..7d0a58265565 --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,473 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Primitive operator definition. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "../attrs.h" +#include "./base.h" +#include "./expr.h" +#include "./type.h" + +namespace tvm { +namespace relay { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +/*! + * \brief Node container of operator structure. + */ +class OpNode : public relay::ExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! \brief the type of the operator */ + mutable FuncType op_type; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to + */ + std::string attrs_type_key; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("op_type", &op_type); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; +}; + +/*! + * \brief Operator reference class. + */ +class Op : public relay::Expr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(std::shared_ptr n) : Expr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static OpMap GetAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); +}; + +/*! \brief Helper structure to register operators */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { return op_; } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation + * on variables. + * \return reference to self. + */ + inline OpRegistry& add_type_rel(const std::string& type_rel_name, + TypeRelationFn type_rel); + + /*! + * \brief Attach the type function corresponding to the return type. + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation + * on variables. + * \return reference to self. + */ + inline OpRegistry& add_type_rel( + const std::string& type_rel_name, + std::function(const Array&, int)> type_rel); + + /*! + * \brief Set the type key of attributes. + * \param type_key The type of of the attrs field.x + * \return reference to self. + */ + inline OpRegistry& set_attrs_type_key(const std::string& type_key); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, int plevel = 10); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + if (get()->name.length() == 0) { + get()->name = name; + } + return *this; + } + /*! \return The global single retistry */ + TVM_DLL static ::dmlc::Registry* Registry(); + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + +// internal macros to make +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp + +/*! + * \def RELAY_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * RELAY_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(node_.get()); +} + +template +inline OpMap Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { + std::shared_ptr n = std::make_shared(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::add_type_rel( + const std::string& type_func_name, + std::function(const Array&, int)> type_fn) { + auto pfunc = + runtime::TypedPackedFunc(const Array&, int)>(type_fn); + return add_type_rel(type_func_name, pfunc); +} + +inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, + TypeRelationFn type_fn) { + auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); + + std::vector type_params; + std::vector arg_types; + + // Add inputs. + int i = 0; + for (auto arg : get()->arguments) { + std::string name = "in"; + name += std::to_string(i++); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + auto ty_call_args = Array(arg_types); + + // Add output type. + auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); + type_params.push_back(out_param); + ty_call_args.push_back(out_param); + + auto type_result = TypeCallNode::make(type_func, ty_call_args); + + auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + + get()->op_type = func_type; + + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) + const std::string& type_key) { + get()->attrs_type_key = type_key; + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { + return map_.get(op, def_value); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h new file mode 100644 index 000000000000..730118e4eaed --- /dev/null +++ b/include/tvm/relay/pass.h @@ -0,0 +1,65 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pass.h + * \brief The set of Relay passes written in C++. + */ +#ifndef TVM_RELAY_PASS_H_ +#define TVM_RELAY_PASS_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Infer the type of an expression with the provided environment. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \param env The environment used for global settings and referencing + * global functions. + * + * \param e The expression to type check. + * + * \return A type checked expression with its checked_type field populated. + */ +Expr InferType(const Environment& env, const Expr& e); +Expr InferType(const Environment& env, const GlobalVar & v, const Function & e); + +/*! + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always contains + * a data type such as `int`, `float`, `uint`. + * + * \param env The global environment. + * \param t The type to check. + * \return true if the rules are satisified otherwise false + */ +bool KindCheck(const Environment& env, const Type& t); + +/*! brief Check that no LocalVar got shadowed. + * + * Roughly speaking, a LocalVar is considered to be shadowed, if it was introduce while it was already in scoped. + * + * For example, the expression `let x = 1 in let x = 2 in 3` shadow x. + * + * However, `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` does not shadow x, f, g. + * x is not shadowed because it is introduce at other scoped - the x used by f is invisible to the x used by g. + * + * \param e the expression to check. + * + * \return true iff e has no shadowing. + */ + bool LocalVarWellFormed(const Expr & e); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_H_ diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h new file mode 100644 index 000000000000..b6d98bd68940 --- /dev/null +++ b/include/tvm/relay/pass/alpha_eq.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pass/alpha_eq.h + * \brief Check expressions and types for structural equivalence. + */ +#ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ +#define TVM_RELAY_PASS_ALPHA_EQ_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Compare two expressions for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `let x = 1 in x` is equal to `let y = 1 in y`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param e1 The left hand expression. + \param e2 The right hand expression. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Expr& e1, const Expr& e2); + +/*! \brief Compare two types for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `forall s, Tensor[f32, s]` is equal to + `forall w, Tensor[f32, w]`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param t1 The left hand type. + \param t2 The right hand type. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Type& t1, const Type& t2); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_ALPHA_EQ_H_ + diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h new file mode 100644 index 000000000000..277c3875a17f --- /dev/null +++ b/include/tvm/relay/source_map.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.h + * \brief A representation of source files and a data structure for + * storing them. + */ +#ifndef TVM_RELAY_SOURCE_MAP_H_ +#define TVM_RELAY_SOURCE_MAP_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief A fragment of a source file used for error reporting. + * + * These can be registered by the frontends and are used for + * displaying errors. + */ +struct SourceFragment { + /*! \brief The file name which the source fragment originates from. */ + std::string file_name; + /*! \brief The lines of source corresponding to the fragment. */ + std::vector source_lines; + + SourceFragment(const std::string& file_name, const std::string& source); + + SourceFragment(const SourceFragment& sf) { + this->file_name = sf.file_name; + this->source_lines = sf.source_lines; + } + + /*! \brief The lines of source code originate at lines. */ + std::string SourceAt(Span sp, int lines); +}; + +/*! \brief Maps from FileId's to a SourceFragment. + */ +class SourceMap { + /*! \brief Map from unique token to a fragment of a source file. */ + std::unordered_map map_; + + public: + SourceMap() : map_() {} + /*! \brief Add a source fragment with the file name and source. */ + SourceName AddSource(const std::string& file_name, const std::string& source); + /*! \brief Retrieve a source fragment by source name. */ + const SourceFragment& GetSource(SourceName id) const; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_SOURCE_MAP_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h new file mode 100644 index 000000000000..f485e0d8d62f --- /dev/null +++ b/include/tvm/relay/type.h @@ -0,0 +1,315 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type.h + * \brief Relay typed AST nodes. + */ +#ifndef TVM_RELAY_TYPE_H_ +#define TVM_RELAY_TYPE_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type of the Relay type hiearchy. */ +class TypeNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Type"; + TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); +}; + +/*! + * \brief Type is the base type of relay type hiearchy. + * + * Relay's type system contains following two key concepts: + * + * - TensorType: type of certain Tensor values in the expression. + * - FunctionType: the type of the function. + * + * There are also advanced types to support generic(polymorphic types), + * which can be ignored when first reading the code base. + */ +class Type : public NodeRef { + public: + Type() {} + explicit Type(std::shared_ptr p) : NodeRef(p) {} + + using ContainerType = TypeNode; +}; + +/*! + * \brief Base of all Tensor types + * This container can hold TensorType or GenericTensorType. + */ +class BaseTensorTypeNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.BaseTensorType"; + TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); + +/*! + * \brief This is the most commonly used type in relay. + * TensorType have a fixed dimension, data type. + * + * The elements of shape can be either IntImm(constant integer), + * or any symbolic integer expression. + * The symbolic integer allows generic shape inference in certain cases. + * \sa TensorTypeNode The container class of TensorType. + */ +class TensorType; +/*! \brief TensorType container node */ +class TensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The shape of the tensor, + * represented by ShapeExpr(tvm::Expr). + */ + Array shape; + /*! \brief The content data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + TVM_DLL static TensorType make(Array shape, DataType dtype); + + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType Int(int bits, int lanes = 1); + + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType UInt(int bits, int lanes = 1); + + /*! \brief Construct a floating-point type */ + TVM_DLL static TensorType Float(int bits, int lanes = 1); + + /*! \brief Construct a boolean type */ + TVM_DLL static TensorType Bool(int lanes = 1); + + static constexpr const char* _type_key = "relay.TensorType"; + TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); +}; + +RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); + +/*! + * \brief Type parameter in the function. + * This can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeParam of f is TypeParam(kind=kShapeVar, var=n). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeParamNode The actual container class of TypeParam + */ +class TypeParam; +/*! \brief TypeParam container node */ +class TypeParamNode : public TypeNode { + public: + /*! \brief possible kinds of TypeParam */ + enum Kind : int { + /*! \brief template variable in shape expression */ + kShapeVar = 0, + kShape = 1, + kBaseType = 2, + kType = 3, + }; + /*! + * \brief The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static TypeParam make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.TypeParam"; + TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); + +/*! + * \brief Potential Constraints in the type. + * \note This is reserved for future use. + */ +class TypeConstraint; +/*! \brief TypeConstraint container node. */ +class TypeConstraintNode : public Node { + public: + static constexpr const char* _type_key = "relay.TypeConstraint"; + TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, Node); +}; + +RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, NodeRef); + +class FuncType; +/*! + * \brief Function type in Relay. + * + * Relay support polymorphic function type. + * This can be roughly viewed as template function in C++. + * + * \sa TypeParam, TypeConstraint + */ +class FuncTypeNode : public TypeNode { + public: + /*! \brief type type of arguments */ + tvm::Array arg_types; + /*! \brief The type of return value. */ + Type ret_type; + // The following fields are used in polymorphic(template) functions + // For normal functions, the following two fields will be empty. + /*! \brief The type parameters of the function */ + tvm::Array type_params; + /*! + * \brief potential constraint the type need to obey + * \note this field is reserved for futher purposes. + */ + tvm::Array type_constraints; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("arg_types", &arg_types); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("type_constraints", &type_constraints); + v->Visit("span", &span); + } + + TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints); + + static constexpr const char* _type_key = "relay.FuncType"; + TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); + +using TypeRelationFn = + runtime::TypedPackedFunc(const Array&, int)>; + +/*! + * \brief Opaque type relation, is an input-output relation on types. + */ +class TypeRelation; +/*! + * \brief TypeRelation container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the environment. + */ +class TypeRelationNode : public RelayNode { + public: + /*! \brief The name of the function */ + std::string name; + /*! \brief Number of input type arguments, can be -1, which means VarArgs */ + int num_args; + /*! + * \brief The function on input and output variables which + * this is not directly serializable, + * need to be looked-up in the environment. + */ + TypeRelationFn func_; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("num_args", &num_args); + } + + TVM_DLL static TypeRelation make(std::string name, int num_args, + TypeRelationFn func_); + + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, Type); + +/*! + * \brief Call a type function with some number of arguments. + */ +class TypeCall; +/*! + * \brief TypeCall container. + */ +class TypeCallNode : public TypeNode { + public: + /*! \brief The type function to be called. */ + Type func; + + /*! \brief The type arguments to the type function. */ + tvm::Array args; + + TypeCallNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + } + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + +/*! + * \brief The type of tuple values. + */ +class TupleType; +/*! + * \brief TupleType container. + */ +class TupleTypeNode : public TypeNode { + public: + /*! \brief The type of each field in the tuple. */ + tvm::Array fields; + + TupleTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + + TVM_DLL static TupleType make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.TypeTuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); + +// The following fields contains advanced typing +// Only keep the class name and reserved for future usage. +class GenericTensorType; +// stores a DataType. +class GenericDataType; +// stores a DataType. +class GenericShape; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py new file mode 100644 index 000000000000..c254c7e9ce7a --- /dev/null +++ b/python/tvm/relay/__init__.py @@ -0,0 +1,35 @@ +# pylint: disable=wildcard-import +"""The Relay IR namespace containing the IR definition and compiler.""" +from . import base +from . import type as tpe +from . import expr +from . import to_tvm +from . import env +from . import ir_pass +from . import ir_builder +# Operators +from .op import Op +from .op.tensor import * + +# Span +Span = base.Span + +# Type +Type = tpe.Type +TensorType = tpe.TensorType +Kind = tpe.Kind +TypeParam = tpe.TypeParam +TypeConstraint = tpe.TypeConstraint +FuncType = tpe.FuncType + +# Expr +Constant = expr.Constant +Tuple = expr.Tuple +LocalVar = expr.LocalVar +GlobalVar = expr.GlobalVar +Param = expr.Param +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If +Var = LocalVar diff --git a/python/tvm/relay/_env.py b/python/tvm/relay/_env.py new file mode 100644 index 000000000000..25b8715a7816 --- /dev/null +++ b/python/tvm/relay/_env.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface to the Environment exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._env", __name__) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi new file mode 100644 index 000000000000..c6b5d0f6c4bd --- /dev/null +++ b/python/tvm/relay/_env.pyi @@ -0,0 +1,5 @@ +from typing import Union, Tuple, Dict, List +from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import ShapeExtension, Operator, Defn + +class Environment(NodeBase): ... \ No newline at end of file diff --git a/python/tvm/relay/_ir_pass.py b/python/tvm/relay/_ir_pass.py new file mode 100644 index 000000000000..61fdcfa38c2f --- /dev/null +++ b/python/tvm/relay/_ir_pass.py @@ -0,0 +1,5 @@ +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._ir_pass", __name__) diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi new file mode 100644 index 000000000000..1bb42ab854c2 --- /dev/null +++ b/python/tvm/relay/_ir_pass.pyi @@ -0,0 +1,6 @@ +from .env import Environment +from . import ir + +def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... +def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... +def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/_make.py b/python/tvm/relay/_make.py new file mode 100644 index 000000000000..20a582e76d6a --- /dev/null +++ b/python/tvm/relay/_make.py @@ -0,0 +1,9 @@ +""" +The constructors for all Relay AST nodes exposed from C++. + +This module includes MyPy type signatures for all of the +exposed modules. +""" +from .._ffi.function import _init_api + +_init_api("relay._make", __name__) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py new file mode 100644 index 000000000000..0f3d2bc58d71 --- /dev/null +++ b/python/tvm/relay/base.py @@ -0,0 +1,30 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck +"""The base node types for the Relay language.""" +from __future__ import absolute_import as _abs +from .._ffi.node import NodeBase, register_node as _register_tvm_node +from . import _make + +NodeBase = NodeBase + +def register_relay_node(type_key=None): + """register relay node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + +@register_relay_node +class Span(NodeBase): + source: "FileSource" + lineno: int + col_offset: int + + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py new file mode 100644 index 000000000000..93cbe1bca284 --- /dev/null +++ b/python/tvm/relay/env.py @@ -0,0 +1,64 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import +"""A global environment storing everything needed to interpret or compile a Relay program.""" +from .base import register_relay_node, NodeBase +from . import _make +from . import _env + +@register_relay_node +class Environment(NodeBase): + """The global Relay environment containing functions, + options and more. + """ + def __init__(self, funcs) -> None: + """Construct an environment. + + Parameters + ------ + funcs: list of relay.Function + + Returns + ------ + env: A new environment containing :py:class:`~relay.env.Environment`. + """ + self.__init_handle_by_constructor__(_make.Environment, funcs) + + def add(self, var, func) -> None: + """Add a function to the environment. + + Parameters + --------- + var: GlobalVar + The global variable which names the function. + + func: Function + The function. + """ + if isinstance(var, str): + var = _env.Environment_GetGlobalVar(self, var) + + _env.Environment_Add(self, var, func) + + def merge(self, other): + """Merge two environments. + + Parameters + ---------- + other: Environment + The environment to merge into the current Environment. + """ + return _env.Environment_Merge(self, other) + + def global_var(self, var): + """Get a global variable by name.""" + return _env.Environment_GetGlobalVar(self, var) + + def lookup(self, var): + """Lookup a global function by name or by variable.""" + if isinstance(var, str): + return _env.Environment_Lookup_str(self, var) + else: + return _env.Environment_Lookup(self, var) + + def transform(self, transformer): + """Apply a transformer function to the environment.""" + _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py new file mode 100644 index 000000000000..3cdaed89d2fb --- /dev/null +++ b/python/tvm/relay/expr.py @@ -0,0 +1,141 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression nodes of Relay.""" +from typing import List +import tvm +from .base import Span, NodeBase, register_relay_node +from .type import Type, TypeParam +from ._ir_pass import _get_checked_type +from . import _make + + +class ExprBuilder(): + """A set of methods useful for building expressions + from other expressions. + """ + def __call__(self, *args, **kwargs): + converted_args = [] + for arg in args: + if isinstance(arg, Param): + converted_args.append(arg.var) + else: + converted_args.append(arg) + + return Call(self, args, None, None) + + +class Expr(NodeBase, ExprBuilder): + """The base type for all Relay exprressions.""" + + def checked_type(self): + return _get_checked_type(self) + + +@register_relay_node +class Constant(Expr): + """A constant tensor in Relay, see tvm/relay/type.h for more details. + """ + data: tvm.nd.NDArray + + def __init__(self, data: tvm.nd.NDArray) -> None: + self.__init_handle_by_constructor__(_make.Constant, data) + + +@register_relay_node +class Tuple(Expr): + """A hetereogenous sequence of values. + see tvm/relay/type.h for more details. + """ + fields: List[Expr] + + def __init__(self, fields: List[Expr]) -> None: + self.__init_handle_by_constructor__(_make.Tuple, fields) + + +@register_relay_node +class LocalVar(Expr): + """A local variable in Relay.""" + name_hint: str + + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + + +@register_relay_node +class GlobalVar(Expr): + """A global variable in Relay.""" + name_hint: str + + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + + +@register_relay_node +class Param(Expr): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + var: LocalVar + type: Type + + def __init__(self, var: LocalVar, ty: Type) -> None: + self.__init_handle_by_constructor__(_make.Param, var, ty) + + +@register_relay_node +class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" + type_params: List[TypeParam] + params: List[Param] + ret_type: Type + body: Expr + + def __init__(self, + params: List[Param], + ret_type: Type, + body: Expr, + type_params: List[TypeParam] = None) -> None: + if not type_params: + type_params = [] + self.__init_handle_by_constructor__( + _make.Function, params, ret_type, body, type_params) + + +@register_relay_node +class Call(Expr): + """A function call in Relay, see tvm/relay/expr.h for more details.""" + op: Expr + args: List[Expr] + # todo(@jroesch): add attrs + + def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) + + +@register_relay_node +class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" + var: LocalVar + value: Expr + body: Expr + # should be type annotation + value_type: Type + + def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: + self.__init_handle_by_constructor__( + _make.Let, var, value, body, value_type) + + +@register_relay_node +class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" + cond: Expr + true_value: Expr + false_value: Expr + span: Span + + def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: + self.__init_handle_by_constructor__( + _make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py new file mode 100644 index 000000000000..9700ea955f59 --- /dev/null +++ b/python/tvm/relay/from_nnvm.py @@ -0,0 +1,7 @@ +#pylint: disable-all +"""Convert an nnvm.graph.Graph into a tvm.relay.Expr""" +import nnvm + +def from_nnvm(graph): + """Convert an nnvm.graph.Graph into a tvm.relay.Expr""" + raise Exception("NYI") diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py new file mode 100644 index 000000000000..c0c2e76c1157 --- /dev/null +++ b/python/tvm/relay/ir_builder.py @@ -0,0 +1,295 @@ +"""IR builder for the Relay IR. + +Enables users to construct Relay programs with a Python API. +""" +from typing import Any +import numpy as np +import tvm +from .type import FuncType, TensorType +from .expr import Expr, Constant, Let, LocalVar, Param, Function, If +from .env import Environment + + +def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: + """Convert Python values into the appropriate types + for the Relay evaluator. + """ + if isinstance(arg, int): + return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) + elif isinstance(arg, float): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, bool): + return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) + elif isinstance(arg, np.ndarray): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, tvm.ndarray.NDArray): + return arg + else: + # raise Exception(f"can't convert {type(arg)} to a Relay AST") + raise Exception(f"unsupported argument type {type(arg)}") + + +def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: + if isinstance(arg, Expr): + return arg + elif isinstance(arg, tuple): + raise Exception("..") + elif isinstance(arg, PartialFunc): + return arg.to_func() + else: + value = convert(arg, ctxt) + return Constant(value) + + +class WithScope(object): + """A wrapper for builder methods which introduce scoping.""" + + def __init__(self, enter_value, exit_cb): + self._enter_value = enter_value + self._exit_cb = exit_cb + + def __enter__(self): + return self._enter_value + + def __exit__(self, ptype, value, trace): + if value: + raise value + else: + self._exit_cb() + + +class PartialFunc(): + """A wrapper around functions while they are being built.""" + def __init__(self, params, ret_type, body, type_params): + self.params = params + self.ret_type = ret_type + self.body = body + self.type_params = type_params + + def param_ids(self): + return [p.var for p in self.params] + + def to_func(self): + return Function( + self.params, + self.ret_type, + self.body, + self.type_params) + +#pylint: disable=invalid-name +def _mk_let(bindings, ret_value): + let_expr = ret_value + for var, (value, ty) in reversed(list(bindings.items())): + let_expr = Let(var, value, let_expr, ty) + + return let_expr + + +class IRBuilder(): + """The IRBuilder class. + + Enables users to build up a Relay environment and program. + """ + def __init__(self): + self.bindings = [{}] + self.scopes = [{}] + self.params = [] + self.ret_values = [None] + self.env = Environment({}) + + def enter_scope(self, params=None): + if not params: + params = [] + + self.bindings.append({}) + self.scopes.append({}) + self.params.append(params) + self.ret_values.append(None) + + def exit_scope(self): + bindings = self.bindings.pop() + scopes = self.scopes.pop() + params = self.params.pop() + ret_value = self.ret_values.pop() + return bindings, scopes, params, ret_value + + #pylint: disable=invalid-name + def bind(self, name, value, ty): + lv = LocalVar(name) + self.scopes[-1][name] = lv + self.bindings[-1][lv] = (value, ty) + return lv + + def let(self, name, value, value_type=None): + if isinstance(value, Param): + value = value.var + + if not isinstance(value, Expr): + value = into_ast(value) + + return self.bind(name, value, value_type) + + def _convert_params(self, raw_params): + relay_params = [] + for raw_param in raw_params: + if isinstance(raw_param, Param): + var = raw_param.var + param = raw_param + elif isinstance(raw_param, tuple): + var, ty = raw_param + if isinstance(var, str): + var = LocalVar(var) + param = Param(var, ty) + elif isinstance(param, str): + var = LocalVar(raw_param) + ty = None + param = Param(var, ty) + else: + raise Exception("unknown parameter type") + + self.scopes[-1][var.name_hint] = var + relay_params.append(param) + + return relay_params + + def function(self, *params): + """Construct a Relay function.""" + + relay_params = self._convert_params(params) + + # self.params.append(relay_params) + + self.enter_scope() + + pfunc = PartialFunc(relay_params, None, None, []) + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + body = _mk_let(bindings, ret_value) + pfunc.body = body + + return WithScope(pfunc, _on_exit) + + def ret(self, x): + if not self.ret_values[-1]: + self.ret_values[-1] = into_ast(x) + else: + raise Exception( + "return value already set, a function can only have one return value") + + def if_scope(self, cond): + """Construct the if branch an if expression with scoping.""" + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + assert self.ret_values[-1] is None + true_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If(cond, true_branch, None) + + return WithScope(10, _on_exit) + + def else_scope(self): + """Construct the else branch of an if expression with scoping.""" + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + partial_if = self.ret_values[-1] + assert isinstance( + partial_if, If) and partial_if.false_value is None + false_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If( + partial_if.cond, + partial_if.true_value, + false_branch) + + return WithScope(10, _on_exit) + + def param(self, name, ty=None): + if not ty: + ty = float_type() + + return Param(LocalVar(name), ty) + + # def params(*args): + # i = 0 + # while i < args.size(): + # arg = args[i] + # if isinstance(arg, str): + + def global_var(self, name: str): + return self.env.global_var(name) + + def decl(self, name: str, *params, ret_type=None): + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + exp = _mk_let(bindings, ret_value) + self.env.add(name, Function(params, ret_type, exp)) + + return WithScope(10, _on_exit) + + # def while_loop(cond) + + def get(self): + """Get the full program""" + bindings = self.bindings.pop() + scope = self.scopes.pop() + + if self.bindings: + raise Exception("IRBuilder: binding error") + + if self.scopes: + raise Exception("IRBuilder: scoping error") + + if bindings and scope and not self.ret_values: + raise Exception("IRBuilder: no return value set") + + return _mk_let(bindings, self.ret_values[-1]), self.env + + +def bool_dtype(): + return 'uint1' + + +def int_dtype(bits=32): + return f'int{bits}' + + +def float_dtype(bits=32): + return f'float{bits}' + + +def uint_dtype(bits=32): + return f'uint{bits}' + + +def int_type(bits=32, _lanes=1): + # TODO(@jroesch, @tqchen) How do we set lanes? + return TensorType(tvm.convert([]), int_dtype(bits)) + + +def uint_type(bits=32, _lanes=1): + return TensorType(tvm.convert([]), uint_dtype(bits)) + + +def float_type(bits=32, _lanes=1): + return TensorType(tvm.convert([]), float_dtype(bits)) + + +def bool_type(_lanes=1): + return TensorType(tvm.convert([]), bool_dtype()) + + +def tensor_type(*shape, dtype='float32'): + return TensorType(tvm.convert(shape), dtype) + + +def func_type(args, ret_type, type_params=None, type_constraints=None): + if not type_params: + type_params = [] + if not type_constraints: + type_constraints = [] + return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py new file mode 100644 index 000000000000..b075704c212a --- /dev/null +++ b/python/tvm/relay/ir_pass.py @@ -0,0 +1,232 @@ +# pylint: disable=no-else-return, +# pylint: disable=unidiomatic-typecheck +"""The optimizer for Relay. + +Exposes an interface for configuring the optimizer and scripting +it directly in Python. +""" +from typing import TypeVar, Generic, Union +from typing import Dict, Tuple, List, Callable +import tvm + +from .expr import Expr +from .expr import Function, Let, Call, LocalVar +from .expr import GlobalVar, If, Constant +from .type import Type, TypeParam +from .env import Environment +from .op import Op +from .op.op import specialize_op +# import relay.make as relay_mk +# from relay import ir +# from relay.env import Environment +# from relay.tyck import check_expr +# from relay.first_order_reverse_ad import fo_with_gradient +# from relay.anf import to_anf +from . import _ir_pass + +# Expose checking expression, should rename to infer_type. +# pylint: disable=invalid-name +check_expr = _ir_pass.check_expr + +# # pylint: disable=invalid-name +# concretize = _opt.concretize + +# # pylint: disable=invalid-name +# optimize = _opt.optimize + +# # pylint: disable=invalid-name +# type_specialize = _opt.type_specialize + +# # pylint: disable=invalid-name +# compile_ops_to_module = _opt.compile_ops_to_module + + +@tvm.register_func("relay.mangle") +def mangle(name: str, types: List[Type]) -> str: + for typ in types: + name += str(typ) + "_" + return name + + +T = TypeVar('T') + + +class AbstractExprVisitor(Generic[T]): + """A functional visitor over Expr in Python.""" + + # pylint: disable=no-else-return + def visit(self, expr: Expr) -> T: + """Apply the visitor to an expression.""" + if isinstance(expr, Function): + return self.visit_function(expr) + elif isinstance(expr, Call): + return self.visit_call(expr) + elif isinstance(expr, Let): + return self.visit_let(expr) + elif isinstance(expr, LocalVar): + return self.visit_local_var(expr) + elif isinstance(expr, GlobalVar): + return self.visit_global_var(expr) + elif isinstance(expr, If): + return self.visit_if(expr) + elif isinstance(expr, Tuple): + return self.visit_tuple(expr) + elif isinstance(expr, Constant): + return self.visit_constant(expr) + else: + raise Exception(f"warning unhandled case: {type(expr)}") + + def visit_function(self, _: Function) -> T: + raise Exception("Abstract method please implement me.") + + def visit_let(self, _: Let) -> T: + raise Exception("Abstract method please implement me.") + + def visit_call(self, _: Call) -> T: + raise Exception("Abstract method please implement me.") + + def visit_local_id(self, _: LocalVar) -> T: + raise Exception("Abstract method please implement me.") + + def visit_type(self, typ: Type) -> Type: + return typ + + def visit_if(self, _: If) -> T: + raise Exception("Abstract method please implement me.") + + def visit_tuple(self, _: Tuple) -> T: + raise Exception("Abstract method please implement me.") + + def visit_constant(self, _: Constant) -> T: + raise Exception("Abstract method please implement me.") + + def visit_global_var(self, _: GlobalVar) -> T: + raise Exception("Abstract method please implement me.") + + @classmethod + def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: + def _outer_wrapper(env): + visitor = cls(env) + + def _inner_wrapper(_, func): + return visitor.visit(func) + return _inner_wrapper + return _outer_wrapper + + +class ExprVisitor(AbstractExprVisitor[Expr]): + """A functional visitor over Expr in Python.""" + + def visit_function(self, fn: Function) -> Expr: + new_body = self.visit(fn.body) + return Function( + list(fn.params), + fn.ret_type, new_body, + fn.type_params) + + def visit_let(self, let: Let) -> Expr: + new_var = self.visit(let.var) + new_value_type = self.visit_type(let.value_type) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body, new_value_type) + + def visit_call(self, call: Call) -> Expr: + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_local_var(self, local_var: LocalVar) -> Expr: + return local_var + + def visit_global_id(self, global_var: GlobalVar) -> Expr: + return global_var + + def visit_if(self, ite: If) -> Expr: + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup: Tuple) -> Expr: + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_constant(self, const: Constant) -> Expr: + return const + + +MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] + + +class Monomorphize(ExprVisitor): + """A monomorphization pass. + + Implements what is known as "monomorphization" in + classic compiler literature. This pass removes + polymorphism replacing calls to functions and + operators with type specialized versions. + """ + monomorph_map: Dict[MMCacheKey, Union[Op, Function]] + + # pylint: disable=super-init-not-called + def __init__(self, env: Environment) -> None: + self.env = env + # Stores (GlobalVar, Type), should eventually store attributes. + self.monomorph_map = {} + + # pylint: disable=no-else-return + def visit_call(self, call: Call) -> Expr: + cache_key = (call.op, call.type_args) + new_args = [self.visit(arg) for arg in call.args] + + if cache_key in self.monomorph_map: + op = self.monomorph_map[cache_key] + new_args = [self.visit(arg) for arg in call.args] + return Call(op, new_args, call.attrs) + else: + if isinstance(call.op, Op): + poly_name = call.op.name + mono_name = mangle(poly_name, call.type_args) + for arg in call.type_args: + if isinstance(arg, TypeParam): + # raise Exception("...") # Fix me in the morning!!! + return call + + mono_op = specialize_op(poly_name, mono_name, call.type_args) + self.monomorph_map[cache_key] = mono_op + return Call(mono_op, new_args, call.attrs, []) + elif isinstance(call.op, GlobalVar): + return call + # defn = self.env.lookup(call.op) + # new_id = self.env.global_id(defn.id.name + str(1)) + # cache_key = (call.op, call.type_args) + # self.monomorph_map[cache_key] = new_id + # new_body = self.visit(type_specialize(call.type_args, defn.body)) + # new_body = Function( + # [], new_body.params, new_body.ret_type, new_body.body) + # new_ty = check_expr(self.env, new_body) + # # TODO(@jroesch): move into C++ + # # TODO(@joresch): implement and call name mangler + # defn = Defn(new_id, new_ty, new_body) + # self.env.add(defn) + # self.visit_item(defn) + # return Call(new_id, call.args, call.attrs) + + elif isinstance(call.op, Function): + return call + # new_func = type_specialize(call.type_args, call.op) + # new_func = self.visit(new_func) + # new_func = Function([], + # new_func.params, + # new_func.ret_type, + # new_func.body) + # check_expr(self.env, new_func) + # return Call(new_func, call.args, call.attrs) + else: + new_fn = self.visit(call.op) + return Call(new_fn, new_args, call.attrs) + + +# TODO(@jroesch): Fix up my type +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.cpu() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py new file mode 100644 index 000000000000..5c3a8ac249a6 --- /dev/null +++ b/python/tvm/relay/op/__init__.py @@ -0,0 +1,12 @@ +#pylint: disable=wildcard-import +"""Relay core operators.""" +# operator defs +from .op import get, register, Op, compile_ops + +# Operators +from .tensor import * + +# operator registry +from . import _tensor +from ..expr import Expr +from ..base import register_relay_node diff --git a/python/tvm/relay/op/_make.py b/python/tvm/relay/op/_make.py new file mode 100644 index 000000000000..79c86cbb0254 --- /dev/null +++ b/python/tvm/relay/op/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ..._ffi.function import _init_api + +_init_api("relay.op._make", __name__) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py new file mode 100644 index 000000000000..2a0ecc6c8550 --- /dev/null +++ b/python/tvm/relay/op/_tensor.py @@ -0,0 +1,60 @@ +#pylint: disable=invalid-name +"""Backend compiler related feature regsitration""" +from topi import add +from .op import register +from ..type import FuncType, TensorType +from ...schedule import create_schedule +from ...api import placeholder + +def type_to_placeholder(name, ty): + """Convert a single type into the correct placeholder.""" + if isinstance(ty, TensorType): + return placeholder(ty.shape, name=name, dtype=ty.dtype) + else: + raise Exception("can only pass Tensor values to TVM operators") + +def func_ty_to_placeholders(func_ty): + """Build input placeholders based on a function type.""" + if isinstance(func_ty, FuncType): + arg_types = func_ty.arg_types + ret_type = func_ty.ret_type + args = [] + var = 0 + for arg in arg_types: + var += 1 + args.append(type_to_placeholder(f"Input{var}", arg)) + return args, ret_type + else: + raise Exception("error") + +# def lookup_in_topi(name): +# try: +# f = eval(f"topi.{name}") +# except: +# f = eval(f"topi.nn.{name}") + +# return f + +# @tvm.register_func("nnvm.relay._default_op_compiler") +# def _default_op_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: +# Inputs, ret_ty = func_ty_to_placeholders(func_ty) +# op = lookup_in_topi(op_name) +# Output = op(*Inputs) + +# if Output.dtype == 'uint1': +# import pdb; pdb.set_trace() +# Output = Output.astype('uint8') + +# schedule = tvm.create_schedule(Output.op) +# return [schedule, Inputs + [Output]] + +#pylint: disable=duplicate-argument-name +def add_compiler(_, func_type, *__): + """The compilation code for the TVM compiler.""" + inputs, _ = func_ty_to_placeholders(func_type) + # op = lookup_in_topi(op_name) + output = add(*inputs) + schedule = create_schedule(output.op) + return [schedule, inputs + [output]] + +register("add", "FRelayOpCompiler", add_compiler) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py new file mode 100644 index 000000000000..14570b62269b --- /dev/null +++ b/python/tvm/relay/op/op.py @@ -0,0 +1,148 @@ +"""The base node types for the Relay language.""" +from ..._ffi.function import _init_api + +from ..base import register_relay_node +from ..expr import Expr +from ..._ffi.function import register_func +from ... import lower, build + + +@register_relay_node +class Op(Expr): + """A Relay operator definition.""" + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _OpGetAttr(self, attr_name) + + +def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _GetOp(op_name) + + +def register(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _Register(op_name, attr_key, v, level) + return v + return _register(value) if value else _register + + +def compile_ops(op_names): + """Register an operator property of an operator. + + + Parameters + ---------- + op_names : List[str] + A list of operator names to compile to machine code. + + Returns + ------- + A module containing the compiled TVM operators. + """ + return _CompileOpsToModule(*op_names) + +# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. + + +@register_func("relay.op._compile_ops") +def _compile_ops(op_impls): + lowered = [] + for local, sch, inputs in op_impls: + lfn = lower(sch, inputs, name=local.name_hint) + lowered.append(lfn) + + # TOOD(@jroesch): Where should we read these settings from + return build(lowered, target='llvm', target_host='llvm') + + +_init_api("relay.op", __name__) + + +def specialize_op(op_name, new_op_name, type_args): + """Specializes an operator to a set of types and assigns it new_op_name. + + The idea is to take operators with generic types such as broadcasting + addition: + + add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) + + This is a function which is polymorphic over two types `T` and `U` and + takes a value of type `T` and one of `U` and returns `Broadcast` of U + and T. + + Broadcast is a type relation which relates U and T to an output type. + + The idea is that the above type is shorthand for: + + add : forall (T : Type) (U : Type) (O : Type), Broadcast(U, T, O) => (U, T) -> O + + That is a function from U and T to O where the typing relation between the values + is specified by Broadcast. + + We implement a basic Broadcasting rule in `type_relations.h` but users can specify + their own. + + If we know T=Tensor[(10, 10), dtype], U=Tensor[(10, 10), dtype] then the result + should be Tensor[(10, 10), dtype]. + + We can use SpecializeOp to implement this change of operator. + + Parameters + ---------- + op_name : str + The operator to be specialized. + + Returns + ------- + The specialized operator. + """ + return _SpecializeOp(op_name, new_op_name, type_args) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py new file mode 100644 index 000000000000..57fbccf488dc --- /dev/null +++ b/python/tvm/relay/op/tensor.py @@ -0,0 +1,100 @@ +"""Basic tensor operations.""" +from __future__ import absolute_import as _abs +from . import _make + +# We create a wrapper function for each operator in the +# python side to call into the positional _make.OpName function. +# +# We make this decision so that we can: +# - Have declare python docstring for each function +# - Enable keyword arguments easily +# - Not put too much burden on FFI to support complicated features +# like default value and keyword arguments + + +def log(data): + """Take log of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log(data) + + +def exp(data): + """Take exp of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.exp(data) + + +def sqrt(data): + """Take sqrt of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sqrt(data) + + +def add(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + + +def subtract(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + + +def equal(lhs, rhs): + return _make.equal(lhs, rhs) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py new file mode 100644 index 000000000000..615a39301142 --- /dev/null +++ b/python/tvm/relay/to_tvm.py @@ -0,0 +1,235 @@ +"""A compiler from Relay programs to TVM's graph runtime. +""" +import json +from typing import Dict, Any, List, Tuple, Set + +import attr +from .ir_pass import AbstractExprVisitor +from .op import compile_ops, Op +from .type import TensorType +from .expr import LocalVar, Function, Let, Call + + +@attr.s(auto_attribs=True) +class NodeRef: + ident: int + index: int = 0 + version: int = 0 + + def to_json(self) -> Any: + return [self.ident, self.index, self.version] + + +@attr.s(auto_attribs=True) +class Node(): + name: str + attrs: Dict[str, Any] + is_output: bool + + def to_json(self) -> Any: + raise Exception("Abstract method, please implement me.") + + +@attr.s(auto_attribs=True) +class InputNode(Node): + """An input node in the graph representation we lower to before NNVM's graph.""" + is_output: bool = False + + def to_json(self): + return { + "op": "null", + "name": self.name, + "inputs": [] + } + + +@attr.s(auto_attribs=True) +class OpNode(Node): + """An operator node in the graph representation we lower to before NNVM's graph.""" + op_name: str + inputs: List[NodeRef] + op_attrs: Dict[str, Any] + is_output: bool = False + + def to_json(self) -> Any: + attrs = dict.copy(self.op_attrs) + # Extend ops with extra info. + attrs['func_name'] = self.op_name + # When do we flatten? + attrs['flatten_data'] = "0" + # Fix me! + attrs['num_inputs'] = str(len(self.inputs)) + attrs['num_outputs'] = "1" + + return { + "op": "tvm_op", + "name": self.name, + "attrs": attrs, + "inputs": self.inputs + } + + +def shape_to_json(shape): + return [sh.value for sh in shape] + + +def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: + return (typ.dtype, shape_to_json(typ.shape)) + + +class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): + """The compiler from Relay to the TVM runtime system.""" + nodes: List[Node] + id_map: Dict[LocalVar, NodeRef] + all_ops: Set[Op] + + def __init__(self) -> None: + self.nodes = [] + self.id_map = {} + self.all_ops = set() + + def add_node(self, node: Node) -> NodeRef: + self.nodes.append(node) + ident = len(self.nodes) - 1 + return NodeRef(ident) + + def add_binding(self, ident: LocalVar, ref: NodeRef) -> None: + self.id_map[ident] = ref + + def let_bind(self, ident: LocalVar, node: Node) -> NodeRef: + ref = self.add_node(node) + self.add_binding(ident, ref) + return ref + + def get_node(self, ref: NodeRef) -> Node: + return self.nodes[ref.ident] + + def lookup(self, ident: LocalVar) -> NodeRef: + return self.id_map[ident] + + def compile(self, func: Function) -> None: + """Compile a single function into a graph.""" + # TODO: (@jroesch) Restore me + # assert len(fn.ty_params) == 0 + + # First we convert all the parameters into input nodes. + params = func.params + + for param in params: + dtype, shape = from_tensor(param.type) + node = InputNode(f"{param.var.name_hint}", { + "shape": shape, + "dtype": dtype, + }) + self.let_bind(param.var, node) + + # Then we compile the body into a graph which can depend + # on input variables. + output_ref = self.visit(func.body) + + # Finally we retreive return value of program, which will + # become our output node. + self.get_node(output_ref).is_output = True + + def visit_let(self, let: Let) -> NodeRef: + """Visit the Let binding, by first traversing its value, + then setting the metadata on the returned NodeRef. + + Finally visit the body, and return the NodeRef corresponding + to it. + """ + ident = let.var + val = let.value + body = let.body + + # Need to add type info? + val_ref = self.visit(val) + dtype, shape = from_tensor(val.checked_type()) + val_node = self.get_node(val_ref) + val_node.attrs["dtype"] = dtype + val_node.attrs["shape"] = shape + self.add_binding(ident, val_ref) + return self.visit(body) + + def visit_local_var(self, ident: LocalVar) -> NodeRef: + return self.lookup(ident) + + def visit_call(self, call: Call) -> NodeRef: + """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" + inputs = [] + for arg in call.args: + inputs.append(self.visit(arg).to_json()) + + assert isinstance(call.op, Op) + self.all_ops.add(call.op.name) + + op_name = call.op.name + attrs = {'shape': shape_to_json(call.checked_type().shape), + 'dtype': call.checked_type().dtype} + op_node = OpNode("call_name", attrs, op_name, inputs, {}) + return self.add_node(op_node) + + def to_json(self) -> str: + """Convert the sequence of nodes stored by the compiler into the + JSON format defined in: https://docs.tvm.ai/dev/nnvm_json_spec.html. + """ + nodes = [] + # First we compute "nodes" field. + for node in self.nodes: + nodes.append(node.to_json()) + + arg_nodes = [] + heads = [] + # Compute "arg_nodes" and "heads" fields. + for i, node in enumerate(self.nodes): + if isinstance(node, InputNode): + arg_nodes.append(i) + + if node.is_output: + # Need to fix this. + heads.append(NodeRef(i).to_json()) + + # Compute "node_row_ptr". + # TODO + + # Compute "attrs" field. + attrs = {} + + # A + shapes = [] + storage_ids = [] + dtype = [] + dltype = [] + + for i, node in enumerate(self.nodes): + storage_ids.append(i) + shapes.append(node.attrs['shape']) + if node.attrs['dtype'] == 'float32': + dtype.append(0) + dltype.append('float32') + + attrs["shape"] = ["list_shape", shapes] + attrs["storage_id"] = ["list_int", storage_ids] + attrs["dtype"] = ["list_int", dtype] + attrs["dltype"] = ["list_str", dltype] + + json_dict = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "heads": heads, + "attrs": attrs + } + + return json.dumps(json_dict) + + +def compile_to_tvm(func): + """Compile a single function to the components needed by the + TVM RTS. + """ + comp = TVMRTSCompiler() + comp.compile(func) + op_names = list(comp.all_ops) + mod = compile_ops(op_names) + graph_json = comp.to_json() + return graph_json, mod, None # params currently isn't supported by API diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py new file mode 100644 index 000000000000..22c853ef512f --- /dev/null +++ b/python/tvm/relay/type.py @@ -0,0 +1,158 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from typing import List +from enum import IntEnum +from tvm import expr +from .base import Span, NodeBase, register_relay_node +from . import _make + + +class Type(NodeBase): + """The base type for all Relay types.""" + + def __eq__(self, other) -> bool: + """Compare two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_make._type_alpha_eq(self, other)) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + def same_as(self, other) -> bool: + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. + """ + shape: List[expr.Expr] + dtype: str + span: Span + + def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. + """ + ShapeVar = 0 + Shape = 1 + BaseType = 2 + Type = 3 + + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + """ + var: expr.Var + kind: Kind + span: Span + + def __init__(self, var: expr.Var, kind: Kind) -> None: + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + """ + type_params: List[TypeParam] + type_constraints: List[TypeConstraint] + arg_types: List[Type] + ret_type: Type + span: Span + + def __init__(self, + arg_types: List[Type], + ret_type: Type, + type_params: List[TypeParam], + type_constraints: List[TypeConstraint]) -> None: + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@register_relay_node +class TypeCall(Type): + def __init__(self, type_rel, args) -> None: + self.__init_handle_by_constructor__( + _make.TypeCall, type_rel, args) + + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + + def __init__(self, kind: Kind) -> None: + self.__init_handle_by_constructor__(_make.IncompleteType, kind) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f169ff1b64ac..f0d60f514a37 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -6,8 +6,10 @@ from . import make as _make from . import expr as _expr + class TensorSlice(NodeGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" + def __init__(self, tensor, indices): if not isinstance(indices, tuple): indices = (indices,) @@ -31,9 +33,11 @@ def dtype(self): itervar_cls = None + @register_node class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" + def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: @@ -104,6 +108,7 @@ def name(self): class Operation(NodeBase): """Represent an operation that generate a tensor""" + def output(self, index): """Get the index-th output of the operation diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc new file mode 100644 index 000000000000..d48b9a4c3e0c --- /dev/null +++ b/src/relay/ir/base.cc @@ -0,0 +1,58 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file base.cc + * \brief The core base types for Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +// TVM_REGISTER_API("relay._make.SourceName") +// .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { +// *ret = SourceNameNode::make(args[0]); +// }); + +// This causes a crash? + +// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +// .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { +// p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; +// }); + +Span SpanNode::make(SourceName source, int lineno, int col_offset) { + std::shared_ptr n = std::make_shared(); + n->source = std::move(source); + n->lineno = lineno; + n->col_offset = col_offset; + return Span(n); +} + +TVM_REGISTER_API("relay._make.Span") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SpanNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << node->name; + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { + p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " + << node->col_offset << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc new file mode 100644 index 000000000000..b5f0d663d26a --- /dev/null +++ b/src/relay/ir/environment.cc @@ -0,0 +1,211 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.cc + * \brief The global environment in Relay. + */ +#include +#include +#include +#include +#include "./../pass/resolve.h" +// #include "tvm/relay/util/rang.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace runtime; + +Environment EnvironmentNode::make(tvm::Map global_funcs) { + std::shared_ptr n = std::make_shared(); + n->functions = std::move(global_funcs); + return Environment(n); +} + +GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { + auto global_id = global_map_.find(str); + if (global_id != global_map_.end()) { + return (*global_id).second; + } else { + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; + } +} + +/*! \brief Add a new item to the global environment + * \note if the update flag is not set adding a duplicate + * definition will trigger an exception, otherwise we will + * update the definition if and only if it is type compatible. + */ +void EnvironmentNode::Add(const GlobalVar &var, const Function &func, + bool update) { + // Type check the item before we add it to the environment. + auto env = GetRef(this); + + Expr checked_expr = InferType(env, var, func); + + if (const FunctionNode *func_node = checked_expr.as()) { + auto checked_func = GetRef(func_node); + auto type = checked_func->checked_type(); + + CHECK(IsFullyResolved(type)); + + if (functions.find(var) != functions.end()) { + if (!update) { + throw dmlc::Error("already have definition for XXXX."); + } + + auto old_type = functions[var].as()->checked_type(); + + if (!AlphaEqual(type, old_type)) { + throw dmlc::Error( + "Environment#update changes type, not possible in this mode."); + } + + this->functions.Set(var, checked_func); + } else { + this->functions.Set(var, checked_func); + } + } else { + throw Error("internal error: unknown item type, unreachable code"); + } +} + +void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { + this->Add(var, func, true); +} + +void EnvironmentNode::Remove(const GlobalVar &) { + // Clarify with @tqchen about how to use COW to do this. + throw Error("NYI"); + // this->items.erase(id); +} + +Function EnvironmentNode::Lookup(const GlobalVar &var) { + auto func = functions.find(var); + if (func != functions.end()) { + return (*func).second; + } else { + throw Error(std::string("there is no definition of ") + var->name_hint); + } +} + +Function EnvironmentNode::Lookup(const std::string &str) { + GlobalVar id = this->GetGlobalVar(str); + return this->Lookup(id); +} + +void EnvironmentNode::Merge(const Environment &env) { + for (auto pair : env->functions) { + this->functions.Set(pair.first, pair.second); + } +} + +inline SourceName EnvironmentNode::AddSource(std::string file_name, + std::string source) { + return this->source_map_.AddSource(file_name, source); +} + +void EnvironmentNode::AddDiagnostic(SpannedError error) { + this->errors_.push_back(error); +} + +void EnvironmentNode::DisplayErrors() { + throw Error("need to restore error printing"); + // for (auto err : this->errors_) { + // auto sp = err.sp; + // auto source_file = this->source_map_.GetSource(err.sp->file_id); + // auto file_name = source_file.file_name; + // auto source_at_span = source_file.SourceAt(err.sp, 1); + // std::string error_marker = "error:"; + // auto line_info = + // std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + + // std::cout << rang::style::bold << rang::fg::red << error_marker + // << rang::fg::reset << file_name << ":" << line_info + // << rang::style::reset << " " << source_at_span << std::endl; + + // // Build the cursor. + + // // Fix this code, hardwired to compute alignment of pointer. + // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + // + + // sp->col_offset - 3; + + // std::string cursor = "~~~~^~~~~"; + // for (size_t i = 0; i < spaces; i++) { + // std::cout << " "; + // } + // std::cout << rang::fg::red << cursor << " " << err.msg << + // rang::style::reset + // << std::endl; + // } +} + +void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { + Array to_process; + for (auto var_and_func : this->functions) { + to_process.push_back(var_and_func.first); + } + + auto for_each = transformer(GetRef(this)); + for (auto var : to_process) { + auto func = this->functions[var]; + auto transformed = for_each(var, func); + this->Add(var, transformed, true); + } +} + +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make(args[0]); + }); + +TVM_REGISTER_API("relay._env.Environment_Add") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Add(args[1], args[2], false); + }); + +TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + *ret = env->GetGlobalVar(args[1]); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalVar var = args[1]; + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup_str") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string var_name = args[1]; + auto var = env->GetGlobalVar(var_name); + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Merge") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Merge(args[1]); + }); + +TVM_REGISTER_API("relay._env.Environment_Transform") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Transform(args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const EnvironmentNode *node, + tvm::IRPrinter *p) { + p->stream << "EnvironmentNode( " << node->functions << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc new file mode 100644 index 000000000000..8dce7c054c8e --- /dev/null +++ b/src/relay/ir/expr.cc @@ -0,0 +1,203 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/expr.cc + * \brief The expression AST nodes of Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Constant ConstantNode::make(runtime::NDArray data) { + std::shared_ptr n = std::make_shared(); + n->data = std::move(data); + return Constant(n); +} + +TVM_REGISTER_API("relay._make.Constant") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ConstantNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ConstantNode *node, + tvm::IRPrinter *p) { + p->stream << "ConstantNode(TODO)"; + }); + +TensorType ConstantNode::tensor_type() const { + auto dl_dtype = data->dtype; + auto dtype = HalideIR::Type(static_cast(dl_dtype.code), + dl_dtype.bits, dl_dtype.lanes); + + Array shape; + for (int i = 0; i < data->ndim; i++) { + shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i])); + } + + return TensorTypeNode::make(shape, dtype); +} + +Tuple TupleNode::make(tvm::Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return Tuple(n); +} + +TVM_REGISTER_API("relay._make.Tuple") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleNode *node, tvm::IRPrinter *p) { + p->stream << "TupleNode(" << node->fields << ")"; + }); + +LocalVar LocalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return LocalVar(n); +} + +TVM_REGISTER_API("relay._make.LocalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LocalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LocalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "LocalVarNode(" << node->name_hint << ")"; + }); + +GlobalVar GlobalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return GlobalVar(n); +} + +TVM_REGISTER_API("relay._make.GlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = GlobalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const GlobalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalVarNode(" << node->name_hint << ")"; + }); + +Param ParamNode::make(LocalVar var, Type type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->type = std::move(type); + return Param(n); +} + +TVM_REGISTER_API("relay._make.Param") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ParamNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { + p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; + }); + +Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, + tvm::Array type_params) { + std::shared_ptr n = std::make_shared(); + n->params = std::move(params); + n->ret_type = std::move(ret_type); + n->body = std::move(body); + n->type_params = std::move(type_params); + return Function(n); +} + +Type FunctionNode::fn_type() const { + Array param_types; + for (auto param : this->params) { + param_types.push_back(param->type); + } + + return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); +} + +TVM_REGISTER_API("relay._make.Function") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type + << ", " << node->body << ", " << node->type_params << ")"; + }); + +Call CallNode::make(Expr op, Array args, Attrs attrs, + Array type_args) { + std::shared_ptr n = std::make_shared(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + return Call(n); +} + +TVM_REGISTER_API("relay._make.Call") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CallNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; + }); + +Let LetNode::make(LocalVar var, Expr value, Expr body, Type value_type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->value = std::move(value); + n->body = std::move(body); + n->value_type = std::move(value_type); + return Let(n); +} + +TVM_REGISTER_API("relay._make.Let") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { + p->stream << "LetNode(" << node->var << ", " << node->value + << ", " << node->body << ", " << node->value_type << ")"; + }); + +If IfNode::make(Expr cond, Expr true_value, Expr false_value) { + std::shared_ptr n = std::make_shared(); + n->cond = std::move(cond); + n->true_value = std::move(true_value); + n->false_value = std::move(false_value); + return If(n); +} + +TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IfNode::make(args[0], args[1], args[2]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { + p->stream << "IfNode(" << node->cond << ", " << node->true_value + << node->false_value << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc new file mode 100644 index 000000000000..7c005acb8648 --- /dev/null +++ b/src/relay/ir/op.cc @@ -0,0 +1,224 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/op.cc + * \brief Resolve incomplete types to complete types. + */ +#include +#include +#include +#include + +#include +#include + +#include "./../pass/type_subst.h" + +namespace dmlc { +// enable registry +DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +} // namespace dmlc + +namespace tvm { +namespace relay { + +::dmlc::Registry* OpRegistry::Registry() { + return ::dmlc::Registry::Get(); +} + +// single manager of operator information. +struct OpManager { + // mutex to avoid registration from multiple threads. + std::mutex mutex; + // global operator counter + std::atomic op_counter{0}; + // storage of additional attribute table. + std::unordered_map> attr; + // frontend functions + std::vector frontend_funcs; + // get singleton of the + static OpManager* Global() { + static OpManager inst; + return &inst; + } +}; + +// find operator by name +const Op& Op::Get(const std::string& name) { + const OpRegistry* reg = dmlc::Registry::Find(name); + CHECK(reg != nullptr) << "Operator " << name << " is not registered"; + return reg->op(); +} + +OpRegistry::OpRegistry() { + OpManager* mgr = OpManager::Global(); + std::shared_ptr n = std::make_shared(); + n->index_ = mgr->op_counter++; + op_ = Op(n); +} + +// Get attribute map by key +const GenericOpMap& Op::GetGenericAttr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + auto it = mgr->attr.find(key); + if (it == mgr->attr.end()) { + LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; + } + return *it->second.get(); +} + +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, + int plevel) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + std::unique_ptr& op_map = mgr->attr[key]; + if (op_map == nullptr) { + op_map.reset(new GenericOpMap()); + } + uint32_t index = op_->index_; + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); + } + std::pair& p = op_map->data_[index]; + CHECK(p.second != plevel) + << "Attribute " << key << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + op_map->data_[index] = std::make_pair(value, plevel); + } +} + +// Frontend APIs +TVM_REGISTER_API("relay.op._ListOpNames") + .set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); + +TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); + +TVM_REGISTER_API("relay.op._OpGetAttr") + .set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); + +TVM_REGISTER_API("relay.op._Register") + .set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); + } else { + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } + } + }); + +bool IsGeneric(const FuncType& func_ty) { + return func_ty->type_params.size() != 0; +} + +using namespace runtime; + +Module CompileOpsToModule(const std::vector& op_names) { + PackedFunc compile_ops = GetPackedFunc("relay.op._compile_ops"); + tvm::Array> args; + + auto compiler_map = Op::GetAttr("FRelayOpCompiler"); + + for (auto op_name : op_names) { + Op op = Op::Get(op_name); + + if (!IsGeneric(op->op_type)) { + auto compiler = compiler_map[op]; + tvm::Array pair = compiler(op->name, op->op_type); + // TODO(@jroesch): I can't pass strings across what should be the + // interface here. + tvm::Array triple = {LocalVarNode::make(op->name), pair[0], + pair[1]}; + args.push_back(triple); + } else { + throw dmlc::Error("it is impossible to compile generic operators."); + } + } + + // Nothing to do, bail out earlier. + // TVM will complain if we try to generate a module of size 0. + if (args.size() == 0) { + return Module(nullptr); + } + + return compile_ops(args); +} + +TVM_REGISTER_API("relay.op._CompileOpsToModule") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector names; + for (auto i = 0; i < args.num_args; i++) { + names.push_back(args[i]); + } + *ret = CompileOpsToModule(names); + }); + +Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, + Array type_args) { + auto registry = ::tvm::relay::OpRegistry::Registry(); + auto op_reg = registry->__REGISTER_OR_GET__(op_name); + auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); + + auto fn_ty = op_reg.op()->op_type; + + tvm::Map subst_map; + + CHECK(fn_ty->type_params.size() == type_args.size()); + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (size_t i = 0; i < type_args.size(); i++) { + subst_map.Set(fn_ty->type_params[i], type_args[i]); + } + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); + FuncType new_op_ty = GetRef(inst_ty.as()); + new_op_reg.op()->op_type = new_op_ty; + + // Now we want to copy over some attributes. + PackedFunc compiler = + Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; + new_op_reg.set_attr("FRelayOpCompiler", compiler); + + return new_op_reg.op(); +} + +TVM_REGISTER_API("relay.op._SpecializeOp") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializeOp(args[0], args[1], args[2]); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc new file mode 100644 index 000000000000..2975c60cc0c1 --- /dev/null +++ b/src/relay/ir/type.cc @@ -0,0 +1,153 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/type.cc + * \brief The type system AST nodes of Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +TensorType TensorTypeNode::make(Array shape, DataType dtype) { + std::shared_ptr n = std::make_shared(); + n->shape = std::move(shape); + n->dtype = std::move(dtype); + return TensorType(n); +} + +TensorType TensorTypeNode::Int(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Int(bits, lanes)); +} + +TensorType TensorTypeNode::UInt(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::UInt(bits, lanes)); +} + +TensorType TensorTypeNode::Float(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Float(bits, lanes)); +} + +TensorType TensorTypeNode::Bool(int lanes) { + return TensorTypeNode::make({}, HalideIR::Bool(lanes)); +} + +TVM_REGISTER_API("relay._make.TensorType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Array shape = args[0]; + *ret = TensorTypeNode::make(shape, args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape + << ")"; + }); + +TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return TypeParam(n); +} + +TVM_REGISTER_API("relay._make.TypeParam") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[1]; + *ret = + TypeParamNode::make(args[0], static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeParamNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeParamNode(" << node->var->name_hint << ", " + << node->kind << ")"; + }); + +FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + std::shared_ptr n = std::make_shared(); + n->arg_types = std::move(arg_types); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->type_constraints = std::move(type_constraints); + return FuncType(n); +} + +TVM_REGISTER_API("relay._make.FuncType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FuncTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "FuncTypeNode(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; + }); + +TypeRelation TypeRelationNode::make(std::string name, int num_args, + TypeRelationFn func) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + n->num_args = std::move(num_args); + n->func_ = std::move(func); + return TypeRelation(n); +} + +TVM_REGISTER_API("relay._make.TypeRelation") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeRelationNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeRelationNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args + << ")"; + }); + +TypeCall TypeCallNode::make(Type func, Array args) { + std::shared_ptr n = std::make_shared(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_API("relay._make.TypeCall") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeCallNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeCallNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); + +TupleType TupleTypeNode::make(Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return TupleType(n); +} + +TVM_REGISTER_API("relay._make.TupleType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleTypeNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc new file mode 100644 index 000000000000..a18259c72117 --- /dev/null +++ b/src/relay/op/tensor/elemwise.cc @@ -0,0 +1,124 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file elemwise.cc + * \brief Elementwise operators. + */ +#include +#include +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// Quick helper macro +// - Expose a positional make function to construct the node. +// - Register op to the registry. +// +// We make the decision to always only expose positional argument. +// We will do rewrapping in the frontend to support language +// sugars such as keyword arguments and default value. +// +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") + + +RELAY_REGISTER_UNARY_OP("log") +.describe(R"code(Returns the log input array, computed element-wise. + +.. math:: + log(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Log", IdentityRel); + +// data : Tensor[shape, dtype] +// result: Tensor[shape, dtype] + + +RELAY_REGISTER_UNARY_OP("exp") +.describe(R"code(Returns the exp input array, computed element-wise. + +.. math:: + \exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Exp", IdentityRel); + + +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the sqrt input array, computed element-wise. + +.. math:: + sqrt(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Sqrt", IdentityRel); + +// Addition +TVM_REGISTER_API("relay.op._make.add") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("add") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("Broadcast", BroadcastRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Addition +TVM_REGISTER_API("relay.op._make.subtract") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("subtract"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("subtract") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("BroadcastComp", BroadcastCompRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Addition +TVM_REGISTER_API("relay.op._make.equal") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("equal") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("BroadcastComp", BroadcastCompRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc new file mode 100644 index 000000000000..e2b2cba1e0ef --- /dev/null +++ b/src/relay/op/type_relations.cc @@ -0,0 +1,135 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_relations.cc + * \brief A set of utilities and common functionality + * for type relations. + */ +#include +#include +#include +#include "../pass/incomplete_type.h" + +namespace tvm { +namespace relay { + +TensorType as_ttype(const Type& t) { + if (auto tt_node = t.as()) { + return GetRef(tt_node); + } else { + return TensorType(nullptr); + } +} + +// TODO(@jroesch) what size value do we extract? +int to_int(const tvm::Expr& e) { + CHECK(e.defined()); + auto imm = e.as(); + CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; + return imm->value; +} + +Array IdentityRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 2); + auto t1 = as_ttype(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } +} + +static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, + DataType output_dtype) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 + << std::endl; + auto sh1 = t1->shape; + auto sh2 = t2->shape; + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 + << std::endl; + if (sh1.size() == 0 && sh2.size() == 0) { + return TensorTypeNode::make({}, output_dtype); + // We have non-zero shapes so broadcast rules apply. + } else { + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + auto rev_sh1 = sh1.rbegin(); + auto rev_sh2 = sh2.rbegin(); + + while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { + auto dim1 = to_int(*rev_sh1); + auto dim2 = to_int(*rev_sh2); + if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { + CHECK(false) << "Dimension mistmatch " << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; + } + rev_sh1++; + rev_sh2++; + } + + Array larger; + Array smaller; + + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); + } + + if (sh1.size() < sh2.size()) { + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; + } else if (sh1.size() > sh2.size()) { + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; + } else { + larger = sh1; + smaller = sh2; + } + + CHECK_EQ(larger.size(), smaller.size()); + + Array out_shape; + for (size_t i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); + } + + return TensorTypeNode::make(out_shape, output_dtype); + } +} + +Array BroadcastRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 3); + RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] << "Out: " << types[2] << std::endl; + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + CHECK_EQ(t1->dtype, t2->dtype); + return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; + } + } + + return types; +} + +/* A relation which specifies broadcasting rules for operations which + compute boolean results. +*/ +Array BroadcastCompRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 3); + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; + } + } + + return types; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h new file mode 100644 index 000000000000..3597246b5a4a --- /dev/null +++ b/src/relay/op/type_relations.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op/type_relations.h + * \brief A set of utilities and common functionality + * for type relations. + */ +#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ +#define TVM_RELAY_OP_TYPE_RELATIONS_H_ + +#include +#include + +namespace tvm { +namespace relay { + +Array IdentityRel(const Array & types, int num_args); +Array BroadcastRel(const Array & types, int num_args); +Array BroadcastCompRel(const Array & types, int num_args); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_ diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc new file mode 100644 index 000000000000..555d4f2db99d --- /dev/null +++ b/src/relay/pass/alpha_eq.cc @@ -0,0 +1,290 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pass/alpha_eq.cc + * \brief Compute the set of variables not bound in the expression. + */ +#include "tvm/relay/pass/alpha_eq.h" +#include "tvm/relay/expr_visitor.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + + TypeAlphaEq() : eq_map(), equal(true) {} + + void DataTypeEqual(const DataType & dt1, const DataType & dt2) { + equal = equal && dt1 == dt2; + } + void ShapeEqual(Array s1, Array s2) { + } + + void VisitType_(const TensorTypeNode *tt1, const Type &t2) override { + if (const TensorTypeNode *tt2 = t2.as()) { + DataTypeEqual(tt1->dtype, tt2->dtype); + ShapeEqual(tt1->shape, tt2->shape); + } else { + equal = false; + } + } + + void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) override { + if (const IncompleteTypeNode *bt2 = t2.as()) { + equal = equal && bt1 == bt2; + return; + } else { + equal = false; + } + } + + void VisitType_(const TypeParamNode *ti1, const Type &t2) override { + if (const TypeParamNode *ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + return; + } + + // Check that they are same kind + if (tid1->kind != tid2->kind) { + equal = false; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const FuncTypeNode *op, const Type &t2) override { + if (const FuncTypeNode *ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < op->arg_types.size(); i++) { + this->VisitType(op->arg_types[i], ta2->arg_types[i]); + if (!equal) { + return; + } + } + + this->VisitType(op->ret_type, ta2->ret_type); + } else { + equal = false; + } + } + + void VisitType_(const TypeRelationNode *tr1, const Type &t2) override { + if (const TypeRelationNode *tr2 = t2.as()) { + equal = tr1 == tr2; + } else { + equal = false; + } + } + +// void VisitType_(const TupleTypeNode *op, const Type &t2) override { +// if (const TupleTypeNode *pt = t2.as()) { +// if (op->fields.size() != pt->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->fields.size(); i++) { +// if (!equal) { +// return; +// } +// this->VisitType(op->fields[i], pt->fields[i]); +// } +// } else { +// equal = false; +// } +// } + + void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { + TypeCall tycall = GetRef(tyn1); + if (const TypeCallNode *tyn2 = t2.as()) { + if (tycall->func != tyn2->func) { + equal = false; + return; + } + + if (tycall->args.size() != tyn2->args.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < tycall->args.size(); i++) { + this->VisitType(tycall->args[i], tyn2->args[i]); + } + } else { + equal = false; + } + } +}; + +bool AlphaEqual(const Type &t1, const Type &t2) { + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +// struct AlphaEq : ExprVisitor { +// public: +// tvm::Map eq_map; +// bool equal; +// AlphaEq() : eq_map(), equal(true) {} + +// void VisitExpr_(const LocalIdNode *e1, const Expr &e2) override { +// if (const LocalIdNode *id2 = e2.as()) { +// auto local1 = GetRef(e1); +// auto local2 = GetRef(id2); +// // +// // We handle open terms with this rule assuming variables are identical. +// // +// // Not sure if we should do this. +// if (local1 == local2) { +// equal = true; +// return; +// } + +// // Next we see if there is mapping for local1 into the rhs term. +// // If there is we check to see if those are equal. +// if (eq_map.find(local1) != eq_map.end()) { +// equal = equal && eq_map[local1] == local2; +// } else { +// equal = false; +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const GlobalIdNode *g1, const Expr &e2) override { +// if (const GlobalIdNode *g2 = e2.as()) { +// equal = equal && g1 == g2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const OperatorIdNode *i1, const Expr &e2) override { +// if (const OperatorIdNode *i2 = e2.as()) { +// equal = equal && i1 == i2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const TupleNode *pl1, const Expr &e2) override { +// Tuple prod1 = GetRef(pl1); +// if (const TupleNode *pl2 = e2.as()) { +// Tuple prod2 = GetRef(pl2); +// if (prod1->fields.size() != prod2->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < prod1->fields.size(); i++) { +// this->VisitExpr(prod1->fields[i], prod2->fields[i]); +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const ParamNode *p1, const Expr &e2) override { +// if (const ParamNode *p2 = e2.as()) { +// eq_map.Set(p1->id, p2->id); +// equal = equal && alpha_eq(p1->type, p2->type); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const FunctionNode *func1, const Expr &e2) override { +// if (const FunctionNode *func2 = e2.as()) { +// if (func1->params.size() != func2->params.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < func1->params.size(); i++) { +// this->VisitExpr(func1->params[i], func2->params[i]); +// } + +// this->VisitExpr(func1->body, func2->body); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const CallNode *op, const Expr &e2) override { +// if (const CallNode *call = e2.as()) { +// this->VisitExpr(op->fn, call->fn); + +// if (op->args.size() != call->args.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->args.size(); i++) { +// this->VisitExpr(op->args[i], call->args[i]); +// } + +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const LetNode *op, const Expr &e2) override { +// if (const LetNode *let = e2.as()) { +// eq_map.Set(op->id, let->id); +// this->VisitExpr(op->value, let->value); +// this->VisitExpr(op->body, let->body); +// } else { +// equal = false; +// } +// } +// }; + +// bool alpha_eq(const Expr &e1, const Expr &e2) { +// AlphaEq eq; +// eq.VisitExpr(e1, e2); +// return eq.equal; +// } + +// // TODO(@jroesch): move to correct namespace? +// TVM_REGISTER_API("relay._make._alpha_eq") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Expr e1 = args[0]; +// Expr e2 = args[1]; +// *ret = alpha_eq(e1, e2); +// }); + +TVM_REGISTER_API("relay._make._type_alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = AlphaEqual(t1, t2); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/incomplete_type.h b/src/relay/pass/incomplete_type.h new file mode 100644 index 000000000000..78771dc6e9b7 --- /dev/null +++ b/src/relay/pass/incomplete_type.h @@ -0,0 +1,38 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file incomplete_type.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ + +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Represents a portion of an incomplete type. + */ +class IncompleteType; + +/*! \brief IncompleteType container node */ +class IncompleteTypeNode : public TypeNode { + public: + TypeParamNode::Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); } + + TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); + + static constexpr const char* _type_key = "relay.IncompleteType"; + TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc new file mode 100644 index 000000000000..522eb93483fb --- /dev/null +++ b/src/relay/pass/kind_check.cc @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file kindchecker.cc + * + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always + * contains a data type such as `int`, `float`, `uint`. + */ +#include +#include +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct KindChecker : TypeVisitor<> { + bool valid; + + KindChecker() : valid(true) {} + + bool Check(const Type &t) { + this->VisitType(t); + return valid; + } +}; + +bool KindCheck(const Environment& env, const Type &t) { + KindChecker kc; + return kc.Check(t); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/local_var_well_formed.cc b/src/relay/pass/local_var_well_formed.cc new file mode 100644 index 000000000000..7cfd93a3ff05 --- /dev/null +++ b/src/relay/pass/local_var_well_formed.cc @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file local_var_well_formed.cc + * \brief Function for substituting a concrete type in place of a type ID + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +struct ShadowDetected { }; + +struct DetectShadow : ExprVisitor { + struct Insert { + DetectShadow * ds; + LocalVar lv; + Insert(DetectShadow * ds, const LocalVar & lv) : ds(ds), lv(lv) { + if (ds->s.count(lv) != 0) { + throw ShadowDetected(); + } + ds->s.insert(lv); + } + Insert(const Insert &) = delete; + Insert(Insert &&) = default; + ~Insert() { + ds->s.erase(lv); + } + }; + std::unordered_set s; + void VisitExpr_(const LetNode & l) { + Insert ins(this, l.var); + (*this)(l.value); // we do letrec only for FunctionNode, but shadowing let in let binding is dangerous, and we should forbidden it. + (*this)(l.body); + } + void VisitExpr_(const FunctionNode & f) { + std::vector ins; + for (const Param & p : f.params) { + ins.push_back(Insert(this, p->var)); + } + (*this)(f.body); + } +}; + +bool LocalVarWellFormed(const Expr & e) { + try { + DetectShadow()(e); + return true; + } catch (const ShadowDetected &) { + return false; + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc new file mode 100644 index 000000000000..bc63d939959e --- /dev/null +++ b/src/relay/pass/resolve.cc @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file resolve.cc + * \brief Resolve incomplete types to complete types. + */ + +#include +#include +#include "./resolve.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +// TODO(@jroesch): We should probably generalize the subst code. +struct ResolveTypeType : TypeFVisitor { + const TypeUnifier &unifier; + + explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} + + Type VisitType(const Type &t) override { + if (!t.defined()) { + auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + unifier->insert(inc_ty); + return inc_ty; + } else { + return TypeFVisitor::VisitType(t); + } + } + + Type VisitType_(const IncompleteTypeNode *op) override { + return unifier->subst(GetRef(op)); + } +}; + +struct ResolveTypeExpr : ExprFVisitor { + const TypeUnifier &unifier; + + explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} + + Expr VisitExpr(const Expr &e) { + // NB: a bit tricky here. + // + // We want to store resolved type without having + // to re-typecheck the entire term. + // + // Since we know that e : T[...] under some holes + // then it is the case that if we resolve types + // present in e, then we can type it under T + // with the wholes filled in. + // + // We will visit e like normal building a new + // term, then resolve e's old type and write + // it back into the new node. + auto new_e = ExprFVisitor::VisitExpr(e); + CHECK(e->checked_type_.defined()); + auto resolved_cty = VisitType(e->checked_type_); + new_e->checked_type_ = resolved_cty; + return new_e; + } + + Type VisitType(const Type &t) { + return ResolveTypeType(unifier).VisitType(t); + } +}; + +Type Resolve(const TypeUnifier &unifier, const Type &ty) { + CHECK(ty.defined()); + return ResolveTypeType(unifier).VisitType(ty); +} + +Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { + return ResolveTypeExpr(unifier).VisitExpr(expr); +} + +struct FullyResolved : TypeVisitor<> { + bool incomplete; + + FullyResolved() : incomplete(true) {} + + void VisitType(const Type &t) override { + if (!t.defined()) { + incomplete = true; + } else { + return TypeVisitor<>::VisitType(t); + } + } + + void VisitType_(const IncompleteTypeNode *ty_var) override { + incomplete = false; + } +}; + +bool IsFullyResolved(const Type &t) { + auto fr = FullyResolved(); + fr.VisitType(t); + return fr.incomplete; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h new file mode 100644 index 000000000000..deb6558322b8 --- /dev/null +++ b/src/relay/pass/resolve.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/resolve.h + * \brief Resolve incomplete types to complete types. + */ +#ifndef TVM_RELAY_PASS_RESOLVE_H_ +#define TVM_RELAY_PASS_RESOLVE_H_ + +#include +#include +#include "./unifier.h" + +namespace tvm { +namespace relay { + +Type Resolve(const TypeUnifier & unifier, const Type & ty); +Expr Resolve(const TypeUnifier & unifier, const Expr & expr); +bool IsFullyResolved(const Type & t); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_RESOLVE_H_ diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h new file mode 100644 index 000000000000..9180703b49e8 --- /dev/null +++ b/src/relay/pass/type_functor.h @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_functor.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ +#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ +#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ + +#include +#include +#include "./incomplete_type.h" + +namespace tvm { +namespace relay { + +template +class TypeFunctor; + +// functions to be overriden. +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward(args)...); } + +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class TypeFunctor { + private: + using TSelf = TypeFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~TypeFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Type& n, Args... args) { + return VisitType(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitType(const Type& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitType_(const TensorTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + + virtual R VisitTypeDefault_(const Node* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc new file mode 100644 index 000000000000..4873b0a55580 --- /dev/null +++ b/src/relay/pass/type_infer.cc @@ -0,0 +1,488 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_infer.cc + * \brief Relay type inference and checking. + * + * This file implements one of the most important passes to the + * Relay IR. In order to do many transformations and generate the + * most efficient code we need to obtain type information for the + * IR. + * + * Like computation graphs the IR leaves most type information + * implicit and relies performing analysis of the program to + * generate this information. + * + * This pass given an expression `e` will infer a type `t` for + * the expression simultaneous checking the property `e : t` + * (i.e we can show e has type t). + * + * If we can not infer a type or there are conflicting typing + * constraints we will trigger an error. + */ + +#include +#include +#include +#include +#include "./incomplete_type.h" +#include "./resolve.h" +#include "./type_subst.h" +#include "./type_visitor.h" +#include "./unifier.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct TypeContext { + std::vector> stack; + + TypeContext() { stack.push_back({}); } + + void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } + + Type lookup(const LocalVar &id) { + for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { + if (frame->find(id) != frame->end()) { + return frame->at(id); + } + } + throw FatalTypeError("Could not resolve local id"); + } + + struct LocalFrame { + TypeContext &tc; + explicit LocalFrame(TypeContext &tc) : tc(tc) { tc.stack.push_back({}); } + ~LocalFrame() { tc.stack.pop_back(); } + }; +}; + +struct TypeNormalizer : TypeFVisitor { + TypeUnifier unifier; + explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} + + Type VisitType_(const TypeCallNode *ty_call_node) { + auto ty_call = GetRef(ty_call_node); + + Array normalized_args; + + for (auto arg : ty_call->args) { + normalized_args.push_back(VisitType(arg)); + } + + auto all_concrete = true; + for (auto arg : normalized_args) { + all_concrete = all_concrete && !arg.as(); + } + + if (all_concrete) { + return normalized_args[normalized_args.size() - 1]; + } else { + if (auto ty_rel_node = ty_call->func.as()) { + // NB: we substract 1 for the output argument. + auto new_args = + ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + CHECK(new_args.size() == normalized_args.size()); + tvm::Array final_args; + + for (size_t i = 0; i < new_args.size(); i++) { + final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); + } + + return TypeCallNode::make(ty_call->func, final_args); + } else { + throw InternalError("found non type relation in the "\ + "type call function position"); + } + } + } +}; + +struct CheckedExpr { + Expr expr; + Type type; + CheckedExpr(Expr e, Type t) : expr(e), type(t) {} + CheckedExpr() {} +}; + +class TypeInferencer : private ExprFunctor { + private: + TypeContext local_stack; + + public: + Environment env; + TypeUnifier unifier; + + // Should be in header? + template + T with_frame(const std::function &f) { + TypeContext::LocalFrame fr(local_stack); + return f(); + } + + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) + : env(env), unifier(unifier) {} + explicit TypeInferencer(Environment env); + + CheckedExpr Infer(const Expr &expr); + + FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); + + Type Normalize(const Type &t); + + void report_error(const std::string &msg, Span sp); + [[noreturn]] void fatal_error(const std::string &msg, Span sp); + + Type unify(const Type &t1, const Type &t2, Span sp); + Type resolve(const Type &t); + Expr resolve(const Expr &e); + CheckedExpr VisitFunction(const Function &f, bool generalize); + private: + CheckedExpr VisitExpr_(const LocalVarNode *op) override; + CheckedExpr VisitExpr_(const GlobalVarNode *op) override; + CheckedExpr VisitExpr_(const ConstantNode *op) override; + CheckedExpr VisitExpr_(const TupleNode *op) override; + CheckedExpr VisitExpr_(const ParamNode *op) override; + CheckedExpr VisitExpr_(const FunctionNode *op) override; + CheckedExpr VisitExpr_(const CallNode *op) override; + CheckedExpr VisitExpr_(const LetNode *op) override; + CheckedExpr VisitExpr_(const IfNode *op) override; + CheckedExpr VisitExpr_(const OpNode *op) override; +}; + +TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +Type TypeInferencer::Normalize(const Type &t) { + auto nt = this->resolve(t); + auto normalizer = TypeNormalizer(this->unifier); + return normalizer.VisitType(nt); +} + +CheckedExpr TypeInferencer::Infer(const Expr &expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type + << std::endl; + Type final_type = Normalize(checked_expr.type); + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type + << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; +} + +CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { + auto var = GetRef(op); + return {var, this->local_stack.lookup(var)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { + GlobalVar var = GetRef(op); + Expr e = this->env->Lookup(var); + return {var, e->checked_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { + return {GetRef(const_node), const_node->tensor_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { + Tuple pl = GetRef(op); + + std::vector field_exprs; + std::vector field_types; + for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { + auto checked_field = Infer(*field); + field_exprs.push_back(checked_field.expr); + field_types.push_back(checked_field.type); + } + + return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + // We should trigger error here and move param code direclty into function checking. + auto rtype = resolve(param->type); + // This is a special case ... not sure if there is a better way + // to handle this. + param->var->checked_type_ = rtype; + return {ParamNode::make(param->var, rtype), rtype}; +} + +CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->with_frame([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->local_stack.insert(param->var, checked_param.type); + } + + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = resolve(f->ret_type); + + auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + + return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, {})}; + }); +} + +CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { + return this->VisitFunction(GetRef(op), false); +} + +FuncType TypeInferencer::instantiate(FuncType fn_ty, + tvm::Array &ty_args) { + tvm::Map subst_map; + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); + } + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); + + CHECK(KindCheck(this->env, inst_ty)); + + return GetRef(inst_ty.as()); +} + +CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { + Call c = GetRef(op); + + auto checked_op = this->Infer(c->op); + + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << checked_op.type << std::endl; + + auto fn_ty_node = checked_op.type.as(); + + if (!fn_ty_node) { + this->fatal_error("only expressions with function types can be called", + c->op->span); + } + + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); + + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } + + fn_ty = instantiate(fn_ty, ty_args); + + std::vector arg_types; + std::vector checked_args; + + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); + } + + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); + + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->fatal_error("the function is provided too many arguments", c->span); + } else { + this->fatal_error("the function is provided too few arguments", c->span); + } + } + + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); + } + + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. + + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->subst(ty_args[i])); + } + + auto new_call = + CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); + + return {new_call, fn_ty->ret_type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { + Let let = GetRef(op); + + CheckedExpr checked_value; + Type annotated_ty = resolve(let->value_type); + + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + with_frame([&]() { + local_stack.insert(let->var, annotated_ty); + checked_value = Infer(let->value); + }); + } else { + checked_value = Infer(let->value); + } + + Type unified_ty = this->unify(checked_value.type, annotated_ty, let->span); + + // Update type context with unified type now that we have + // solved this equation. + local_stack.insert(let->var, unified_ty); + + auto checked_body = with_frame([&]() { + local_stack.insert(let->var, unified_ty); + return Infer(let->body); + }); + + auto checked_let = LetNode::make(let->var, checked_value.expr, + checked_body.expr, let->value_type); + + return {checked_let, checked_body.type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, + checked_false.expr); + return {checked_if, unified_type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { + auto op = GetRef(op_node); + return {op, op->op_type}; +} + +Type TypeInferencer::resolve(const Type& t) { + if (t.defined()) { + return ::tvm::relay::Resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + } +} + +Expr TypeInferencer::resolve(const Expr& e) { + CHECK(e.defined()); + return ::tvm::relay::Resolve(this->unifier, e); +} + +Expr InferType(const Environment &env, const Expr &e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + return ti.resolve(checked_expr.expr); +} + +Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { + TypeInferencer ti(env); + auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); + func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); + env->functions.Set(var, func_copy); + auto checked_expr = ti.Infer(func); + auto map_node = env->functions.CopyOnWrite(); + map_node->data.erase(var.node_); + return ti.resolve(checked_expr.expr); +} + + +inline void TypeInferencer::report_error(const std::string &msg, Span sp) { + this->env->AddDiagnostic({msg, sp}); +} + +void TypeInferencer::fatal_error(const std::string &msg, Span sp) { + this->env->AddDiagnostic({msg, sp}); + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); +} + +Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { + try { + return Normalize(this->unifier->unify(t1, t2)); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + // ss << PrintType(env, t1, WrapWidth(40)); + ss << "` and `"; + ss << t2; + // ss << PrintType(env, t2, WrapWidth(40)); + ss << "`: " << e.what(); + this->fatal_error(ss.str(), sp); + } +} + +TVM_REGISTER_API("relay._ir_pass.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = InferType(env, e); + }); + +// TODO(@jroesch): put in a better namespace. +TVM_REGISTER_API("relay._ir_pass._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); + +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc new file mode 100644 index 000000000000..91713976bcaa --- /dev/null +++ b/src/relay/pass/type_subst.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_subst.cc + * \brief Function for substituting a concrete type in place of a type ID + */ +#include "./type_subst.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +struct TypeSubstV : TypeFVisitor { + tvm::Map subst_map; + + explicit TypeSubstV(tvm::Map subst_map) + : subst_map(subst_map) {} + + Type VisitType_(const TypeParamNode *op) override { + auto id = GetRef(op); + if (subst_map.find(id) != subst_map.end()) { + return this->subst_map[id]; + } else { + return id; + } + } +}; + +Type TypeSubst(const Type &type, const TypeParam &target, const Type &subst) { + TypeSubstV ty_sub({ {target, subst} }); + return ty_sub.VisitType(type); +} + +Type TypeSubst(const Type &type, tvm::Map subst_map) { + TypeSubstV ty_sub(subst_map); + return ty_sub.VisitType(type); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h new file mode 100644 index 000000000000..5b6956f8e451 --- /dev/null +++ b/src/relay/pass/type_subst.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pass/type_subst.h + * \brief Utility functions for substituting types. + */ +#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_ +#define TVM_RELAY_PASS_TYPE_SUBST_H_ + +#include + +namespace tvm { +namespace relay { + +Type TypeSubst(const Type & type, const TypeParam & target, const Type & subst); +Type TypeSubst(const Type &type, tvm::Map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_SUBST_H_ diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h new file mode 100644 index 000000000000..d65d6c567b23 --- /dev/null +++ b/src/relay/pass/type_visitor.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_visitor.h + * \brief A wrapper around TypeFunctor for common use cases. + */ +#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_ +#define TVM_RELAY_PASS_TYPE_VISITOR_H_ + +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +/*! \brief A type visitor for vistiors which make use of internal + * mutable state. + * + * We recursively visit each type contained inside the visitor. + */ +template +struct TypeVisitor : ::tvm::relay::TypeFunctor { + void VisitType_(const TypeParamNode* op, Args... args) override {} + + void VisitType_(const FuncTypeNode* op, Args... args) override { + // TODO(@jroesch): handle poly + // this->VisitType(op->var, args...); + // this->VisitType(op->boundType, args...); + for (auto arg_type : op->arg_types) { + this->VisitType(arg_type, args...); + } + this->VisitType(op->ret_type, args...); + } + + void VisitType_(const TensorTypeNode* op, Args... args) override {} + + void VisitType_(const TupleTypeNode* op, Args... args) override { + for (const Type& t : op->fields) { + this->VisitType(t, args...); + } + } + + void VisitType_(const TypeCallNode* op, Args... args) override { + this->VisitType(op->func, args...); + + for (const Type& t : op->args) { + this->VisitType(t, args...); + } + } + + void VisitType_(const TypeRelationNode* op, Args... args) override {} + void VisitType_(const IncompleteTypeNode* op, Args... args) override {} +}; + +// A functional visitor for rebuilding an AST in place. +struct TypeFVisitor : TypeFunctor { + Type VisitType_(const TensorTypeNode* op) override { + // TODO(@jroesch): maybe we should recursively visit + return TensorTypeNode::make(op->shape, op->dtype); + } + + Type VisitType_(const TypeParamNode* op) override { + return GetRef(op); + } + + Type VisitType_(const FuncTypeNode* op) override { + // TODO(@jroesch): handle poly + + // auto new_id = this->VisitType(op->var); + // if (const TypeParamNode* tin = new_id.as()) { + // return TypeQuantifierNode::make(GetRef(tin), + // this->VisitType(op->boundType)); + + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); + } + + return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), + {}, {}); // fix me + } + + Type VisitType_(const TupleTypeNode* op) override { + std::vector new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->VisitType(t)); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const TypeRelationNode* op) override { + return GetRef(op); + } + + Type VisitType_(const TypeCallNode* op) override { + auto func = this->VisitType(op->func); + std::vector new_args; + for (const Type& t : op->args) { + new_args.push_back(this->VisitType(t)); + } + return TypeCallNode::make(func, new_args); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_ diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc new file mode 100644 index 000000000000..f5e337eb17f7 --- /dev/null +++ b/src/relay/pass/unifier.cc @@ -0,0 +1,374 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/src/relay/pass/unifier.cc + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ + +#include "./unifier.h" +#include +#include +#include +#include +#include "./type_visitor.h" +// #include "tvm/relay/typeck/kindchecker.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +UnionFind UnionFindNode::make(tvm::Map uf_map) { + std::shared_ptr n = std::make_shared(); + n->uf_map = uf_map; + return UnionFind(n); +} + +void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } + +void UnionFindNode::debug() { + for (auto entry : this->uf_map) { + RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl; + } +} + +void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { + if (!AlphaEqual(l, r)) { + std::stringstream ss; + ss << "Incompatible parent types in UF:" << l << " and " << r; + throw UnionFindError(ss.str()); + } +} + +void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << "t=" << t << std::endl; + auto parent1 = this->find(v1); + + // if t is a type var, then unify parents + const IncompleteTypeNode *tvn2 = t.as(); + if (tvn2) { + auto v2 = GetRef(tvn2); + auto parent2 = this->find(v2); + + // if parents are exactly equal, then we're done + if (parent1 == parent2) { + return; + } + + // if first parent is a type var, then can just set its union find map to + // second parent + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, parent2); + return; + } + + // if second parent is a type var but first isn't, can set second type var + if (const IncompleteTypeNode *pvn2 = parent2.as()) { + auto pv2 = GetRef(pvn2); + this->uf_map.Set(pv2, parent1); + return; + } + + // if both parents are not type vars themselves, check alpha-equality + AssertAlphaEqual(parent1, parent2); + return; + } + + // if t is not a type var, then unify with v1's parent if parent is a type + // var; else, check alpha-equality for compatibility + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, t); + return; + } + + AssertAlphaEqual(parent1, t); +} + +Type UnionFindNode::find(const IncompleteType &v) { + // The node has no mapping, so its representative is just itself. + if (this->uf_map.find(v) == this->uf_map.end()) { + return v; + } + + Type parent = this->uf_map.at(v); + + if (v == parent) { + return v; + } + + // if parent is not a type var, then it must be the representative type + const IncompleteTypeNode *rep = parent.as(); + if (!rep) { + return parent; + } + + // otherwise, recurse and perform path compression + IncompleteType pv = GetRef(rep); + Type higher_up = this->find(pv); + this->uf_map.Set(v, higher_up); + return higher_up; +} + +TVM_REGISTER_API("relay._make.UnionFind") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 0) { + *ret = UnionFindNode::make({}); + } else { + *ret = UnionFindNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const UnionFindNode *node, + tvm::IRPrinter *p) { + p->stream << "UnionFindNode(" << node->uf_map << ")"; + }); + +TypeUnifier TypeUnifierNode::make(UnionFind uf) { + std::shared_ptr n = std::make_shared(); + n->uf = uf; + return TypeUnifier(n); +} + +void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->insert(v); } + +Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { + RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 + << std::endl; + + Type unified = this->VisitType(t1, t2); + // if (!check_kind(unified)) { + // throw UnificationError("Invalid kinds in unified type"); + // } + return unified; +} + +struct IncompleteTypeSubst : TypeFVisitor { + const TypeUnifierNode *unifier; + + IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} + + // type var: look it up in the type map and recurse + Type VisitType_(const IncompleteTypeNode *op) override { + auto tv = GetRef(op); + auto parent = unifier->uf->find(tv); + if (parent == tv) { + return tv; + } + return this->VisitType(parent); + } +}; + +Type TypeUnifierNode::subst(const Type &t) { + IncompleteTypeSubst tvsubst(this); + // normalize first so substitutions in quantifiers will be correct + Type ret = tvsubst.VisitType(t); + // if (!check_kind(ret)) { + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); + // } + return ret; +} + +Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { + // When the right hand size is a type variable immediately unify. + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + // The TypeCallNode case is special and not symmetric. + // + // We flip the arguments so we hit the TypeCall and other case in there is + // ever a type call. + } else if (t2.as()) { + return TypeFunctor::VisitType(t2, t1); + } else { + return TypeFunctor::VisitType(t1, t2); + } +} + +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, + const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 + << std::endl; + // Fix unify to return new representative + this->uf->unify(tv2, t1); + auto rep = this->uf->find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { + IncompleteType tv1 = GetRef(t1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 + << std::endl; + this->uf->unify(tv1, rt2); + auto rep = this->uf->find(tv1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { + TypeParam ti1 = GetRef(t1); + + // for other type ids, only check equality + if (const TypeParamNode *tin2 = rt2.as()) { + TypeParam ti2 = GetRef(tin2); + + if (ti1 != ti2) { + throw UnificationError("Attempting to unify non-matching TypeParams"); + } + + return ti1; + } + + // cannot unify TypeParam with non-TypeParam + throw UnificationError("Unable to unify TypeParamNode"); +} + +Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { + FuncType ft1 = GetRef(t1); + + if (const FuncTypeNode *tan2 = rt2.as()) { + FuncType ft2 = GetRef(tan2); + + if (ft1->type_params.size() != ft2->type_params.size()) { + throw UnificationError( + "unable to unify functions with differing number of type parameters"); + } + + if (ft1->type_params.size() != 0) { + throw dmlc::Error("NYI"); + } + + // TypeParam id1 = tq1->id; + // TypeParam id2 = tq2->id; + + // if (id1->kind != id2->kind) { + // throw UnificationError( + // "Cannot unify quantifiers over ids of different kinds"); + // } + + // TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + + // auto bt1 = type_subst(tq1->boundType, id1, fresh); + // auto bt2 = type_subst(tq2->boundType, id2, fresh); + + // Type unified_bound_type = this->VisitType(bt1, bt2); + + if (ft1->arg_types.size() != ft2->arg_types.size()) { + throw UnificationError("unable to unify functions of different arities"); + } + + tvm::Array unified_args; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + unified_args.push_back( + this->VisitType(ft1->arg_types[i], ft2->arg_types[i])); + } + + Type unified_ret_type = this->VisitType(ft1->ret_type, ft2->ret_type); + + return FuncTypeNode::make(unified_args, unified_ret_type, {}, {}); + } + + throw UnificationError("unable to unify function types"); +} + +Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { + TensorType tt1 = GetRef(t1); + + if (const TensorTypeNode *ttn2 = rt2.as()) { + TensorType tt2 = GetRef(ttn2); + + if (!AlphaEqual(tt1, tt2)) { + throw UnificationError("dtypes do not match"); + } + + RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape + << " s2= " << tt2->shape << std::endl; + try { + // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); + return rt2; + } catch (const UnificationError &err) { + CHECK(false) << "Need to check constraint " << tt1->shape << " = " + << tt2->shape << std::endl; + } + + // fix me + return rt2; + // return TensorTypeNode::make(unified_bt, tt2->shape); + } + + // nothing else can unify + throw UnificationError("Cannot unify TensorTypeNode"); +} + +Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { + TupleType pt1 = GetRef(t1); + + // When unifying tuple types we just solve each field in order. + if (const TupleTypeNode *ptn2 = rt2.as()) { + TupleType pt2 = GetRef(ptn2); + + std::vector unified_fields; + if (pt1->fields.size() != pt2->fields.size()) { + throw UnificationError("Product types are of different dimensions"); + } + + for (size_t i = 0U; i < pt1->fields.size(); i++) { + Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); + unified_fields.push_back(unified); + } + + return TupleTypeNode::make(unified_fields); + } + + // otherwise cannot unify + throw UnificationError("Cannot unify TupleTypeNode"); +} + +Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { + if (const TypeRelationNode *tr2 = t2.as()) { + if (tr1 == tr2) { + return GetRef(tr1); + } else { + throw UnificationError("Cannot unify different type relations"); + } + } else { + throw UnificationError( + "Cannot unify type relation with another type of type"); + } +} + +Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { + TypeCall ty_call1 = GetRef(tcn1); + + if (const TypeCallNode *tcn2 = t2.as()) { + Type unified_func = this->VisitType(ty_call1->func, tcn2->func); + + // For now, we will only unify if they are equal. + if (ty_call1->args.size() != tcn2->args.size()) { + throw UnificationError( + "Cannot unify calls of different number of arguments"); + } + + // Unify members, if possible + tvm::Array new_args; + for (size_t i = 0U; i < ty_call1->args.size(); i++) { + Type unified_member = this->VisitType(ty_call1->args[i], tcn2->args[i]); + new_args.push_back(unified_member); + } + + return TypeCallNode::make(unified_func, new_args); + } else { + auto args = ty_call1->args; + return this->VisitType(args[args.size() - 1], t2); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h new file mode 100644 index 000000000000..0671a40c0d74 --- /dev/null +++ b/src/relay/pass/unifier.h @@ -0,0 +1,134 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file include/tvm/relay/pass/unifier.h + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ +#ifndef TVM_RELAY_PASS_UNIFIER_H_ +#define TVM_RELAY_PASS_UNIFIER_H_ + +#include +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +struct UnionFindError : dmlc::Error { + explicit UnionFindError(const std::string& msg) : Error(msg) {} +}; + +struct UnificationError : dmlc::Error { + explicit UnificationError(const std::string& msg) : Error(msg) {} +}; + +struct SubstitutionError : dmlc::Error { + explicit SubstitutionError(const std::string& msg) : Error(msg) {} +}; + +/*! \brief a union-find data structure for the type-checker */ +class UnionFind; // forward declaration + +class UnionFindNode : public Node { + public: + tvm::Map uf_map; + + UnionFindNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); } + + TVM_DLL static UnionFind make(tvm::Map uf_map); + + // insert v into UF + void insert(const IncompleteType& v); + + // infers that v1 and v2 must be of the smae type + void unify(const IncompleteType& v1, const Type& v2); + + // returns representative of v's UF-group + Type find(const IncompleteType& v); + + void debug(); + + void AssertAlphaEqual(const Type& l, const Type& r); + + static constexpr const char* _type_key = "relay.UnionFind"; + TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); +}; + +class UnionFind : public NodeRef { + public: + UnionFind() {} + explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} + + // The union find structure is mutable so we do not use the standard macros + // and expose the pointer via `->`. + UnionFindNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = UnionFindNode; +}; + +class TypeUnifier; +class TypeUnifierNode : public Node, + private TypeFunctor { + public: + UnionFind uf; + + TypeUnifierNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf", &uf); } + + TVM_DLL static TypeUnifier make(UnionFind uf); + + /*! \brief Introduces a new type var into the unifier */ + void insert(const IncompleteType& v); + + /*! \brief Unifies two types if possible, throws a unification error if it + * cannot */ + Type unify(const Type& t1, const Type& t2); + + /*! \brief Attempts to substitute all type vars in t with concrete types, + * throws substitution error if it cannot concretize*/ + Type subst(const Type& t); + + // /*! \brief Checks the kinds in the given type */ + // Type CheckKinds(const Type& t); + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); + + private: + /*! \brief Unify incomplete type with another type. */ + Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + /*! \brief Implements unification between two types with incomplete portions. + */ + Type VisitType(const Type& t1, const Type t2) override; + + // Visitor Cases + Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; + Type VisitType_(const TensorTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeParamNode* t1, const Type t2) override; + Type VisitType_(const FuncTypeNode* t1, const Type t2) override; + Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeRelationNode* s1, const Type t2) override; + Type VisitType_(const TypeCallNode* s1, const Type t2) override; +}; + +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(std::shared_ptr p) : NodeRef(p) {} + + // no const so that unifier can be mutable as a member of typechecker + inline TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = TypeUnifierNode; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_UNIFIER_H_ diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc new file mode 100644 index 000000000000..9d3316cf38cf --- /dev/null +++ b/src/relay/source_map.cc @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.cc + * \brief Source maps for Relay. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceFragment::SourceFragment(const std::string& file_name, + const std::string& source) + : file_name(file_name), source_lines({}) { + RELAY_LOG(INFO) << "SourceFragment::SourceFragment source=" << source + << std::endl; + std::stringstream source_stream; + source_stream.str(source.c_str()); + std::string line; + + while (std::getline(source_stream, line)) { + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line + << std::endl; + std::string copy(line); + source_lines.push_back(copy); + } +} + +std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { + std::stringstream out; + + // We need to move from 1 based indexing to zero based indexing. + int starting_line = sp->lineno; + + if (starting_line >= static_cast(this->source_lines.size())) { + throw dmlc::Error("SourceFragment: index out of bounds"); + } + + auto lines = std::max(static_cast(max_lines), + source_lines.size() - starting_line); + + for (size_t i = 0; i < lines; i++) { + out << std::endl << this->source_lines.at(starting_line + i); + } + + auto source_slice = out.str(); + + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice + << std::endl; + return source_slice; +} + +SourceName SourceMap::AddSource(const std::string & file_name, const std::string & source) { + auto new_id = SourceNameNode::make(file_name); + SourceFragment sfile(file_name, source); + this->map_.insert({new_id, sfile}); + return new_id; +} + +const SourceFragment& SourceMap::GetSource(SourceName id) const { + auto item = map_.find(id); + if (item != map_.end()) { + return (*item).second; + } else { + throw dmlc::Error("could not find requested source fragment"); + } +} + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py new file mode 100644 index 000000000000..6c0e7779eae6 --- /dev/null +++ b/tests/python/relay/test_alpha_eq.py @@ -0,0 +1,573 @@ +"""Test alpha-equivalence of expressions and types.""" +# from relay.ir import alpha_eq, ShapeOp, Kind +# from relay.typing import TYPE_DEFAULTS +# from relay import ir + +# INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +# INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +# def int_type(width=32) -> ir.Type: +# return TensorType(IntType(width), ShapeSeq([])) + +# def float_type(width=32) -> ir.Type: +# return TensorType(FloatType(width), ShapeSeq([])) + +# def bool_type() -> ir.Type: +# return TensorType(BoolType(), ShapeSeq([])) + +# def nest_quantifiers(ids, body) -> ir.Type: +# ret = body +# for tid in reversed(ids): +# ret = TypeQuantifier(tid, ret) +# return ret + +# def test_local_id_not_eq() -> None: +# assert not alpha_eq(LocalId("x"), LocalId("y")) + +# def test_local_id_eq() -> None: +# x = LocalId("x") +# assert alpha_eq(x, x) + +# def test_global_id_not_eq() -> None: +# left = GlobalId("xyz") +# right = GlobalId("xyz") +# assert not alpha_eq(left, right) + +# def test_global_id_eq() -> None: +# ident = GlobalId("xyz") +# assert alpha_eq(ident, ident) + +# def test_operator_id_not_eq() -> None: +# left = OperatorId("xyz") +# right = OperatorId("xyz") +# # equality on operator id is pointer equality +# assert not alpha_eq(left, right) + +# def test_operator_id_eq() -> None: +# x = OperatorId("xyz") +# assert alpha_eq(x, x) + +# def test_float_literal_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(1.0) +# assert alpha_eq(x, y) + +# def test_float_literal_not_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(2.0) +# assert not alpha_eq(x, y) + +# def test_int_literal_eq() -> None: +# x = IntLit(1) +# y = IntLit(1) +# assert alpha_eq(x, y) + +# def test_int_literal_not_eq() -> None: +# x = IntLit(1) +# y = IntLit(2) +# assert not alpha_eq(x, y) + +# def test_bool_literal_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(True) +# assert alpha_eq(x, y) + +# def test_bool_literal_not_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(False) +# assert not alpha_eq(x, y) + +# def test_tensor_literal_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_tensor_literal_not_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(3)]) +# z = TensorLit([IntLit(1)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_product_literal_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_product_literal_not_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(2), IntLit(2)]) +# z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_projection_eq() -> None: +# prod = Tuple([IntLit(3), FloatLit(3.5)]) + +# assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) +# assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +# def test_projection_not_eq() -> None: +# prod1 = Tuple([IntLit(3), IntLit(4)]) +# prod2 = Tuple([IntLit(3)]) +# prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) + +# assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) +# assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +# def test_cast_not_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(1)) +# assert not alpha_eq(left, right) + +# # same literal, different type +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(2), IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_cast_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(2)) +# assert alpha_eq(left, right) + +# def test_param_not_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("foo"), bool_type()) +# assert not alpha_eq(left, right) + +# def test_param_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("bar"), int_type()) +# assert alpha_eq(left, right) + +# def test_function_not_eq() -> None: +# params1 = [Param(LocalId("x"), int_type())] +# fn1 = Function([], params1, int_type(), LocalId("x")) +# params2 = [Param(LocalId("y"), bool_type())] +# fn2 = Function([], params2, int_type(), LocalId("y")) +# assert not alpha_eq(fn1, fn2) + +# params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] +# fn3 = Function([], params3, int_type(), LocalId("z")) +# assert not alpha_eq(fn1, fn3) + +# def test_function_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# assert alpha_eq(fn1, fn2) + +# def test_call_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args1 = [IntLit(1)] +# call1 = Call(fn1, args1) + +# args2 = [IntLit(2)] +# call2 = Call(fn1, args2) +# assert not alpha_eq(call1, call2) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, float_type(), FloatLit(0.0)) +# call3 = Call(fn2, args1) +# assert not alpha_eq(call1, call3) +# assert not alpha_eq(call2, call3) + +# def test_call_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args = [IntLit(1)] +# call1 = Call(fn1, args) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# call2 = Call(fn2, args) +# assert alpha_eq(call1, call2) + +# def test_debug_not_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_debug_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(1)) +# assert alpha_eq(left, right) + +# def test_let_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), IntLit(11)) +# let2 = Let(y, int_type(), IntLit(10), IntLit(12)) +# assert not alpha_eq(let1, let2) + +# let3 = Let(x, int_type(), IntLit(10), x) +# let4 = Let(y, int_type(), IntLit(12), y) +# assert not alpha_eq(let3, let4) + +# def test_let_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), x) +# let2 = Let(y, int_type(), IntLit(10), y) +# assert alpha_eq(let1, let2) + +# def test_ref_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(IntLit(5)) +# assert alpha_eq(r1, r2) + +# def test_ref_not_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(FloatLit(3.5)) +# r3 = Ref(r1) +# assert not alpha_eq(r1, r2) +# assert not alpha_eq(r1, r3) +# assert not alpha_eq(r2, r3) + +# def test_val_ref_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(35))) +# vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) +# assert alpha_eq(vr1, vr1) +# assert alpha_eq(vr2, vr2) + +# def test_val_ref_not_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(5))) +# vr2 = ReadRef(Ref(vr1)) +# vr3 = ReadRef(Ref(FloatLit(5.0))) +# assert not alpha_eq(vr1, vr2) +# assert not alpha_eq(vr1, vr3) +# assert not alpha_eq(vr2, vr3) + +# def test_set_ref_eq() -> None: +# sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) +# sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), +# Tuple([IntLit(5), BoolLit(True)])) +# assert alpha_eq(sr1, sr1) +# assert alpha_eq(sr2, sr2) + +# def test_set_ref_not_eq() -> None: +# r1 = Ref(FloatLit(5.0)) +# r2 = Ref(IntLit(5)) +# r3 = Ref(IntLit(6)) + +# assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), +# WriteRef(r2, IntLit(6))) +# assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) +# assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) + +# # Type alpha-equality tests + +# def test_base_type_eq() -> None: +# assert alpha_eq(IntType(32), IntType(32)) +# assert alpha_eq(BoolType(), BoolType()) +# assert alpha_eq(FloatType(32), FloatType(32)) + +# def test_tensor_type_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert alpha_eq(tt1, tt1) +# assert alpha_eq(tt2, tt2) + +# def test_tensor_type_not_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt3 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert not alpha_eq(tt1, tt2) +# assert not alpha_eq(tt1, tt3) + +# def test_ref_type_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert alpha_eq(rt1, rt1) +# assert alpha_eq(rt2, rt2) + +# def test_ref_type_not_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert not alpha_eq(rt1, rt2) + +# def test_product_type_eq() -> None: +# pt1 = TupleType([int_type(), RefType(float_type())]) +# pt2 = TupleType([float_type(), float_type(), int_type()]) +# assert alpha_eq(pt1, pt1) +# assert alpha_eq(pt2, pt2) + +# def test_product_type_not_eq() -> None: +# pt1 = TupleType([int_type(), int_type()]) +# pt2 = TupleType([int_type(), int_type(), float_type()]) +# pt3 = TupleType([bool_type(), float_type()]) +# assert not alpha_eq(pt1, pt2) +# assert not alpha_eq(pt1, pt3) + +# def test_type_id_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id2", Kind.Type) + +# assert alpha_eq(id1, id1) +# assert alpha_eq(id2, id2) +# assert alpha_eq(id3, id3) + +# def test_type_id_not_eq() -> None: +# # name is just a hint, we use pointer equality as the rule +# # (unless there is a quantifier to give context) +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id1", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) + +# assert not alpha_eq(id1, id2) +# assert not alpha_eq(id1, id3) + +# def test_arrow_type_eq() -> None: +# ar1 = TypeArrow([int_type()], bool_type()) +# ar2 = TypeArrow([int_type(), int_type()], TupleType([])) +# assert alpha_eq(ar1, ar1) +# assert alpha_eq(ar2, ar2) + +# def test_arrow_type_not_eq() -> None: +# t1 = int_type() +# t2 = bool_type() +# t3 = [int_type(), bool_type()] + +# assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) +# assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) +# assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), +# TypeArrow([t1], t1)) + +# def test_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.Shape) +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_nested_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.BaseType) +# id2 = TypeParam("id2", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) +# id4 = TypeParam("id4", Kind.Shape) +# tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) +# tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_type_quantifier_not_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id3", Kind.Shape) + +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) +# tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) +# tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + +# assert not alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq2, tq4) + +# def test_shape_singleton_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(10) + +# assert alpha_eq(single1, single1) +# assert alpha_eq(single1, single2) + +# def test_shape_singelton_not_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(11) + +# assert not alpha_eq(single1, single2) + +# def test_shape_attr_eq() -> None: +# attr1 = ShapeAttr("x") +# attr2 = ShapeAttr("x") + +# assert alpha_eq(attr1, attr1) +# assert alpha_eq(attr1, attr2) + +# def test_shape_attr_not_eq() -> None: +# id1 = "x" +# id2 = "y" +# attr1 = ShapeAttr(id1) +# attr2 = ShapeAttr(id2) + +# assert not alpha_eq(attr1, attr2) + +# def test_shape_seq_eq() -> None: +# empty = ShapeSeq([]) +# seq1 = ShapeSeq([ShapeSingleton(5)]) +# seq2 = ShapeSeq([ShapeSingleton(5)]) + +# assert alpha_eq(empty, empty) +# assert alpha_eq(seq1, seq2) + +# def test_shape_seq_not_eq() -> None: +# empty = ShapeSeq([]) +# seq = ShapeSeq([ShapeSingleton(5)]) +# single = ShapeSingleton(5) + +# assert not alpha_eq(empty, seq) +# assert not alpha_eq(seq, single) + +# def test_shape_projection_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + +# assert alpha_eq(proj1, proj2) + +# def test_shape_projection_not_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) +# proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) +# proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + +# assert not alpha_eq(proj1, proj2) +# assert not alpha_eq(proj1, proj3) +# assert not alpha_eq(proj1, proj4) +# assert not alpha_eq(proj2, proj3) +# assert not alpha_eq(proj2, proj4) +# assert not alpha_eq(proj3, proj4) + +# def test_shape_binary_op_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) +# op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) +# op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) +# op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + +# assert alpha_eq(op1, op1) +# assert alpha_eq(op2, op2) +# assert alpha_eq(op3, op3) +# assert alpha_eq(op4, op4) + +# def test_shape_binary_op_not_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, single, single), +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([single]), +# ShapeSeq([single]))) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), +# ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), +# ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +# def test_shape_nested_in_quantifier() -> None: +# b1 = TypeParam("b", Kind.BaseType) +# x1 = TypeParam("x", Kind.Shape) +# y1 = TypeParam("y", Kind.Shape) + +# b2 = TypeParam("b", Kind.BaseType) +# x2 = TypeParam("x", Kind.Shape) +# y2 = TypeParam("y", Kind.Shape) + +# b3 = TypeParam("b", Kind.BaseType) +# x3 = TypeParam("x", Kind.Shape) +# y3 = TypeParam("y", Kind.Shape) + +# tq1 = nest_quantifiers( +# [b1, x1, y1], +# TypeArrow( +# [TensorType(b1, x1), TensorType(b1, y2)], +# TensorType( +# b1, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x1, ShapeProjection(y1, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq2 = nest_quantifiers( +# [b2, x2, y2], +# TypeArrow( +# [TensorType(b2, x2), TensorType(b2, y2)], +# TensorType( +# b2, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x2, ShapeProjection(y2, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# # different attr, var order, position, and constant +# tq3 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(4), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq4 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 2), +# ShapeSingleton(5), ShapeAttr("att2")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq5 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHMUL, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq6 = nest_quantifiers( +# [b3, y3, x3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# assert alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq4) +# assert not alpha_eq(tq1, tq5) +# assert not alpha_eq(tq2, tq5) +# assert not alpha_eq(tq1, tq6) +# assert not alpha_eq(tq2, tq6) diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py new file mode 100644 index 000000000000..666d7ff25659 --- /dev/null +++ b/tests/python/relay/test_ir_builder.py @@ -0,0 +1,23 @@ +import numpy as np +from tvm.relay.expr import Let, Constant +from tvm.relay.ir_builder import IRBuilder + +def test_let(): + b = IRBuilder() + x = b.let('x', 1) + b.ret(x) + prog = b.get() + assert isinstance(prog, Let) + var = prog.var + value = prog.value + assert var.name_hint == 'x' + assert var == prog.body + assert isinstance(value, Constant) + assert value.data.asnumpy() == np.array(1) + assert prog.value_type == None + +# def test_function(): +# b = IRBuilder() + +if __name__ == "__main__": + test_let() diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py new file mode 100644 index 000000000000..676aa347950b --- /dev/null +++ b/tests/python/relay/test_ir_nodes.py @@ -0,0 +1,159 @@ +""" test ir""" +import tvm +from tvm import relay +from tvm.expr import * + +# Span +def test_span() -> None: + span = relay.Span(None, 1, 1) + assert span.source == None + assert span.lineno == 1 + assert span.col_offset == 1 + assert span.same_as(span) + assert span == span + assert isinstance(span, relay.base.Span) + str(span) + +# Types + +def test_tensor_type() -> None: + shape = tvm.convert([1, 2, 3]) + dtype = 'float32' + tt = relay.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape + assert tt.span == None + str(tt) + + +def test_type_param() -> None: + tp = relay.TypeParam('name', relay.Kind.Shape) + tp.kind == relay.Kind.Shape + tp.span # TODO allow us to set span + str(tp) + + +def test_func_type() -> None: + type_params = tvm.convert([]) + type_constraints = tvm.convert([]) # TODO: fill me in + arg_types = tvm.convert([]) + ret_type = None + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) + assert tf.type_params == type_params + assert tf.type_constraints == type_constraints + assert tf.arg_types == arg_types + assert tf.ret_type == ret_type + assert tf.span == None + # TODO make sure we can set + str(tf) + + +def test_constant() -> None: + arr = tvm.nd.array(10) + const = relay.Constant(arr) + assert const.data == arr + assert const.span == None + str(const) + + +def test_tuple() -> None: + fields = tvm.convert([]) + tup = relay.Tuple(fields) + assert tup.fields == fields + assert tup.span == None + str(tup) + + +def test_local_var() -> None: + name_hint = 's' + lv = relay.LocalVar(name_hint) + lv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(lv) + + +def test_global_var() -> None: + name_hint = 'g' + gv = relay.GlobalVar(name_hint) + gv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(gv) + + +def test_param() -> None: + lv = relay.LocalVar('x') + ty = None + param = relay.Param(lv, ty) + assert param.var == lv + assert param.type == ty + assert param.span == None + str(param) + + +def test_function() -> None: + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([relay.Param(relay.LocalVar(n), None) for n in param_names]) + ret_type = None + body = None + type_params = tvm.convert([]) + fn = relay.Function(params, ret_type, body, type_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + + +def test_call() -> None: + op = relay.LocalVar('f') + arg_names = ['a', 'b', 'c', 'd'] + args = tvm.convert([relay.LocalVar(n) for n in arg_names]) + call = relay.Call(op, args, None, None) + assert call.op == op + assert call.args == args + assert call.span == None + str(call) + + +def test_let() -> None: + lv = relay.LocalVar('x') + ty = None + arr = tvm.nd.array(10) + value = relay.Constant(arr) + # I would prefer that the order of arguments + # matches syntax let x : t = v in b + let = relay.Let(lv, value, lv, ty) + assert let.var == lv + assert let.value == value + assert let.value_type == ty + assert let.body == lv + assert let.span == None + str(let) + + +def test_if() -> None: + cond = relay.LocalVar('cond') + left = relay.LocalVar('left') + right = relay.LocalVar('right') + ife = relay.If(cond, left, right) + assert ife.cond == cond + assert ife.true_value == left + assert ife.false_value == right + assert ife.span == None + str(ife) + + +if __name__ == "__main__": + test_span() + test_tensor_type() + test_type_param() + test_func_type() + test_constant() + test_tuple() + test_local_var() + test_global_var() + test_param() + test_function() + test_call() + test_let() + test_if() diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py new file mode 100644 index 000000000000..1f95a3f72c15 --- /dev/null +++ b/tests/python/relay/test_relay_op.py @@ -0,0 +1,27 @@ +from tvm import relay + +def test_op_attr(): + log_op = relay.op.get("log") + + @relay.op.register("exp", "ftest") + def test(x): + return x + 1 + + assert log_op.num_inputs == 1 + assert log_op.get_attr("ftest") is None + assert relay.op.get("exp").get_attr("ftest")(1) == 2 + +def test_op_level1(): + x = relay.Var("x") + + for op_name in ["log", "exp", "sqrt"]: + y = getattr(relay, op_name)(x) + assert y.op.name == op_name + assert y.op.support_level == 1 + assert y.args[0] == x + + +if __name__ == "__main__": + test_op_attr() + test_op_level1() + diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py new file mode 100644 index 000000000000..f9a3d098a3e2 --- /dev/null +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -0,0 +1,180 @@ +"""Test that type checker correcly computes types + for expressions. +""" +import tvm +import numpy as np +from nnvm import graph +from tvm.relay.ir_pass import check_expr +from tvm.relay.ir_builder import IRBuilder, float_type, int_type +from tvm.relay.ir_builder import func_type, tensor_type, into_ast +from tvm.relay.env import Environment +from tvm.relay.ir_pass import Monomorphize +from tvm.relay.op import log, add, equal, subtract +from tvm.relay.expr import Function +from tvm.relay import to_tvm +from tvm.contrib import graph_runtime +import nnvm + + +def has_type(expr, typ, env=Environment({})): + checked_expr = check_expr(env, expr) + return checked_expr.checked_type() == typ + + +def decl_has_type(env, name, typ): + func = env.lookup(name) + return func.checked_type() == typ + + +def run(env, expr, inputs, shape): + if not isinstance(expr, Function): + expr = Function([], None, expr, []) + + env.add("main", expr) + env.transform(Monomorphize.to_pass()) + main = env.lookup("main") + graph, lib, _ = to_tvm.compile_to_tvm(main) + # We use NNVM to load the graph right now because it populates node_row_ptr field. + nnvm_graph = nnvm.graph.load_json(graph) + module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) + module.set_input(None, None, **inputs) + module.run() + out_nd_array = tvm.nd.array(np.empty(shape, dtype='float32')) + return module.get_output(0, out=out_nd_array) + + +def test_monomorphic_let(): + "Program: let x = 1; return x" + b = IRBuilder() + x = b.let('x', 1.0, value_type=float_type(64)) + b.ret(x) + + prog, env = b.get() + assert has_type(prog, float_type(64)) + run(env, prog, [], float_type(64)) + + +def test_single_op(): + "Program: fn (x : float32) { let t1 = f(x); t1 }" + b = IRBuilder() + with b.function(('x', float_type())) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + b.ret(t1) + assert has_type(func.to_func(), func_type([float_type()], float_type())) + + +def test_add_op(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 5, 5)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) + +def test_add_broadcast_op(): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + x_data = tvm.nd.array(np.random.rand(10, 4).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 10, 1).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) + +def test_dual_op(): + """Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + """ + b = IRBuilder() + with b.function(('x', tensor_type(10, 10))) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) + assert has_type(func.to_func(), func_type([float_type()], float_type())) + + +def test_decl(): + """Program: + def f(x : Tensor[f32, (10, 10)]) { + let lx = log(x); + return lx; + } + """ + b = IRBuilder() + x = b.param('x') + with b.decl('f', x): + lx = b.let('lx', log(x)) + b.ret(lx) + _, env = b.get() + assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + + +def test_recursion(): + """ + Program: + def f(n: i32, data: f32) -> f32 { + if (n == 0) { + return f(n - 1, log(data)); + } else { + return data; + } + } + f(2, 10000); + """ + b = IRBuilder() + f = b.global_var('f') + n = b.param('n', ty=int_type()) + data = b.param('data', ty=float_type()) + with b.decl(f, n, data): + with b.if_scope(equal(n, into_ast(0.0))): + b.ret(f(subtract(n, into_ast(1)), log(data))) + with b.else_scope(): + b.ret(data) + b.ret(f(into_ast(2.0), into_ast(10000.0))) + assert decl_has_type(b.env, 'f', func_type( + [int_type(), float_type()], float_type())) + # TODO(@jroesch): need evaluator or new runtime + # to execute this. + +if __name__ == "__main__": + test_monomorphic_let() + test_single_op() + test_add_op() + test_add_broadcast_op() + # test_dual_op() + # test_decl() + # test_recursion()