Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Add Syntatic Sugar to the C++ pattern API and support DataType Attribute Matching #7120

Merged
merged 5 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>

#include <string>
#include <vector>

namespace tvm {
namespace relay {

Expand All @@ -46,6 +49,29 @@ class DFPatternNode : public Object {
*/
class DFPattern : public ObjectRef {
public:
/*! \brief Syntatic Sugar for creating a CallPattern */
DFPattern operator()(const std::vector<DFPattern>& args);
/*! \brief Syntatic Sugar for creating a CallPattern with an "add" op */
icemelon marked this conversation as resolved.
Show resolved Hide resolved
DFPattern operator+(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "subtract" op */
DFPattern operator-(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "multiply" op */
DFPattern operator*(const DFPattern& other);
/*! \brief Syntatic Sugar for creating a CallPattern with a "divide" op */
DFPattern operator/(const DFPattern& other);
/*! \brief Syntatic Sugar for creating an AltPattern */
DFPattern operator||(const DFPattern& other);
/*! \brief Syntatic Sugar for creating an AttrPattern */
DFPattern HasAttr(const Map<String, ObjectRef>& attrs);
/*! \brief Syntatic Sugar for creating a TypePattern */
DFPattern HasType(const Type& type);
/*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
DFPattern HasDtype(const DataType& dtype);
/*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */
DFPattern HasDtype(const std::string& dtype);
/*! \brief Syntatic Sugar for creating a ShapePattern */
DFPattern HasShape(const Array<PrimExpr> shape);

TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
};

Expand Down Expand Up @@ -86,28 +112,19 @@ class VarPatternNode : public DFPatternNode {
* \brief The name of the Var (optional).
*/
String name;
/*!
* \brief type annotation of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;

/*! \return The name hint of the variable */
const String& name_hint() const { return name; }

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_annotation", &type_annotation);
}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
};

class VarPattern : public DFPattern {
public:
TVM_DLL VarPattern(String name_hint, Type type_annotation);
TVM_DLL VarPattern(String name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
};

Expand Down Expand Up @@ -393,7 +410,7 @@ class AttrPatternNode : public DFPatternNode {
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The attribute to match */
Attrs attrs;
DictAttrs attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
Expand All @@ -409,7 +426,7 @@ class AttrPatternNode : public DFPatternNode {
*/
class AttrPattern : public DFPattern {
public:
TVM_DLL AttrPattern(DFPattern pattern, Attrs attrs);
TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode);
};

Expand Down Expand Up @@ -447,6 +464,19 @@ class DominatorPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode);
};

/*! \brief Syntatic Sugar for creating a VarPattern with a name */
DFPattern IsVar(const String& name);
/*! \brief Syntatic Sugar for creating a ConstantPattern */
DFPattern IsConstant();
/*! \brief Syntatic Sugar for creating a ExprPattern */
DFPattern IsExpr(const Expr& expr);
/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/
DFPattern IsOp(const String& op_name);
/*! \brief Syntatic Sugar for creating a TuplePattern*/
DFPattern IsTuple(const Array<DFPattern>& fields);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/
DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_DATAFLOW_PATTERN_H_
4 changes: 2 additions & 2 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ class VarPattern(DFPattern):
The type annotation on the variable.
"""

def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None):
self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation)
def __init__(self, name_hint: str = ""):
self.__init_handle_by_constructor__(ffi.VarPattern, name_hint)


@register_df_node
Expand Down
14 changes: 12 additions & 2 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return val->data == rhs.operator std::string();
}
break;
case kTVMDataType:
if (auto* val = lhs.as<tir::StringImmNode>()) {
return rhs.operator std::string() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator std::string() == val->data;
}
break;
case kTVMObjectHandle:
if (rhs.IsObjectRef<String>()) {
if (auto* val = lhs.as<tir::StringImmNode>()) {
Expand All @@ -140,7 +147,10 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
}

bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
bool matches = VisitDFPattern(attr_pattern->pattern, expr);
if (!matches) {
return matches;
}
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
Expand Down Expand Up @@ -179,7 +189,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
}
}
}
return matches && VisitDFPattern(attr_pattern->pattern, expr);
return matches;
}

Array<DFPattern> reverse(const Array<DFPattern>& args) {
Expand Down
67 changes: 53 additions & 14 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief The dataflow pattern language for Relay.
*/
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/runtime/data_type.h>

namespace tvm {
namespace relay {
Expand All @@ -44,29 +45,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(node->expr);
});

VarPattern::VarPattern(String name_hint, Type type_annotation) {
VarPattern::VarPattern(String name_hint) {
ObjectPtr<VarPatternNode> n = make_object<VarPatternNode>();
n->name = std::move(name_hint);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(VarPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern")
.set_body_typed([](String name_hint, Type type_annotation) {
return VarPattern(name_hint, type_annotation);
});
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern").set_body_typed([](String name_hint) {
return VarPattern(name_hint);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const VarPatternNode*>(ref.get());
p->stream << "VarPattern(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->Print(node->type_annotation);
}
p->stream << ")";
p->stream << "VarPattern(" << node->name_hint() << ")";
});

TVM_REGISTER_NODE_TYPE(ConstantPatternNode);
Expand Down Expand Up @@ -241,7 +235,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
});

AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) {
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
n->pattern = std::move(pattern);
n->attrs = std::move(attrs);
Expand All @@ -251,7 +245,7 @@ AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
TVM_REGISTER_NODE_TYPE(AttrPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern")
.set_body_typed([](DFPattern pattern, Attrs attrs) { return AttrPattern(pattern, attrs); });
.set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttrPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand All @@ -263,6 +257,7 @@ DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern c
ObjectPtr<DominatorPatternNode> n = make_object<DominatorPatternNode>();
n->parent = std::move(parent);
n->path = std::move(path);

n->child = std::move(child);
data_ = std::move(n);
}
Expand All @@ -281,5 +276,49 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});

// Syntatic Sugar
DFPattern DFPattern::operator()(const std::vector<DFPattern>& args) {
return CallPattern(GetRef<DFPattern>(this->get()), Array<DFPattern>(args));
}
DFPattern DFPattern::operator+(const DFPattern& other) {
return IsOp("add")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator-(const DFPattern& other) {
return IsOp("subtract")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator*(const DFPattern& other) {
return IsOp("multiply")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator/(const DFPattern& other) {
return IsOp("divide")({GetRef<DFPattern>(this->get()), other});
}
DFPattern DFPattern::operator||(const DFPattern& other) {
return AltPattern(GetRef<DFPattern>(this->get()), other);
}

DFPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) {
return AttrPattern(GetRef<DFPattern>(this->get()), DictAttrs(attrs));
}
DFPattern DFPattern::HasType(const Type& type) {
return TypePattern(GetRef<DFPattern>(this->get()), type);
}
DFPattern DFPattern::HasDtype(const DataType& dtype) {
return DataTypePattern(GetRef<DFPattern>(this->get()), dtype);
}
DFPattern DFPattern::HasDtype(const std::string& dtype) {
return HasDtype(DataType(runtime::String2DLDataType(dtype)));
}
DFPattern DFPattern::HasShape(const Array<PrimExpr> shape) {
return ShapePattern(GetRef<DFPattern>(this->get()), shape);
}
DFPattern IsVar(const String& name) { return VarPattern(name); }
DFPattern IsConstant() { return ConstantPattern(make_object<ConstantPatternNode>()); }
DFPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
DFPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
DFPattern IsTuple(const Array<DFPattern>& fields) { return TuplePattern(fields); }
DFPattern IsTupleGetItem(const DFPattern tuple, int index) {
return TupleGetItemPattern(tuple, index);
}

} // namespace relay
} // namespace tvm
9 changes: 3 additions & 6 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
namespace tvm {
namespace relay {

static Op reshape_op = Op::Get("reshape");
static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");

/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
Expand All @@ -44,9 +41,9 @@ class SimplifyReshape {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})});
auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
pattern_ = reshape1({reshape2({x_})});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
Expand Down
Loading