Skip to content

Commit

Permalink
[TensorIR] GetProducer, GetConsumer (apache#506) (apache#9464)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
6 people authored and ylc committed Jan 13, 2022
1 parent efcf614 commit 426fdd2
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ class ScheduleNode : public runtime::Object {
* \return A list of child blocks
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/*!
* \brief Get the producer of a specific block
* \param block_rv The block in the query
* \return A list of blocks, the producers of the given block
*/
virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
/*!
* \brief Get the consumers of a specific block
* \param block_rv The block to be queried
* \return A list of blocks, the consumers of the given block
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,36 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR
"""
return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type: ignore # pylint: disable=no-member

def get_producers(self, block: BlockRV) -> List[BlockRV]:
"""Get the producers of a specific block
Parameters
----------
block : BlockRV
The block in the query
Returns
-------
producers : List[BlockRV]
A list of producers of the given block
"""
return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member

def get_consumers(self, block: BlockRV) -> List[BlockRV]:
"""Get the consumers of a specific block
Parameters
----------
block : BlockRV
The block in the query
Returns
-------
consumers : List[BlockRV]
A list of consumers of the given block
"""
return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: Transform loops ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
Expand Down
14 changes: 14 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
return result;
}

Array<BlockRV> ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV<BlockRV>(tir::GetProducers(state_, this->GetSRef(block_rv)));
TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_);
throw;
}

Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV<BlockRV>(tir::GetConsumers(state_, this->GetSRef(block_rv)));
TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_);
throw;
}

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
Expand Down
14 changes: 14 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
* \return A list of leaf blocks inside a specific block/loop
*/
Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref);
/*!
* \brief Get the producers of a specific block
* \param self The schedule state
* \param block_sref The block in the query
* \return A list of blocks, the producers of the given block
*/
Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref);
/*!
* \brief Get the consumers of a specific block
* \param self The schedule state
* \param block_rv The block in the query
* \return A list of blocks, the consumers of the given block
*/
Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref);
/******** Schedule: Transform loops ********/
/*!
* Split a loop into a list of consecutive loops. It requires:
Expand Down
78 changes: 78 additions & 0 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent
return std::move(collector.result);
}

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
for (const Dependency& edge : edges) {
if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
results.push_back(edge->src);
}
}
return results;
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
for (const Dependency& edge : edges) {
if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
results.push_back(edge->dst);
}
}
return results;
}

/******** InstructionKind Registration ********/

struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
Expand Down Expand Up @@ -159,9 +187,59 @@ struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> {
static constexpr const char* kName = "GetProducers";
static constexpr bool kIsPure = true;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
return sch->GetProducers(block_rv);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv) {
PythonAPICall py("get_producers");
py.Input("block", block_rv);
py.OutputList(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
static constexpr const char* kName = "GetConsumers";
static constexpr bool kIsPure = true;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
return sch->GetConsumers(block_rv);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv) {
PythonAPICall py("get_consumers");
py.Input("block", block_rv);
py.OutputList(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);

} // namespace tir
} // namespace tvm
4 changes: 4 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
<< ". Its value is: " << rv;
throw;
});
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
.set_body_method<Schedule>(&ScheduleNode::GetProducers);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
Expand Down
22 changes: 22 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
return results;
}

Array<BlockRV> TracedScheduleNode::GetProducers(const BlockRV& block_rv) {
Array<BlockRV> results = ConcreteScheduleNode::GetProducers(block_rv);

static const InstructionKind& kind = InstructionKind::Get("GetProducers");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
Array<BlockRV> results = ConcreteScheduleNode::GetConsumers(block_rv);

static const InstructionKind& kind = InstructionKind::Get("GetConsumers");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
Expand Down
42 changes: 41 additions & 1 deletion tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,31 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j in T.grid(128, 128):
with T.block("init"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.float32(0)
C[vi, vj] = 0.0
for k in range(0, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@T.prim_func
def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024))
B = T.match_buffer(b, (1024, 1024))
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(1024, 1024):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.max(C[vi, vj], 0.0)


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -159,5 +177,27 @@ def test_get_child_blocks():
assert s.get(update) == s.get(blocks[1])


def test_get_producers():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
block = sch.get_block("relu")
(producer,) = sch.get_producers(block)
assert tvm.ir.structural_equal(
sch.get_sref(producer).stmt,
sch.get_sref(sch.get_block("matmul")).stmt,
)
verify_trace_roundtrip(sch, mod=matmul_relu)


def test_get_consumers():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
block = sch.get_block("matmul")
(consumer,) = sch.get_consumers(block)
assert tvm.ir.structural_equal(
sch.get_sref(consumer).stmt,
sch.get_sref(sch.get_block("relu")).stmt,
)
verify_trace_roundtrip(sch, mod=matmul_relu)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 426fdd2

Please sign in to comment.