Skip to content

Commit

Permalink
Fix problem with hang by using deep copy
Browse files Browse the repository at this point in the history
  • Loading branch information
echuraev committed Sep 20, 2021
1 parent 9981b45 commit bf4d604
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import tempfile
import multiprocessing
import logging
import copy

import tvm._ffi
from tvm.runtime import Object, module, ndarray
Expand Down Expand Up @@ -818,7 +819,7 @@ def prepare_input_map(args):
return tensor_input_map


def prepare_runner_args(inp, build_res, dev):
def prepare_runner_args(inp, build_res):
"""This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc
runner in main process
Expand All @@ -840,9 +841,6 @@ def prepare_runner_args(inp, build_res, dev):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

task_input_names = inp.task.task_input_names
tensor_input_map = prepare_input_map(build_res.args)
if not task_input_names:
Expand All @@ -855,17 +853,15 @@ def prepare_runner_args(inp, build_res, dev):
if tensor_name in task_input_names:
task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name)
# convert tvm.NDArray to picklable numpy.ndarray
args.append(ndarray.array(task_input_buffer), dev)
args.append(task_input_buffer.numpy())
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
random_fill(empty_array)
args.append(empty_array)
args.append(None)
if task_inputs_count != len(task_input_names):
raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error")
return args
Expand All @@ -874,6 +870,7 @@ def prepare_runner_args(inp, build_res, dev):
def _timed_eval_func(
inp_serialized,
build_res,
args,
number,
repeat,
min_repeat_ms,
Expand Down Expand Up @@ -910,10 +907,23 @@ def _timed_eval_func(

if error_no == 0:
try:
args = prepare_runner_args(inp, build_res, dev)
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
assert len(args) == len(build_res.args)
loc_args = copy.deepcopy(args)
# pylint: disable=consider-using-enumerate
for idx in range(len(loc_args)):
if loc_args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
loc_args[idx] = empty_array
else:
loc_args[idx] = ndarray.array(loc_args[idx], dev)
dev.sync()
costs = time_f(*args).results
costs = time_f(*loc_args).results
# pylint: disable=broad-except
except Exception:
costs = (MAX_FLOAT,)
Expand Down Expand Up @@ -1002,13 +1012,15 @@ def local_run(
time.time(),
)
else:
args = prepare_runner_args(inp, build_res)
res = call_func_with_timeout(
worker,
timeout,
_timed_eval_func,
args=(
inp.serialize(),
build_res,
args,
number,
repeat,
min_repeat_ms,
Expand Down Expand Up @@ -1049,6 +1061,7 @@ def local_run(
def _rpc_run(
inp_serialized,
build_res,
args,
key,
host,
port,
Expand Down Expand Up @@ -1095,16 +1108,31 @@ def _rpc_run(
try:
stream = dev.create_raw_stream()
dev.set_raw_stream(stream)
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

args = prepare_runner_args(inp, build_res, dev)
assert len(args) == len(build_res.args)
loc_args = copy.deepcopy(args)
# pylint: disable=consider-using-enumerate
for idx in range(len(loc_args)):
if loc_args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
loc_args[idx] = empty_array
else:
loc_args[idx] = ndarray.array(loc_args[idx], dev)
dev.sync()

# First run for check that the kernel is correct
func.entry_func(*args)
func.entry_func(*loc_args)
dev.sync()

costs = time_f(*args).results
costs = time_f(*loc_args).results

# clean up remote files
remote.remove(build_res.filename)
Expand Down Expand Up @@ -1144,7 +1172,7 @@ def _rpc_run_worker(args):
res : MeasureResult
The measure result of this Runner thread.
"""
_, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
_, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (
(MAX_FLOAT,),
Expand Down Expand Up @@ -1250,6 +1278,7 @@ def rpc_runner_run(
(
inp.serialize(),
build_res,
prepare_runner_args(inp, build_res),
key,
host,
port,
Expand Down

0 comments on commit bf4d604

Please sign in to comment.