Skip to content

Commit

Permalink
[MetaSchedule] Rewrite Parallel-Vectorize-Unroll / Verify-GPU / Disal…
Browse files Browse the repository at this point in the history
…low-Dynamic-Loops (apache#499)

* wip

fix

* revoke change to gallery

* split postprocessors to separate files

* rename attrs

* minor

* minor tweak on utils.h

* refactor disallow-dynamic-loop

* refactor verify_gpu_code

* succesfully give up refactoring parallelize-vectorize-unroll

* python structuring

* unittests

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
spectrometerHBH and junrushao committed Nov 11, 2021
1 parent dcf7310 commit 996cf62
Show file tree
Hide file tree
Showing 19 changed files with 1,117 additions and 28 deletions.
16 changes: 16 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,23 @@ class Postproc : public runtime::ObjectRef {
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string);
/*!
* \brief Create a postprocessor that checks if all loops are static
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteCooperativeFetch();
/*!
* \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling
* according to the annotation of each block
* \return The postprocessor created
*/
TVM_DLL static Postproc RewriteParallelVectorizeUnroll();
/*!
* \brief Create a postprocessor that rewrites reduction block by moving the init block out.
* \return The postprocessor created.
Expand All @@ -134,6 +145,11 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteUnboundBlock();
/*!
* \brief Creates a postprocessor that verifies if the GPU code is correct
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyGPUCode();
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,23 @@ constexpr const int meta_schedule_cache_type_read = 0;
/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_write = 1;

/*! \brief Mark auto-parallel setting on the block. */
constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";

/*! \brief Mark auto-vectorize setting on the block. */
constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*! \brief Pragma: auto-unroll, max_step */
constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";

/*! \brief Pragma: unroll explicit */
constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .disallow_dynamic_loop import DisallowDynamicLoop
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.DisallowDynamicLoop")
class DisallowDynamicLoop(Postproc):
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that applies parallelization, vectorization and auto unrolling
according to the annotation of each block"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteParallelVectorizeUnroll")
class RewriteParallelVectorizeUnroll(Postproc):
"""A postprocessor that applies parallelization, vectorization and auto unrolling
according to the annotation of each block"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member
)
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/verify_gpu_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that verifies if the GPU code is correct"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.VerifyGPUCode")
class VerifyGPUCode(Postproc):
"""A postprocessor that verifies if the GPU code is correct"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member
)
20 changes: 10 additions & 10 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class TuneContext(Object):
The search strategy.
sch_rules: Optional[List[ScheduleRule]] = None,
The schedule rules.
postproc: Optional[List[Postproc"]] = None,
postprocs: Optional[List[Postproc"]] = None,
The postprocessors.
mutator: Optional[List[Mutator]] = None,
mutators: Optional[List[Mutator]] = None,
The mutators.
task_name : Optional[str] = None
The name of the tuning task.
Expand All @@ -81,8 +81,8 @@ class TuneContext(Object):
space_generator: Optional["SpaceGenerator"]
search_strategy: Optional["SearchStrategy"]
sch_rules: Optional[List["ScheduleRule"]]
postproc: Optional[List["Postproc"]]
mutator: Optional[List["Mutator"]]
postprocs: Optional[List["Postproc"]]
mutators: Optional[List["Mutator"]]
task_name: Optional[str]
rand_state: int
num_threads: int
Expand All @@ -94,8 +94,8 @@ def __init__(
space_generator: Optional["SpaceGenerator"] = None,
search_strategy: Optional["SearchStrategy"] = None,
sch_rules: Optional[List["ScheduleRule"]] = None,
postproc: Optional[List["Postproc"]] = None,
mutator: Optional[List["Mutator"]] = None,
postprocs: Optional[List["Postproc"]] = None,
mutators: Optional[List["Mutator"]] = None,
task_name: Optional[str] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
Expand All @@ -114,9 +114,9 @@ def __init__(
The search strategy.
sch_rules : List[ScheduleRule] = []
The schedule rules.
postproc : List[Postproc] = []
postprocs : List[Postproc] = []
The postprocessors.
mutator : List[Mutator] = []
mutators : List[Mutator] = []
The mutators.
task_name : Optional[str] = None
The name of the tuning task.
Expand All @@ -138,8 +138,8 @@ def __init__(
space_generator,
search_strategy,
sch_rules,
postproc,
mutator,
postprocs,
mutators,
task_name,
rand_state,
num_threads,
Expand Down
89 changes: 89 additions & 0 deletions src/meta_schedule/postproc/disallow_dynamic_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief Check if the loop is dynamic. */
struct DynamicExtentFinder : private StmtVisitor {
public:
static bool Find(const IRModule& mod) {
DynamicExtentFinder finder;
for (const auto& kv : mod->functions) {
const BaseFunc& func = kv.second;
if (const auto* prim_func = func.as<PrimFuncNode>()) {
finder(prim_func->body);
if (finder.found_) {
return true;
}
}
}
return false;
}

private:
void VisitStmt_(const ForNode* loop) final {
if (!loop->extent->IsInstance<IntImmNode>()) {
found_ = true;
} else {
StmtVisitor::VisitStmt_(loop);
}
}

void VisitStmt(const Stmt& stmt) final {
if (!found_) {
StmtVisitor::VisitStmt(stmt);
}
}

bool found_ = false;
};

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

using tir::Schedule;

/*! \brief Check if the IRModule has any loop with non-constant extent. */
class DisallowDynamicLoopNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }

static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
};

Postproc Postproc::DisallowDynamicLoop() {
ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop")
.set_body_typed(Postproc::DisallowDynamicLoop);

} // namespace meta_schedule
} // namespace tvm
Loading

0 comments on commit 996cf62

Please sign in to comment.