Skip to content

Commit

Permalink
fix: warn out if catching SystemExit [DET-2956]
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyuann committed Aug 25, 2020
1 parent 3c8c047 commit 9484250
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 48 deletions.
5 changes: 5 additions & 0 deletions docs/release-notes/1116-sys-exit.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
:orphan:

**Improvements**

- Remind users to remove ``sys.exit()`` if ``SystemExit`` exception is caught.
3 changes: 2 additions & 1 deletion harness/determined/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
LoopTrialController,
TrialController,
)
from determined._local_execution import (
from determined._execution import (
_execute_user_func,
_make_local_execution_env,
_local_execution_manager,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pathlib
import sys
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple

import determined as det
from determined import constants, gpu, horovod, workload
Expand All @@ -16,6 +16,18 @@ def _get_gpus() -> Tuple[bool, List[str], List[int]]:
return use_gpu, gpu_uuids, gpu_ids


@contextlib.contextmanager
def _execute_user_func() -> Any:
try:
yield
except SystemExit as e:
raise det.errors.InvalidExperimentException(
"User code raised a SystemExit exception. "
"This might be raised by calling sys.exit(). "
"Please remove calls on sys.exit() from your script."
) from e


def _make_local_execution_exp_config(input_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Create a local experiment configuration based on an input configuration and
Expand Down
16 changes: 12 additions & 4 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,19 @@ def from_trial(
)
trial_inst = cast(EstimatorTrial, trial_inst)

with det._execute_user_func():
est, train_spec, val_spec, serving_input_receiver_fns = (
trial_inst.build_estimator(),
trial_inst.build_train_spec(),
trial_inst.build_validation_spec(),
trial_inst.build_serving_input_receiver_fns()
)

return EstimatorTrialController(
trial_inst.build_estimator(),
trial_inst.build_train_spec(),
trial_inst.build_validation_spec(),
trial_inst.build_serving_input_receiver_fns(),
est,
train_spec,
val_spec,
serving_input_receiver_fns,
context,
env,
*args,
Expand Down
28 changes: 19 additions & 9 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,19 @@ def from_trial(
check.is_instance(trial_inst, TFKerasTrial, "TFKerasTrialController needs a TFKerasTrial")
trial = cast(TFKerasTrial, trial_inst)

session = TFKerasTrialController._configure_session(env, hvd_config, trial.session_config())

# Set Keras session.
with det._execute_user_func():
session_config = trial.session_config()
session = TFKerasTrialController._configure_session(env, hvd_config, session_config)

# Build data loaders.
with det._execute_user_func():
train_data_loader, val_data_loader = (
trial.build_training_data_loader(),
trial.build_validation_data_loader()
)
training_x, training_y, training_sample_weight = keras._get_x_y_and_sample_weight(
input_data=trial.build_training_data_loader()
input_data=train_data_loader
)
training_data = keras._adapt_keras_data(
x=training_x,
Expand All @@ -322,9 +331,8 @@ def from_trial(
batch_size=context.get_per_slot_batch_size(),
drop_leftovers=True,
)

val_x, val_y, val_sample_weight = keras._get_x_y_and_sample_weight(
input_data=trial.build_validation_data_loader()
input_data=val_data_loader
)
validation_data = keras._adapt_keras_data(
x=val_x,
Expand All @@ -334,17 +342,19 @@ def from_trial(
drop_leftovers=False,
)

trial.build_model()
# Build and compile model.
with det._execute_user_func():
trial.build_model()
check.is_not_none(context.model, "Please call wrap_model(...).")

check.is_not_none(context.compile_args, "Please call model.compile(...).")
compile_args = cast(inspect.BoundArguments, context.compile_args)

TFKerasTrialController.compile_model(
context=context, compile_args=compile_args, env=env, hvd_config=hvd_config
)

tf_keras_callbacks = trial.keras_callbacks()
# Initialize callbacks.
with det._execute_user_func():
tf_keras_callbacks = trial.keras_callbacks()

return TFKerasTrialController(
context.model,
Expand Down
80 changes: 47 additions & 33 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None:
check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial")
self.trial = cast(PyTorchTrial, trial_inst)
self.context = cast(PyTorchTrialContext, self.context)
self.callbacks = self.trial.build_callbacks()
with det._execute_user_func():
self.callbacks = self.trial.build_callbacks()

self._apply_backwards_compatibility()

Expand Down Expand Up @@ -166,10 +167,11 @@ def _apply_backwards_compatibility(self) -> None:
"and context.step_optimizer(optimizer) in train_batch.",
)

model = self.context.wrap_model(self.trial.build_model())
optim = self.context.wrap_optimizer(self.trial.optimizer(model))
with det._execute_user_func():
model = self.context.wrap_model(self.trial.build_model())
optim = self.context.wrap_optimizer(self.trial.optimizer(model))
lr_scheduler = self.trial.create_lr_scheduler(optim)

lr_scheduler = self.trial.create_lr_scheduler(optim)
if lr_scheduler is not None:
opt = getattr(lr_scheduler._scheduler, "optimizer", None)
if opt is not None:
Expand Down Expand Up @@ -250,12 +252,13 @@ def _set_data_loaders(self) -> None:
nreplicas = hvd.size() if self.hvd_config.use else 1
rank = hvd.rank() if self.hvd_config.use else 0

self.training_loader = self.trial.build_training_data_loader().get_data_loader(
with det._execute_user_func():
training_dataset = self.trial.build_training_data_loader()
validation_dataset = self.trial.build_validation_data_loader()
self.training_loader = training_dataset.get_data_loader(
repeat=True, skip=skip_batches, num_replicas=nreplicas, rank=rank
)
self.context._epoch_len = len(self.training_loader)

validation_dataset = self.trial.build_validation_data_loader()
if self._evaluate_batch_defined():
self.validation_loader = validation_dataset.get_data_loader(
repeat=False, skip=0, num_replicas=nreplicas, rank=rank
Expand Down Expand Up @@ -373,9 +376,12 @@ def _train_for_step(

self.context._current_batch_idx = batch_idx
self.context._loss_ids = {}
tr_metrics = self.trial.train_batch(
batch=batch, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx,
)
with det._execute_user_func():
tr_metrics = self.trial.train_batch(
batch=batch,
epoch_idx=self.get_epoch_idx(batch_idx),
batch_idx=batch_idx,
)
if isinstance(tr_metrics, torch.Tensor):
tr_metrics = {"loss": tr_metrics}
check.is_instance(
Expand Down Expand Up @@ -428,14 +434,17 @@ def _compute_validation_metrics(self) -> workload.Response:
for model in self.context.models:
model.eval()

for callback in self.callbacks.values():
logging.warning(
"on_validation_step_start is now deprecated, please use on_validation_start instead"
)
callback.on_validation_step_start()

for callback in self.callbacks.values():
callback.on_validation_start()
with det._execute_user_func():
for callback in self.callbacks.values():
logging.warning(
"on_validation_step_start is now deprecated, please use on_validation_start instead"
)
callback.on_validation_step_start()

with det._execute_user_func():
for callback in self.callbacks.values():
callback.on_validation_start()

num_inputs = 0
metrics = {} # type: Optional[Dict[str, Any]]
Expand All @@ -450,7 +459,8 @@ def _compute_validation_metrics(self) -> workload.Response:
batch = self.context.to_device(batch)
num_inputs += data_length(batch)

vld_metrics = self.trial.evaluate_batch(batch=batch)
with det._execute_user_func():
vld_metrics = self.trial.evaluate_batch(batch=batch)
# Verify validation metric names are the same across batches.
if keys is None:
keys = vld_metrics.keys()
Expand Down Expand Up @@ -483,13 +493,14 @@ def _compute_validation_metrics(self) -> workload.Response:
check.true(self._evaluate_full_dataset_defined())
self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader)
if self.is_chief:
metrics = self.trial.evaluate_full_dataset(data_loader=self.validation_loader)
with det._execute_user_func():
metrics = self.trial.evaluate_full_dataset(data_loader=self.validation_loader)

check.is_instance(
metrics, dict, f"eval() must return a dictionary, got {type(metrics)}."
)

metrics = self._convert_metrics_to_numpy(metrics)
metrics = self._convert_metrics_to_numpy(cast(Dict, metrics))
num_inputs = self.context.get_per_slot_batch_size() * len(self.validation_loader)

if self.hvd_config.use and any(
Expand All @@ -505,14 +516,14 @@ def _compute_validation_metrics(self) -> workload.Response:
)
metrics = hvd.broadcast_object(metrics, root_rank=0)

for callback in self.callbacks.values():
logging.warning(
"on_validation_step_end is now deprecated, please use on_validation_end instead"
)
callback.on_validation_step_end(cast(Dict[str, Any], metrics))

for callback in self.callbacks.values():
callback.on_validation_end(cast(Dict[str, Any], metrics))
with det._execute_user_func():
for callback in self.callbacks.values():
logging.warning(
"on_validation_step_end is now deprecated, please use on_validation_end instead"
)
callback.on_validation_step_end(cast(Dict[str, Any], metrics))
for callback in self.callbacks.values():
callback.on_validation_end(cast(Dict[str, Any], metrics))

if not self.is_chief:
return workload.Skipped()
Expand All @@ -522,7 +533,8 @@ def _compute_validation_metrics(self) -> workload.Response:
def _prepare_metrics_reducers(self, keys: Any) -> Dict[str, Reducer]:
metrics_reducers = {} # type: Dict[str, Reducer]
if isinstance(self.trial.evaluation_reducer(), Dict):
metrics_reducers = cast(Dict[str, Any], self.trial.evaluation_reducer())
with det._execute_user_func():
metrics_reducers = cast(Dict[str, Any], self.trial.evaluation_reducer())
check.eq(
metrics_reducers.keys(),
keys,
Expand All @@ -531,8 +543,9 @@ def _prepare_metrics_reducers(self, keys: Any) -> Dict[str, Reducer]:
f"Expected keys: {keys}, provided keys: {metrics_reducers.keys()}.",
)
elif isinstance(self.trial.evaluation_reducer(), Reducer):
for key in keys:
metrics_reducers[key] = cast(Reducer, self.trial.evaluation_reducer())
with det._execute_user_func():
for key in keys:
metrics_reducers[key] = cast(Reducer, self.trial.evaluation_reducer())

for key in keys:
check.true(
Expand Down Expand Up @@ -757,8 +770,9 @@ def _save(self, path: pathlib.Path) -> workload.Response:
checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle
)

for callback in self.callbacks.values():
callback.on_checkpoint_end(str(path))
with det._execute_user_func():
for callback in self.callbacks.values():
callback.on_checkpoint_end(str(path))

return cast(
workload.Response,
Expand Down

0 comments on commit 9484250

Please sign in to comment.