diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 47003c6faa25f..5cd483698a13f 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -177,7 +177,7 @@ class ApplyHistoryBest(MetaScheduleContext): pass -def extract_task( +def extract_task_from_relay( mod: Union[IRModule, RelayFunc], target: Target, params: Optional[Dict[str, NDArray]] = None, diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 048ad03ab3ed6..bcfa08cdfc7d4 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -23,7 +23,7 @@ from tvm.ir.module import IRModule from tvm.runtime import NDArray -from tvm.meta_schedule.integration import extract_task +from tvm.meta_schedule.integration import extract_task_from_relay from tvm.target.target import Target from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule @@ -651,7 +651,7 @@ def tune_relay( """ logger.info("Working directory: %s", work_dir) - extracted_tasks = extract_task(mod, target, params) + extracted_tasks = extract_task_from_relay(mod, target, params) # pylint: disable=protected-access tune_contexts = [] target = Parse._target(target) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index f508c7d252e10..0ace4d2bd02c0 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -112,7 +112,7 @@ def test_meta_schedule_integration_extract_from_resnet(): layout="NHWC", dtype="float32", ) - extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params) + extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == 30 diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py index 8d1eca51432e5..8523275f51869 100644 --- a/tests/python/unittest/test_meta_schedule_task_extraction.py +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -91,7 +91,7 @@ def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int dtype="float32", ) target = tvm.target.Target(target) - ms.integration.extract_task(mod, params=params, target=target) + ms.integration.extract_task_from_relay(mod, params=params, target=target) if __name__ == "__main__":