diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index bc0fcc9f2988..c1eebde15fba 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -86,6 +86,9 @@ namespace relay { } else if (type == DataType::UInt(8)) { \ typedef uint8_t DType; \ { __VA_ARGS__ } \ + } else if (type == DataType::Bool()) { \ + typedef bool DType; \ + { __VA_ARGS__ } \ } else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \ static_cast(type.code()))) { \ typedef double DType; \ diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 3d925bcfc759..423f0a4f213d 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -117,7 +117,7 @@ def after_right(x, elem_op, value): assert tvm.ir.structural_equal(zz, after) for shape in [[10], [10, 10], [10, 10, 10]]: - for dtype in ["float32", "int32"]: + for dtype in ["float32", "int32", "bool"]: for value in [0, 1, 2]: validate(shape, value, dtype)