Skip to content

Commit

Permalink
[ARITH] detect iter affine map with predicate (#7752)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored Mar 27, 2021
1 parent 474bc4e commit dc81767
Show file tree
Hide file tree
Showing 9 changed files with 547 additions and 156 deletions.
10 changes: 10 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,16 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
/*!
* \brief Whether can we prove lhs == rhs.
*
* \param lhs The input lhs.
* \param rhs The input rhs.
* \return Whether we can prove lhs == rhs.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
/*!
* \brief Whether can we prove condition.
*
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class IterMark : public ObjectRef {
TVM_DLL IterMark(PrimExpr source, PrimExpr extent);

TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode);
};

/*!
Expand Down Expand Up @@ -259,7 +260,6 @@ class IterSumExpr : public IterMapExpr {

/*!
* \brief Detect if indices can be written as
*
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters)
Expand All @@ -272,12 +272,15 @@ class IterSumExpr : public IterMapExpr {
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);

} // namespace arith
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constrain
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Calculate the expresion complexity based on number of symbols it contains.
* \param expr The expr to be calculated.
*/
TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,27 @@ def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)


def detect_iter_map(indices, input_iters):
"""Detect if indices can be written mapped iters from input_iters.
def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False):
"""Detect if indices can be written as mapped iters from input iters
Parameters
----------
indices : List[PrimExpr]
The input indices.
The input indices
input_iters : Map[Var, Range]
The domain of each input iterators.
predicate : PrimExpr
The predicate constraints on the input iterators
require_bijective : bool
A boolean flag that indicates whether the mapping should be bijective
Returns
-------
results : List[IterSumExpr]
The iter map matching result.
Empty array if no match can be found.
"""
return _ffi_api.DetectIterMap(indices, input_iters)
return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective)
7 changes: 7 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
return false;
}

bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) return clhs->value == crhs->value;
return CanProve(lhs - rhs == 0);
}

bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
Expand Down
53 changes: 53 additions & 0 deletions src/arith/expr_complexity.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/analysis/expr_complexity.cc
* \brief Calculate expr complexity.
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr_functor.h>

namespace tvm {
namespace tir {

/*! \brief Count the size of the PrimExpr. */
class PrimExprSizeCounter : public ExprVisitor {
public:
PrimExprSizeCounter() = default;

static size_t Count(const PrimExpr& expr) {
PrimExprSizeCounter prim_expr_size_counter;
prim_expr_size_counter.VisitExpr(expr);
return prim_expr_size_counter.counter_;
}

private:
void VisitExpr(const PrimExpr& expr) final {
counter_++;
ExprVisitor::VisitExpr(expr);
}

size_t counter_{0};
};

size_t CalculateExprComplexity(const PrimExpr& expr) { return PrimExprSizeCounter::Count(expr); }

} // namespace tir
} // namespace tvm
Loading

0 comments on commit dc81767

Please sign in to comment.