Skip to content

Commit

Permalink
feat: support Estimator early stopping hooks [DET-3239] (determined-a…
Browse files Browse the repository at this point in the history
  • Loading branch information
brainhart authored Jun 8, 2020
1 parent 3ab90a6 commit c8bb942
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
33 changes: 31 additions & 2 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@
VERY_LARGE_NUMBER = 9999999999999999


class DeterminedEarlyStoppingHook(tf.compat.v1.train.SessionRunHook): # type: ignore
"""
DeterminedEarlyStoppingHook converts a stop request, so that Determined can
handle the stop request by finishing the step and checkpointing.
"""

def __init__(self, context: Any) -> None:
self.context = context

def after_run(
self, run_context: tf.estimator.SessionRunContext, run_values: tf.estimator.SessionRunValues
) -> None:
if run_context.stop_requested:
run_context._stop_requested = False
self.context.set_stop_requested(True)


class DeterminedControlHook(tf.estimator.SessionRunHook): # type: ignore
"""
DeterminedControlHook takes control of the train_and_evaluate() loop between
Expand Down Expand Up @@ -458,6 +475,7 @@ def wrapper(*args: Any, **kwargs: Any) -> tf.data.Dataset:

def _init_model(self) -> None:
self._init_train_hooks()
self._init_val_hooks()
self._init_paths()

self.estimator = tf.estimator.Estimator(
Expand Down Expand Up @@ -504,13 +522,20 @@ def _init_model(self) -> None:
self.train_spec = tf.estimator.TrainSpec(
input_fn=repeating_train_fn, hooks=self.train_hooks
)
steps = (
self.val_spec.steps // self.context.distributed.get_size()
if self.val_spec.steps is not None
else None
)
self.eval_spec = tf.estimator.EvalSpec(
input_fn=self.val_spec.input_fn, hooks=self.val_spec.hooks, steps=None
input_fn=self.val_spec.input_fn, hooks=self.val_hooks, steps=steps
)

def _init_train_hooks(self) -> None:
self.train_hooks = [*self.user_train_spec.hooks]

self.train_hooks.append(DeterminedEarlyStoppingHook(self.context))

if self.hvd_config.use:
self.train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

Expand All @@ -519,6 +544,10 @@ def _init_train_hooks(self) -> None:
# their chance.
self.train_hooks.append(DeterminedControlHook(self))

def _init_val_hooks(self) -> None:
self.val_hooks = [*self.val_spec.hooks]
self.val_hooks.append(DeterminedEarlyStoppingHook(self.context))

def _init_run_config(self, config: tf.estimator.RunConfig) -> tf.estimator.RunConfig:
logging.debug(f"Initializing RunConfig. Got RunConfig: {config} .")

Expand Down Expand Up @@ -601,7 +630,7 @@ def _init_paths(self) -> None:

def compute_validation_metrics(self) -> workload.Response:
metrics = self.estimator.evaluate(
input_fn=self.val_spec.input_fn, steps=self.val_spec.steps, hooks=self.val_spec.hooks
input_fn=self.eval_spec.input_fn, steps=self.eval_spec.steps, hooks=self.eval_spec.hooks
)

if self.hvd_config.use:
Expand Down
17 changes: 15 additions & 2 deletions harness/tests/experiment/fixtures/estimator_xor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def map_dataset(x, y):
return _input_fn


class StopVeryEarly(tf.compat.v1.train.SessionRunHook): # type: ignore
def after_run(
self, run_context: tf.estimator.SessionRunContext, run_values: tf.estimator.SessionRunValues
) -> None:
run_context.request_stop()


class XORTrial(estimator.EstimatorTrial):
"""
Models a lightweight neural network model with one hidden layer to
Expand Down Expand Up @@ -100,21 +107,27 @@ def build_estimator(self) -> tf.estimator.Estimator:
)

def build_train_spec(self) -> tf.estimator.TrainSpec:
hooks = [StopVeryEarly()] if self.context.env.hparams.get("stop_early") == "train" else []
return tf.estimator.TrainSpec(
xor_input_fn(
context=self.context,
batch_size=self.context.get_per_slot_batch_size(),
shuffle=self.context.get_hparam("shuffle"),
)
),
hooks=hooks,
)

def build_validation_spec(self) -> tf.estimator.EvalSpec:
hooks = (
[StopVeryEarly()] if self.context.env.hparams.get("stop_early") == "validation" else []
)
return tf.estimator.EvalSpec(
xor_input_fn(
context=self.context,
batch_size=self.context.get_per_slot_batch_size(),
shuffle=False,
)
),
hooks=hooks,
)

def build_serving_input_receiver_fns(self) -> Dict[str, estimator.ServingInputReceiverFn]:
Expand Down
18 changes: 18 additions & 0 deletions harness/tests/experiment/tensorflow/test_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,24 @@ def make_workloads() -> workload.Stream:
with open(hparams["training_end"], "r") as fp:
assert fp.readline() == "success"

@pytest.mark.parametrize("stop_early,request_stop_step_id", [("train", 1), ("validation", 2)])
def test_early_stopping(self, stop_early: str, request_stop_step_id: int) -> None:
def make_workloads() -> workload.Stream:
trainer = utils.TrainAndValidate(request_stop_step_id=request_stop_step_id)
yield from trainer.send(steps=2, validation_freq=2, batches_per_step=5)
tm, vm = trainer.result()
yield workload.terminate_workload(), [], workload.ignore_workload_response

hparams = dict(self.hparams)
hparams["stop_early"] = stop_early
controller = utils.make_trial_controller_from_trial_implementation(
trial_class=estimator_xor_model.XORTrial,
hparams=hparams,
workloads=make_workloads(),
batches_per_step=5,
)
controller.run()


def test_local_mode() -> None:
utils.run_local_test_mode(utils.fixtures_path("estimator_xor_model_native.py"))
Expand Down
17 changes: 16 additions & 1 deletion harness/tests/experiment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class TrainAndValidate:
metrics from each.
"""

def __init__(self) -> None:
def __init__(self, request_stop_step_id: Optional[int] = None) -> None:
self._training_metrics = None # type: Optional[List[Dict[str, Any]]]
self._validation_metrics = None # type: Optional[List[Dict[str, Any]]]
self.request_stop_step_id = request_stop_step_id

def send(
self, steps: int, validation_freq: int, initial_step_id: int = 1, batches_per_step: int = 1
Expand All @@ -37,17 +38,31 @@ def send(
interceptor = workload.WorkloadResponseInterceptor()

for step_id in range(initial_step_id, initial_step_id + steps):
print(f"STEP ID: {step_id}")
stop_requested = False
yield from interceptor.send(workload.train_workload(step_id), [batches_per_step])
metrics = interceptor.metrics_result()
batch_metrics = metrics["metrics"]["batch_metrics"]
assert len(batch_metrics) == batches_per_step
self._training_metrics.extend(batch_metrics)
if metrics["stop_requested"]:
assert step_id == self.request_stop_step_id
stop_requested = True

if step_id % validation_freq == 0:
yield from interceptor.send(workload.validation_workload(step_id), [])
validation = interceptor.metrics_result()
print(validation)
v_metrics = validation["metrics"]["validation_metrics"]
self._validation_metrics.append(v_metrics)
if validation["stop_requested"]:
assert step_id == self.request_stop_step_id
stop_requested = True

if stop_requested:
break
else:
assert step_id != self.request_stop_step_id

def result(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
assert self._training_metrics is not None
Expand Down

0 comments on commit c8bb942

Please sign in to comment.