Skip to content

Commit

Permalink
Add ShapePattern and DataTypePattern (apache#5760)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 12, 2020
1 parent 14157dd commit 56dfec2
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 10 deletions.
58 changes: 58 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,64 @@ class TypePattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
};

class ShapePattern;
/*!
* \brief Pattern for Shapes.
*/
class ShapePatternNode : public DFPatternNode {
public:
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The type to match */
Array<PrimExpr> shape;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("shape", &shape);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode);
};

/*!
* \brief A pattern which matches a type in another pattern
*/
class ShapePattern : public DFPattern {
public:
TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
};

class DataTypePattern;
/*!
* \brief Pattern for Types.
*/
class DataTypePatternNode : public DFPatternNode {
public:
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The type to match */
DataType dtype;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("dtype", &dtype);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode);
};

/*!
* \brief A pattern which matches a type in another pattern
*/
class DataTypePattern : public DFPattern {
public:
TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
};

class AttrPattern;
/*!
* \brief Pattern for Attributes.
Expand Down
10 changes: 8 additions & 2 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -106,13 +108,15 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
return vtable;
}
Expand All @@ -130,13 +134,15 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const AltPatternNode* op) override;
void VisitDFPattern_(const AttrPatternNode* op) override;
void VisitDFPattern_(const CallPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const DataTypePatternNode* op) override;
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;

protected:
Expand Down
120 changes: 115 additions & 5 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ def has_type(self, ttype: tvm.ir.type.Type):
"""
return has_type(ttype, self)

def has_dtype(self, dtype: str):
"""
Add a type constraint to this pattern
Parameters
----------
dtype: str
The dtype to match
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting DataTypePattern
"""
return has_dtype(dtype, self)

def has_shape(self, shape: List[tvm.ir.PrimExpr]):
"""
Add a type constraint to this pattern
Parameters
----------
shape: List[tvm.ir.PrimExpr]
The shape to match
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting ShapePattern
"""
return has_shape(shape, self)

def match(self, expr: Expr) -> bool:
"""
Match this pattern to an expression
Expand Down Expand Up @@ -293,18 +325,18 @@ def wildcard() -> "DFPattern":
return WildcardPattern()


def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a TypePattern
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
ttype: tvm.ir.type.Type
The type to match
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
Expand All @@ -315,6 +347,50 @@ def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
return TypePattern(pattern, ttype)


def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a DataTypePattern
Parameters
----------
dtype: str
The dtype to match
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting DataTypePattern
"""
if pattern is None:
pattern = wildcard()
return DataTypePattern(pattern, dtype)


def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a ShapePattern
Parameters
----------
shape: List[tvm.ir.PrimExpr]
The shape to match
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting ShapePattern
"""
if pattern is None:
pattern = wildcard()
return ShapePattern(pattern, shape)


def has_attr(attrs, pattern=None) -> "DFPattern":
"""
Syntatic sugar for creating an AttrPattern
Expand Down Expand Up @@ -514,7 +590,7 @@ def __init__(self):

@register_df_node
class TypePattern(DFPattern):
"""Get index-th item from a TuplePattern.
"""A pattern that matches another pattern with a certain type annotation.
Parameters
----------
Expand All @@ -529,6 +605,40 @@ def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type):
self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype)


@register_df_node
class DataTypePattern(DFPattern):
"""A pattern that matches another pattern with certain data type
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern that needs type annotation.
dtype: str
The dtype to match.
"""

def __init__(self, pattern: "DFPattern", dtype: str):
self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype)


@register_df_node
class ShapePattern(DFPattern):
"""A pattern that matches another pattern with a certain tensor shape
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern that needs type annotation.
shape: List[tvm.ir.PrimExpr]
The shape to match.
"""

def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]):
self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)


@register_df_node
class AttrPattern(DFPattern):
"""Get match an expression with a certain attributes.
Expand Down
20 changes: 19 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

void ClearMap(size_t watermark);
Expand Down Expand Up @@ -393,6 +395,22 @@ bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& ex
return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
}

bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) {
auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr);
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) {
auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr);
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* var_node = expr.as<VarNode>()) {
Expand Down
40 changes: 40 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")";
});

ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
n->pattern = std::move(pattern);
n->shape = std::move(shape);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(ShapePatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
.set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) {
return ShapePattern(pattern, shape);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ShapePatternNode*>(ref.get());
p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")";
});

DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
n->pattern = std::move(pattern);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DataTypePatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
.set_body_typed([](DFPattern pattern, DataType dtype) {
return DataTypePattern(pattern, dtype);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const DataTypePatternNode*>(ref.get());
p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
});

AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
n->pattern = std::move(pattern);
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
VisitDFPattern(arg);
}
}

void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) {
VisitDFPattern(op->pattern);
}

void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
VisitDFPattern(op->parent);
VisitDFPattern(op->path);
Expand All @@ -57,6 +62,8 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {

void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
VisitDFPattern(op->tuple);
}
Expand Down
Loading

0 comments on commit 56dfec2

Please sign in to comment.