Skip to content

Commit

Permalink
Rename extract_task to extract_task_from_relay
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jan 7, 2022
1 parent b464012 commit 4df0194
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class ApplyHistoryBest(MetaScheduleContext):
pass


def extract_task(
def extract_task_from_relay(
mod: Union[IRModule, RelayFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from tvm.ir.module import IRModule
from tvm.runtime import NDArray
from tvm.meta_schedule.integration import extract_task
from tvm.meta_schedule.integration import extract_task_from_relay
from tvm.target.target import Target
from tvm.te import Tensor, create_prim_func
from tvm.tir import PrimFunc, Schedule
Expand Down Expand Up @@ -651,7 +651,7 @@ def tune_relay(
"""

logger.info("Working directory: %s", work_dir)
extracted_tasks = extract_task(mod, target, params)
extracted_tasks = extract_task_from_relay(mod, target, params)
# pylint: disable=protected-access
tune_contexts = []
target = Parse._target(target)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_meta_schedule_integration_extract_from_resnet():
layout="NHWC",
dtype="float32",
)
extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params)
extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params)
assert len(extracted_tasks) == 30


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int
dtype="float32",
)
target = tvm.target.Target(target)
ms.integration.extract_task(mod, params=params, target=target)
ms.integration.extract_task_from_relay(mod, params=params, target=target)


if __name__ == "__main__":
Expand Down

0 comments on commit 4df0194

Please sign in to comment.