Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Add ShapePattern and DataTypePattern #5760

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,64 @@ class TypePattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
};

class ShapePattern;
/*!
* \brief Pattern for Shapes.
*/
class ShapePatternNode : public DFPatternNode {
public:
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The type to match */
Array<PrimExpr> shape;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("shape", &shape);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode);
};

/*!
* \brief A pattern which matches a type in another pattern
*/
class ShapePattern : public DFPattern {
public:
TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
};

class DataTypePattern;
/*!
* \brief Pattern for Types.
*/
class DataTypePatternNode : public DFPatternNode {
public:
/*! \brief The pattern. */
DFPattern pattern;
/*! \brief The type to match */
DataType dtype;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("dtype", &dtype);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode);
};

/*!
* \brief A pattern which matches a type in another pattern
*/
class DataTypePattern : public DFPattern {
public:
TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
};

class AttrPattern;
/*!
* \brief Pattern for Attributes.
Expand Down
10 changes: 8 additions & 2 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand All @@ -106,13 +108,15 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
return vtable;
}
Expand All @@ -130,13 +134,15 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const AltPatternNode* op) override;
void VisitDFPattern_(const AttrPatternNode* op) override;
void VisitDFPattern_(const CallPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const DataTypePatternNode* op) override;
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;

protected:
Expand Down
120 changes: 115 additions & 5 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ def has_type(self, ttype: tvm.ir.type.Type):
"""
return has_type(ttype, self)

def has_dtype(self, dtype: str):
"""
Add a type constraint to this pattern

Parameters
----------
dtype: str
The dtype to match

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting DataTypePattern
"""
return has_dtype(dtype, self)

def has_shape(self, shape: List[tvm.ir.PrimExpr]):
"""
Add a type constraint to this pattern

Parameters
----------
shape: List[tvm.ir.PrimExpr]
The shape to match

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting ShapePattern
"""
return has_shape(shape, self)

def match(self, expr: Expr) -> bool:
"""
Match this pattern to an expression
Expand Down Expand Up @@ -293,18 +325,18 @@ def wildcard() -> "DFPattern":
return WildcardPattern()


def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a TypePattern

Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation

ttype: tvm.ir.type.Type
The type to match

pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
Expand All @@ -315,6 +347,50 @@ def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
return TypePattern(pattern, ttype)


def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a DataTypePattern

Parameters
----------
dtype: str
The dtype to match

pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting DataTypePattern
"""
if pattern is None:
pattern = wildcard()
return DataTypePattern(pattern, dtype)


def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a ShapePattern

Parameters
----------
shape: List[tvm.ir.PrimExpr]
The shape to match

pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting ShapePattern
"""
if pattern is None:
pattern = wildcard()
return ShapePattern(pattern, shape)


def has_attr(attrs, pattern=None) -> "DFPattern":
"""
Syntatic sugar for creating an AttrPattern
Expand Down Expand Up @@ -514,7 +590,7 @@ def __init__(self):

@register_df_node
class TypePattern(DFPattern):
"""Get index-th item from a TuplePattern.
"""A pattern that matches another pattern with a certain type annotation.

Parameters
----------
Expand All @@ -529,6 +605,40 @@ def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type):
self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype)


@register_df_node
class DataTypePattern(DFPattern):
"""A pattern that matches another pattern with certain data type

Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern that needs type annotation.

dtype: str
The dtype to match.
"""

def __init__(self, pattern: "DFPattern", dtype: str):
self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype)


@register_df_node
class ShapePattern(DFPattern):
"""A pattern that matches another pattern with a certain tensor shape

Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern that needs type annotation.

shape: List[tvm.ir.PrimExpr]
The shape to match.
"""

def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]):
self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)


@register_df_node
class AttrPattern(DFPattern):
"""Get match an expression with a certain attributes.
Expand Down
20 changes: 19 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

void ClearMap(size_t watermark);
Expand Down Expand Up @@ -393,6 +395,22 @@ bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& ex
return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
}

bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) {
auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr);
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) {
auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr);
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* var_node = expr.as<VarNode>()) {
Expand Down
40 changes: 40 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")";
});

ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
n->pattern = std::move(pattern);
n->shape = std::move(shape);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(ShapePatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
.set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) {
return ShapePattern(pattern, shape);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ShapePatternNode*>(ref.get());
p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")";
});

DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
n->pattern = std::move(pattern);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DataTypePatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
.set_body_typed([](DFPattern pattern, DataType dtype) {
return DataTypePattern(pattern, dtype);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const DataTypePatternNode*>(ref.get());
p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
});

AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
n->pattern = std::move(pattern);
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
VisitDFPattern(arg);
}
}

void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) {
VisitDFPattern(op->pattern);
}

void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
VisitDFPattern(op->parent);
VisitDFPattern(op->path);
Expand All @@ -57,6 +62,8 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {

void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
VisitDFPattern(op->tuple);
}
Expand Down
Loading