Skip to content

Commit

Permalink
test: fix nightly nas and iris tf keras tests [DET-3264] (determined-…
Browse files Browse the repository at this point in the history
…ai#644)

* docs: update NAS example to use correct gradient clipping
* test: set random seed for nightly iris tf_keras test
  • Loading branch information
aaron276h authored Jun 4, 2020
1 parent 4ff9fa0 commit 84e875a
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions e2e_tests/tests/nightly/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_object_detection_accuracy() -> None:
@pytest.mark.nightly # type: ignore
def test_iris_tf_keras() -> None:
config = conf.load_config(conf.official_examples_path("iris_tf_keras/const.yaml"))
config = conf.set_random_seed(config, 1591280374)
experiment_id = exp.run_basic_test_with_temp_config(
config, conf.official_examples_path("iris_tf_keras"), 1
)
Expand Down
2 changes: 1 addition & 1 deletion examples/experimental/nas_search/arch_search.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
description: NAS_Search
hyperparameters:
clip_grad_l2_norm: .25
clip_gradients_l2_norm: .25
global_batch_size: 64
bptt: 35
learning_rate: 20
Expand Down
5 changes: 4 additions & 1 deletion examples/experimental/nas_search/model_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.optim.lr_scheduler import _LRScheduler

import determined as det
from determined.pytorch import DataLoader, PyTorchTrial, LRScheduler
from determined.pytorch import ClipGradsL2Norm, DataLoader, PyTorchCallback, PyTorchTrial, LRScheduler


import data
Expand Down Expand Up @@ -343,3 +343,6 @@ def build_validation_data_loader(self) -> DataLoader:
),
collate_fn=data.PadSequence(),
)

def build_callbacks(self) -> Dict[str, PyTorchCallback]:
return {"clip_grads": ClipGradsL2Norm(self.context.get_hparam("clip_gradients_l2_norm"))}
2 changes: 1 addition & 1 deletion examples/experimental/nas_search/train_one_arch.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
description: NAS_ASHA
hyperparameters:
clip_grad_l2_norm: .25
clip_gradients_l2_norm: .25
global_batch_size: 64
bptt: 35
learning_rate: 20
Expand Down

0 comments on commit 84e875a

Please sign in to comment.