Skip to content

Commit

Permalink
feat: support stopping training in trial code [DET-3238] (determined-…
Browse files Browse the repository at this point in the history
  • Loading branch information
brainhart authored Jun 5, 2020
1 parent fa09a74 commit ee1314f
Show file tree
Hide file tree
Showing 15 changed files with 151 additions and 47 deletions.
8 changes: 6 additions & 2 deletions e2e_tests/tests/fixtures/metric_maker/metric_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ def train_for_step(self, step_id: int, batches_per_step: int) -> Dict[str, Any]:
# Update the overall base value for the trial..
self.value += self.gain_per_batch * batches_per_step

return {"batch_metrics": batch_metrics, "num_inputs": batches_per_step}
return {"metrics": {"batch_metrics": batch_metrics, "num_inputs": batches_per_step}}

def compute_validation_metrics(self, step_id: int) -> Dict[str, Any]:
return {"validation_metrics": structure_to_metrics(self.value, self.validation_structure)}
return {
"metrics": {
"validation_metrics": structure_to_metrics(self.value, self.validation_structure)
}
}

def set_random_seed(self, trial_seed) -> None:
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
description: noop_adaptive
checkpoint_storage:
type: shared_fs
host_path: /tmp
storage_path: determined-integration-checkpoints
hyperparameters:
global_batch_size: 32
metrics_progression: decreasing
metrics_base: 0.5
metrics_sigma: 0
request_stop:
type: categorical
vals: [True, False]
searcher:
name: grid
metric: validation_error
max_steps: 4
reproducibility:
experiment_seed: 999
max_restarts: 0
entrypoint: model_def:NoOpTrial
20 changes: 16 additions & 4 deletions e2e_tests/tests/fixtures/no_op/model_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
assert 0 <= self.metrics_sigma
self.write_null = self.env.hparams.get("write_null", False)

self.request_stop = self.env.hparams.get("request_stop", False)

if self.load_path is None:
self.trained_steps = collections.Counter()
else:
Expand Down Expand Up @@ -83,24 +85,34 @@ def current_metric(self) -> float:
raise ValueError("Invalid `metrics_progression` {}".format(self.metrics_progression))

def train_for_step(self, step_id: int, batches_per_step: int) -> Dict[str, Any]:
if self.request_stop:
self.context.set_stop_requested(True)
self.chaos_failure(self.chaos_probability_train)
time.sleep(self.train_batch_secs * batches_per_step)
if self.write_null:
with open("/dev/stdout", "wb") as f:
f.write(b"\x00")
self.trained_steps[step_id] += 1
metrics = {name: self.current_metric() for name in ["loss", *self.training_metrics()]}
return det.util.make_metrics(
self._batch_size * batches_per_step, [metrics] * batches_per_step
)
response = {
"metrics": det.util.make_metrics(
self._batch_size * batches_per_step, [metrics] * batches_per_step
),
"stop_requested": self.context.get_stop_requested(),
}
return response

def compute_validation_metrics(self, step_id: int) -> Dict[str, Any]:
self.chaos_failure(self.chaos_probability_validate)
time.sleep(self.validation_secs)
metrics = {
name: self.current_metric() for name in ["validation_error", *self.validation_metrics()]
}
return {"validation_metrics": metrics, "num_inputs": self.validation_set_size}
response = {
"metrics": {"validation_metrics": metrics, "num_inputs": self.validation_set_size},
"stop_requested": self.context.get_stop_requested(),
}
return response

def training_metrics(self) -> Dict[str, Any]:
return {"metric_{}".format(i): None for i in range(1, self.num_training_metrics)}
Expand Down
6 changes: 6 additions & 0 deletions e2e_tests/tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ def test_log_null_bytes() -> None:
assert len(logs) > 0


@pytest.mark.e2e_cpu # type: ignore
def test_graceful_trial_termination() -> None:
config_obj = conf.load_config(conf.fixtures_path("no_op/grid-graceful-trial-termination.yaml"))
exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 2)


@pytest.mark.e2e_gpu # type: ignore
def test_s3_no_creds(secrets: Dict[str, str]) -> None:
pytest.skip("Temporarily skipping this until we find a more secure way of testing this.")
Expand Down
17 changes: 17 additions & 0 deletions harness/determined/_train_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, env: det.EnvContext, hvd_config: horovod.HorovodContext):
self.env = env # type: det.EnvContext
self.hvd_config = hvd_config # type: horovod.HorovodContext
self.distributed = DistributedContext(env, hvd_config)
self._stop_requested = False

def get_experiment_config(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -82,6 +83,22 @@ def get_hparam(self, name: str) -> Any:
)
return self.env.hparams[name]

def get_stop_requested(self) -> bool:
"""
Return whether a trial stoppage has been requested.
"""
return self._stop_requested

def set_stop_requested(self, stop_requested: bool) -> None:
"""
Set a flag to request a trial stoppage. When this flag is set to True,
we finish the step, checkpoint, then exit.
"""
if not isinstance(stop_requested, bool):
raise AssertionError("stop_requested must be a boolean")

self._stop_requested = stop_requested


class TrialContext(_TrainContext):
"""
Expand Down
14 changes: 10 additions & 4 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def after_run(
check.is_not_none(self.train_response_func, "no response_func at end of train_for_step")
self.train_response_func = cast(workload.ResponseFunc, self.train_response_func)
if self.estimator_trial_controller.is_chief:
self.train_response_func(
det.util.make_metrics(self.batches_processed_in_step, self.step_metrics)
)
response = {
"metrics": det.util.make_metrics(self.batches_processed_in_step, self.step_metrics),
"stop_requested": self.estimator_trial_controller.context.get_stop_requested(),
}
self.train_response_func(response)
else:
self.train_response_func(workload.Skipped())

Expand Down Expand Up @@ -289,7 +291,11 @@ def control_loop(self) -> None:
# re-enters the train_and_evaluate() loop.
break
elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
response_func(self._compute_validation_metrics())
response = {
"metrics": self._compute_validation_metrics(),
"stop_requested": self.estimator_trial_controller.context.get_stop_requested(),
}
response_func(response)
elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
check.len_eq(args, 1)
check.is_instance(args[0], pathlib.Path)
Expand Down
4 changes: 2 additions & 2 deletions harness/determined/experimental/_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def _make_test_workloads(
logging.info("Training one batch")
yield from interceptor.send(workload.train_workload(1), [1])
metrics = interceptor.metrics_result()
batch_metrics = metrics["batch_metrics"]
batch_metrics = metrics["metrics"]["batch_metrics"]
check.eq(len(batch_metrics), config.batches_per_step())
logging.debug(f"Finished training, metrics: {batch_metrics}")

logging.info("Validating one step")
yield from interceptor.send(workload.validation_workload(1), [])
validation = interceptor.metrics_result()
v_metrics = validation["validation_metrics"]
v_metrics = validation["metrics"]["validation_metrics"]
logging.debug(f"Finished validating, validation metrics: {v_metrics}")

logging.info(f"Saving a checkpoint to {checkpoint_dir}.")
Expand Down
12 changes: 10 additions & 2 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def on_train_batch_end(self, _: int, logs: Any = None) -> None:
)

if self.tf_keras_trial_controller.is_chief:
response_func(det.util.make_metrics(num_inputs, self.metrics))
response = {
"metrics": det.util.make_metrics(num_inputs, self.metrics),
"stop_requested": self.tf_keras_trial_controller.context.get_stop_requested(),
}
response_func(response)
else:
response_func(workload.Skipped())

Expand Down Expand Up @@ -481,7 +485,11 @@ def run(self) -> None:
break

elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
response_func(self.compute_validation_metrics())
response = {
"metrics": self.compute_validation_metrics(),
"stop_requested": self.context.get_stop_requested(),
}
response_func(response)
elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
check.len_eq(args, 1)
check.is_instance(args[0], pathlib.Path)
Expand Down
57 changes: 34 additions & 23 deletions harness/determined/layers/_workload_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import sys
from datetime import datetime, timezone
from typing import List, Optional, cast
from typing import Any, Dict, List, Optional, cast

import determined as det
from determined import tensorboard, workload
Expand Down Expand Up @@ -119,14 +119,17 @@ def yield_train_for_step(
callback.on_trial_begin()
callback.on_train_step_begin(wkld.step_id)

def _respond(metrics: workload.Response) -> None:
def _respond(in_response: workload.Response) -> None:

# Only the chief container should actually respond to TRAIN_FOR_STEP.
if self.rendezvous_info.get_rank() != 0:
respond(workload.Skipped())
return

check_not_isinstance(metrics, workload.Skipped, "Chief skipped a workload.")
check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.")

in_response = cast(workload.Metrics, in_response)
metrics = in_response["metrics"]
metrics = cast(workload.Metrics, metrics)

batch_metrics = metrics["batch_metrics"]
Expand All @@ -140,16 +143,19 @@ def _respond(metrics: workload.Response) -> None:

self.tensorboard_mgr.sync()

out_response = {
"type": "WORKLOAD_COMPLETED",
"workload": wkld,
"start_time": start_time,
"end_time": _current_timestamp(),
"metrics": metrics,
}

if in_response.get("stop_requested", False):
out_response["exited_reason"] = "USER_CANCELED"

# Send the response up.
respond(
{
"type": "WORKLOAD_COMPLETED",
"workload": wkld,
"start_time": start_time,
"end_time": _current_timestamp(),
"metrics": metrics,
}
)
respond(out_response)

num_batches = self.env.experiment_config.get("batches_per_step", 100)
yield wkld, [num_batches], _respond
Expand All @@ -159,14 +165,16 @@ def yield_compute_validation_metrics(
) -> workload.Stream:
start_time = _current_timestamp()

def _respond(metrics: workload.Response) -> None:
def _respond(in_response: workload.Response) -> None:

# Only the chief container should actually respond to COMPUTE_VALIDATION_METRICS.
if self.rendezvous_info.get_rank() != 0:
respond(workload.Skipped())
return

check_not_isinstance(metrics, workload.Skipped, "Chief skipped a workload.")
check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.")
in_response = cast(Dict[str, Any], in_response)
metrics = in_response["metrics"]
metrics = cast(workload.Metrics, metrics)

v_metrics = metrics["validation_metrics"]
Expand Down Expand Up @@ -235,15 +243,18 @@ def _respond(metrics: workload.Response) -> None:
for metric_name in non_serializable_metrics:
del v_metrics[metric_name]

respond(
{
"type": "WORKLOAD_COMPLETED",
"workload": wkld,
"start_time": start_time,
"end_time": _current_timestamp(),
"metrics": metrics,
}
)
out_response = {
"type": "WORKLOAD_COMPLETED",
"workload": wkld,
"start_time": start_time,
"end_time": _current_timestamp(),
"metrics": metrics,
}

if in_response.get("stop_requested", False):
out_response["exited_reason"] = "USER_CANCELED"

respond(out_response)

for callback in self.callbacks:
callback.on_validation_step_begin(wkld.step_id)
Expand Down
12 changes: 10 additions & 2 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,17 @@ def run(self) -> None:
if w.kind == workload.Workload.Kind.RUN_STEP:
check.eq(len(args), 1)
num_batches = cast(int, args[0])
response_func(self._train_for_step(w.step_id, num_batches))
response = {
"metrics": self._train_for_step(w.step_id, num_batches),
"stop_requested": self.context.get_stop_requested(),
}
response_func(response)
elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
response_func(self._compute_validation_metrics())
response = {
"metrics": self._compute_validation_metrics(),
"stop_requested": self.context.get_stop_requested(),
}
response_func(response)
elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
check.eq(len(args), 1)
check.is_instance(args[0], pathlib.Path)
Expand Down
15 changes: 13 additions & 2 deletions harness/determined/tensorpack/_tensorpack_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
workloads: workload.Stream,
is_chief: bool,
machine_rank: int,
context: Any,
) -> None:
self.metric_names = metric_names
self.batch_metrics = [] # type: List[Dict[str, Any]]
Expand All @@ -108,6 +109,7 @@ def __init__(
self.workloads = workloads
self.is_chief = is_chief
self.machine_rank = machine_rank
self.context = context

# Store the response_func for train_for_step workloads while we do the training.
self.train_response_func = None # type: Optional[workload.ResponseFunc]
Expand Down Expand Up @@ -227,7 +229,11 @@ def _trigger_epoch(self) -> None:
self.train_response_func = cast(workload.ResponseFunc, self.train_response_func)

if self.is_chief:
self.train_response_func(det.util.make_metrics(None, self.batch_metrics))
response = {
"metrics": det.util.make_metrics(None, self.batch_metrics),
"stop_requested": self.context.get_stop_requested(),
}
self.train_response_func(response)
else:
self.train_response_func(workload.Skipped())

Expand All @@ -243,7 +249,11 @@ def _control_loop(self) -> None:
self.train_response_func = response_func
break
elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
response_func(self._compute_validation_metrics())
response = {
"metrics": self._compute_validation_metrics(),
"stop_requested": self.context.get_stop_requested(),
}
response_func(response)
elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
check.len_eq(args, 1)
check.is_instance(args[0], pathlib.Path)
Expand Down Expand Up @@ -453,6 +463,7 @@ def _init_model(self, training_dataflow: Any, validation_dataflow: Any) -> None:
self.workloads,
self.is_chief,
self.rendezvous_info.get_rank(),
self.context,
)

# TODO: check to make sure users don't pass in InferenceRunner
Expand Down
4 changes: 2 additions & 2 deletions harness/tests/experiment/keras/test_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def make_workloads() -> workload.Stream:
# Calculate what the loss should be.
loss = trial_class.calc_loss(w, batch)

assert metrics["avg_metrics"]["loss"] == pytest.approx(loss)
assert metrics["metrics"]["avg_metrics"]["loss"] == pytest.approx(loss)

# Update what the weight should be.
w = w - hparams["learning_rate"] * trial_class.calc_gradient(w, batch)
Expand Down Expand Up @@ -248,7 +248,7 @@ def make_workloads_2() -> workload.Stream:
yield from interceptor.send(workload.validation_workload(), [])
metrics = interceptor.metrics_result()

new_loss = metrics["validation_metrics"]["val_loss"]
new_loss = metrics["metrics"]["validation_metrics"]["val_loss"]
assert new_loss == pytest.approx(old_loss)

yield workload.terminate_workload(), [], workload.ignore_workload_response
Expand Down
Loading

0 comments on commit ee1314f

Please sign in to comment.