Skip to content

Commit

Permalink
[MetaSchedule] Sample-Perfect-Tile (apache#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Nov 4, 2021
1 parent 8f58137 commit 8b69550
Show file tree
Hide file tree
Showing 13 changed files with 491 additions and 46 deletions.
10 changes: 10 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ class ScheduleNode : public runtime::Object {
*/
virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
Optional<Integer> decision = NullOpt) = 0;
/*!
* \brief Sample the factors to perfect tile a specific loop
* \param loop_rv The loop to be tiled
* \param n The number of tiles to be sampled
* \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop
* \param decision The sampling decision
* \return A list of length `n`, the random perfect tile sizes sampled
*/
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) = 0;

/******** Schedule: Get blocks & loops ********/
/*!
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,39 @@ def sample_categorical(
decision,
)

def sample_perfect_tile(
self,
loop: LoopRV,
n: int,
max_innermost_factor: int = 16,
decision: Optional[List[int]] = None,
) -> List[ExprRV]:
"""Sample the factors to perfect tile a specific loop
Parameters
----------
loop : LoopRV
The loop to be tiled
n : int
The number of tiles to be sampled
max_innermost_factor : int
The maximum tile size allowed to be sampled in the innermost loop
decision: Optional[List[int]]
The sampling decision, if any
Returns
-------
result : List[ExprRV]
A list of length `n`, the random perfect tile sizes sampled
"""
return _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member
self,
loop,
n,
max_innermost_factor,
decision,
)

########## Schedule: Get blocks & loops ##########
def get_block(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
if (const ForNode* loop = p->StmtAs<ForNode>()) {
if (loop->kind == ForKind::kThreadBinding) {
const String& thread_tag = loop->thread_binding.value()->thread_tag;
if (CanRelaxStorageUndereThread(extra_relax_scope,
runtime::ThreadScope::Create(thread_tag))) {
if (CanRelaxStorageUnderThread(extra_relax_scope,
runtime::ThreadScope::Create(thread_tag))) {
result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates,
throw;
}

Array<ExprRV> ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n,
int max_innermost_factor,
Optional<Array<Integer>> decision) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
max_innermost_factor, &decision));
TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_);
throw;
}

/******** Schedule: Get blocks & loops ********/

BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
Expand Down
25 changes: 17 additions & 8 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,10 @@ class ConcreteScheduleNode : public ScheduleNode {

public:
/******** Schedule: Sampling ********/
/*!
* \brief Sample an integer given the probability distribution
* \param candidates The candidates
* \param probs The probability distribution of the candidates
* \param decision The sampling decision, if it's given we would validate the decision, otherwise
* we would sample a decision from the distribution and set the decision accordingly.
* \return The random variable sampled from candidates
*/
ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
Optional<Integer> decision = NullOpt) override;
Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) override;
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Expand Down Expand Up @@ -162,6 +156,12 @@ class ConcreteScheduleNode : public ScheduleNode {
* \return The new random variable created
*/
inline ExprRV CreateRV(int64_t value);
/*!
* \brief Add a list of integers as random variables into the symbol table
* \param value The list of integers to be added to the symbol table
* \return The new random variables created
*/
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
};
Expand Down Expand Up @@ -295,6 +295,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) {
return std::move(rv);
}

inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value) {
Array<ExprRV> results;
results.reserve(value.size());
for (int64_t v : value) {
results.push_back(CreateRV(v));
}
return results;
}

inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) {
auto it = this->symbol_table_.find(obj);
if (it != this->symbol_table_.end()) {
Expand Down
23 changes: 21 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace tir {
* \param max_exclusive The maximum value of the range, exclusive.
* \return The random integer sampled in the given range.
*/
TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive,
int max_exclusive);
TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state,
int32_t min_inclusive, int32_t max_exclusive);
/*!
* \brief Sample once category from candidates according to the probability weights.
* \param self The schedule to update
Expand All @@ -46,6 +46,25 @@ TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state,
TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
const Array<Integer>& candidates, const Array<FloatImm>& probs,
Optional<Integer>* decision);
/*!
* \brief Sample the factors to perfect tile a specific loop
* \param rand_state The random state
* \param loop_sref The loop to be tiled
* \param n The number of tiles to be sampled
* \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop
* \param decision The sampling decision
* \return A list of length `n`, the random perfect tile sizes sampled
*/
TVM_DLL std::vector<int64_t> SamplePerfectTile(
support::LinearCongruentialEngine::TRandState* rand_state, //
int32_t extent, int32_t n_splits);
TVM_DLL std::vector<int64_t> SamplePerfectTile(
support::LinearCongruentialEngine::TRandState* rand_state, //
int32_t extent, int32_t n_split, int32_t max_innermost_factor);
TVM_DLL std::vector<int64_t> SamplePerfectTile(
support::LinearCongruentialEngine::TRandState* rand_state, //
const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor,
Optional<Array<Integer>>* decision);

/******** Schedule: Get blocks & loops ********/
/*!
Expand Down
Loading

0 comments on commit 8b69550

Please sign in to comment.