Skip to content

Commit

Permalink
Fixes and improvements (apache#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and yongwww committed Aug 14, 2022
1 parent 56447db commit f793ccc
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 86 deletions.
27 changes: 1 addition & 26 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAX_BLOCK_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
Expand All @@ -38,32 +39,6 @@ namespace relax {

class BlockBuilder;

/*!
* \brief Utility data structure for generating unique names for IR construction.
*/
class NameTable {
public:
/*!
* \brief Generate a unique name with a specified prefix.
* \param prefix The name prefix.
* \return The generated name.
*/
inline std::string GetUniqueName(std::string prefix) {
std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = alloc_map_.find(prefix);
if (it != alloc_map_.end()) {
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
}
}
alloc_map_[unique_prefix] = 0;
return unique_prefix;
}

private:
std::unordered_map<std::string, uint32_t> alloc_map_;
};

/*!
* \brief A builder that provides APIs to build Relax binding blocks.
*/
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ShapeExprNode : public ExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("values", &values);
v->Visit("shape_", &shape_);
v->Visit("checked_type_", &checked_type_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -94,11 +94,11 @@ class VarNode : public ExprNode {
const String& name_hint() const { return vid->name_hint; }

void VisitAttrs(AttrVisitor* v) {
v->Visit("_checked_type_", &checked_type_);
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("shape_", &shape_);
v->Visit("checked_type_", &checked_type_);
}

bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -143,7 +143,7 @@ class DataflowVarNode : public VarNode {
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("shape_", &shape_);
v->Visit("checked_type_", &checked_type_);
v->Visit("_checked_type_", &checked_type_);
}

bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -330,7 +330,7 @@ class SeqExprNode : public ExprNode {
v->Visit("blocks", &blocks);
v->Visit("body", &body);
v->Visit("shape_", &shape_);
v->Visit("checked_type_", &checked_type_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("span", &span);
}

Expand Down Expand Up @@ -378,7 +378,7 @@ class FunctionNode : public BaseFuncNode {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("checked_type_", &checked_type_);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
v->Visit("span", &span);
}
Expand Down
47 changes: 35 additions & 12 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
*
* ExprVisitor treats Expr as dataflow graph,
* and only visit each Expr node once.
*/
class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
public:
Expand All @@ -167,9 +164,6 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
virtual void VisitMatchShape(const MatchShape& binding);
virtual void VisitBindingBlock(const BindingBlock& block);
virtual void VisitDataflowBlock(const DataflowBlock& block);

protected:
std::unordered_map<const Object*, size_t> visit_counter_;
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -221,19 +215,48 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
virtual Type VisitType(const Type& t);

virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding);
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
/*! \brief Look up the value binded to a var. */

/*! \brief Look up the value of a variable. If the variable is bound, then returns the bound
* value. Otherwise, returns the rewritten expression for the variable.
*/
Expr LookupVar(Var var);
// A remapping table: pre var -> post var
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> memo_;

inline void UpdateMemo(Expr pre, Expr post) {
if (const VarNode* var = pre.as<VarNode>()) {
var_memo_[var->vid] = post;
} else {
expr_memo_[pre] = post;
}
}

inline Optional<Expr> LookupMemo(Expr pre) {
if (pre.as<VarNode>()) {
Id vid = Downcast<Var>(pre)->vid;
if (var_memo_.count(vid)) {
return var_memo_[vid];
}
} else {
if (expr_memo_.count(pre)) {
return expr_memo_[pre];
}
}
return NullOpt;
}

/*! \brief Variable memoization table using Id equality */
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;

/*! \brief Expr memoization table using pointer equality */
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;

std::shared_ptr<NameTable> name_table_;
BlockBuilder builder_;
};
Expand All @@ -245,7 +268,7 @@ class DataflowMutator : public ExprMutator {
public:
void VisitBinding(const Binding& binding) final;

virtual Var VisitDataflowVarBinding(const VarBinding& binding);
virtual void VisitDataflowVarBinding(const VarBinding& binding);
};

} // namespace relax
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relax/ir_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

/*!
* \file tvm/relax/ir_functor.h
* \brief A generic visitor for traversing Relax IR nodes.
* \brief A generic functor for working with Relax IR nodes.
* \sa tvm/relax/expr_functor.h for common IR rewriting use-cases.
*/
#ifndef TVM_RELAX_IR_FUNCTOR_H_
#define TVM_RELAX_IR_FUNCTOR_H_
Expand Down
63 changes: 63 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/utils.h
* \brief Utility classes and functions for working with the Relax IR.
*/
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_

#include <string>
#include <algorithm>
#include <unordered_map>

namespace tvm {
namespace relax {

/*!
* \brief Utility data structure for generating unique names for IR construction.
*/
class NameTable {
public:
/*!
* \brief Generate a unique name with a specified prefix.
* \param prefix The name prefix.
* \return The generated name.
*/
inline std::string GetUniqueName(std::string prefix) {
std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = alloc_map_.find(prefix);
if (it != alloc_map_.end()) {
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
}
}
alloc_map_[unique_prefix] = 0;
return unique_prefix;
}

private:
std::unordered_map<std::string, uint32_t> alloc_map_;
};

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_UTILS_H_
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _traverse_expr(node):
else:
node_entry["inputs"].append([in_node_idx, 0, 0])
infer_out = _infer_type(node)
out_type = infer_out._checked_type_
out_type = infer_out.checked_type_
if isinstance(out_type, TensorType):
node_entry["types"].append(out_type)
elif isinstance(out_type, TupleType):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/target/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def infer_type(node):
def call_node_infer_type(node):
"""infer the output types of call node"""
infer_out = infer_type(node)
out_type = infer_out._checked_type_
out_type = infer_out.checked_type_
if isinstance(out_type, TensorType):
types = [out_type]
elif isinstance(out_type, TupleType):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def checked_type(self):
checked_type : tvm.relay.Type
The checked type.
"""
ret = self.checked_type_
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated the checked_type for this node")
return ret
Expand Down
4 changes: 2 additions & 2 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/

#include <tvm/ir/type_functor.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/utils.h>
#include <tvm/relax/ir_functor.h>

#include <algorithm>
Expand Down Expand Up @@ -397,7 +397,7 @@ std::vector<Doc> RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) {
}
} else {
AttrPrinter attr_printer(&kwargs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&attr_printer);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitAttrs(&attr_printer);
}
return kwargs;
}
Expand Down
Loading

0 comments on commit f793ccc

Please sign in to comment.