Skip to content

Commit

Permalink
feat: support TF Keras EarlyStopping callbacks [DET-3240] (#666)
Browse files Browse the repository at this point in the history
  • Loading branch information
brainhart authored Jun 8, 2020
1 parent 4056146 commit 05aa3d2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
24 changes: 24 additions & 0 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def load_optimizer_weights(model: Model, load_path: pathlib.Path) -> None:
)


class DeterminedEarlyStoppingCallback(tf.keras.callbacks.Callback): # type: ignore
"""
DeterminedEarlyStoppingCallback converts a stop request, so that Determined
can handle the stop request by finishing the step and checkpointing.
"""

def __init__(self, tf_keras_trial_controller: "TFKerasTrialController") -> None:
self.tf_keras_trial_controller = tf_keras_trial_controller

def _convert_stop_training(self) -> None:
# We use stop_training to exit out of the training loop, but we set
# expect_terminate when we do so.
if self.model.stop_training and not self.tf_keras_trial_controller.expect_terminate:
self.model.stop_training = False
self.tf_keras_trial_controller.context.set_stop_requested(True)

def on_epoch_end(self, _: int, logs: Any = None) -> None:
self._convert_stop_training()

def on_train_end(self, _: int, logs: Any = None) -> None:
self._convert_stop_training()


class WaitForInstructionsCallback(tf.keras.callbacks.Callback): # type: ignore
"""
WaitForInstructionsCallback allows a separate process to control this trial.
Expand Down Expand Up @@ -507,6 +530,7 @@ def _launch_fit(self) -> None:
check.false(self.fit_loop_started)
self.fit_loop_started = True

self.tf_keras_callbacks.append(DeterminedEarlyStoppingCallback(self))
self.tf_keras_callbacks.append(WaitForInstructionsCallback(self))

profile_frequency = self.env.experiment_config.profile_frequency()
Expand Down
10 changes: 9 additions & 1 deletion harness/tests/experiment/fixtures/tf_keras_xor_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, List, cast

import tensorflow as tf
from tensorflow.keras.layers import Dense
Expand All @@ -11,6 +11,11 @@
from tests.experiment.utils import make_xor_data_sequences, xor_data # noqa: I202, I100


class StopVeryEarlyCallback(tf.keras.callbacks.Callback): # type: ignore
def on_epoch_end(self, _: int, logs: Any = None) -> None:
self.model.stop_training = True


def categorical_error(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
return 1.0 - categorical_accuracy(y_true, y_pred)

Expand Down Expand Up @@ -59,6 +64,9 @@ def build_validation_data_loader(self) -> keras.InputData:
_, test = make_xor_data_sequences(batch_size=4)
return keras.SequenceAdapter(test, workers=0)

def keras_callbacks(self) -> List[tf.keras.callbacks.Callback]:
return [StopVeryEarlyCallback()] if self.context.env.hparams.get("stop_early") else []


class XORTrialWithTrainingMetrics(XORTrial):
def build_model(self) -> Sequential:
Expand Down
15 changes: 15 additions & 0 deletions harness/tests/experiment/keras/test_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,21 @@ def controller_fn(workloads: workload.Stream) -> det.TrialController:
controller_fn=controller_fn, steps=3, validation_freq=1, batches_per_step=100
)

def test_early_stopping(self) -> None:
def make_workloads() -> workload.Stream:
trainer = utils.TrainAndValidate(request_stop_step_id=1)
yield from trainer.send(steps=100, validation_freq=2, batches_per_step=5)
tm, vm = trainer.result()
yield workload.terminate_workload(), [], workload.ignore_workload_response

hparams = dict(self.hparams)
hparams["stop_early"] = True

controller = utils.make_trial_controller_from_trial_implementation(
tf_keras_xor_model.XORTrial, hparams, make_workloads(), batches_per_step=5,
)
controller.run()


def test_surface_native_error():
cmd = ["python3", utils.fixtures_path("tf_keras_runtime_error.py")]
Expand Down
1 change: 0 additions & 1 deletion harness/tests/experiment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def send(
interceptor = workload.WorkloadResponseInterceptor()

for step_id in range(initial_step_id, initial_step_id + steps):
print(f"STEP ID: {step_id}")
stop_requested = False
yield from interceptor.send(workload.train_workload(step_id), [batches_per_step])
metrics = interceptor.metrics_result()
Expand Down

0 comments on commit 05aa3d2

Please sign in to comment.