diff --git a/docs/requirements.txt b/docs/requirements.txt index 3969823f8c..4a11c767e5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -f https://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp37-cp37m-linux_x86_64.whl torch>=1.4.0 -pytorch-ignite==0.3.0 +pytorch-ignite==0.4.2 numpy>=1.17 itk nibabel diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index cd1c3e404b..20cb29227e 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -25,8 +25,8 @@ from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") class Evaluator(Workflow): diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 0e6e01d186..7110a09c0f 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -19,15 +19,15 @@ from monai.engines.utils import get_devices_spec from monai.utils import exact_version, optional_import -create_supervised_trainer, _ = optional_import("ignite.engine", "0.3.0", exact_version, "create_supervised_trainer") -create_supervised_evaluator, _ = optional_import("ignite.engine", "0.3.0", exact_version, "create_supervised_evaluator") -_prepare_batch, _ = optional_import("ignite.engine", "0.3.0", exact_version, "_prepare_batch") +create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.2", exact_version, "create_supervised_trainer") +create_supervised_evaluator, _ = optional_import("ignite.engine", "0.4.2", exact_version, "create_supervised_evaluator") +_prepare_batch, _ = optional_import("ignite.engine", "0.4.2", exact_version, "_prepare_batch") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float: diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index b352e8bc92..faea81d053 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,8 +26,8 @@ from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") class Trainer(Workflow): @@ -42,8 +42,6 @@ def run(self) -> None: If call this function multiple times, it will continuously run from the previous state. """ - if self._is_done(self.state): - self.state.iteration = 0 # to avoid creating new State instance in ignite Engine.run self.scaler = torch.cuda.amp.GradScaler() if self.amp else None super().run() diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8cbe928603..208e689122 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -19,15 +19,15 @@ from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import -IgniteEngine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") -State, _ = optional_import("ignite.engine", "0.3.0", exact_version, "State") -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +IgniteEngine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") +State, _ = optional_import("ignite.engine", "0.4.2", exact_version, "State") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import @@ -99,7 +99,7 @@ def set_sampler_epoch(engine: Engine): iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=-1, + epoch_length=len(data_loader), output=None, batch=None, metrics={}, @@ -154,7 +154,7 @@ def run(self) -> None: Execute training, validation or evaluation based on Ignite Engine. """ - super().run(data=self.data_loader, epoch_length=len(self.data_loader)) + super().run(data=self.data_loader, max_epochs=self.state.max_epochs) def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index a77f8cac8f..029dcf4ee4 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -16,12 +16,12 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") -Checkpoint, _ = optional_import("ignite.handlers", "0.3.0", exact_version, "Checkpoint") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "Checkpoint") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class CheckpointLoader: diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index aea8ad7e4c..713358426e 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -14,12 +14,12 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") -ModelCheckpoint, _ = optional_import("ignite.handlers", "0.3.0", exact_version, "ModelCheckpoint") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +ModelCheckpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "ModelCheckpoint") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class CheckpointSaver: diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 9eeb6134a6..e446773144 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -15,11 +15,11 @@ from monai.data import CSVSaver from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class ClassificationSaver: diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index a9c59b8985..9fd2f64885 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -16,11 +16,11 @@ from monai.utils import ensure_tuple, exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class LrScheduleHandler: diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 3e901dcd1d..fcb2317ef6 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -16,10 +16,10 @@ from monai.metrics import DiceMetric from monai.utils import MetricReduction, exact_version, optional_import -NotComputableError, _ = optional_import("ignite.exceptions", "0.3.0", exact_version, "NotComputableError") -Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.3.0", exact_version, "reinit__is_reduced") -sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.3.0", exact_version, "sync_all_reduce") +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") class MeanDice(Metric): # type: ignore[valid-type, misc] # due to optional_import diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index f19da1facf..3198fcce6a 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -14,11 +14,11 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class MetricLogger: diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 9143aab94c..22080a1bb5 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -17,8 +17,8 @@ from monai.metrics import compute_roc_auc from monai.utils import Average, exact_version, optional_import -Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.3.0", exact_version, "reinit__is_reduced") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") class ROCAUC(Metric): # type: ignore[valid-type, misc] # due to optional_import diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 5c822c016b..444768d555 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -17,11 +17,11 @@ from monai.data import NiftiSaver, PNGSaver from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class SegmentationSaver: diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index 309481c4f9..2c96f00316 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -14,11 +14,11 @@ from monai.data import SmartCacheDataset from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class SmartCacheHandler: diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 2546e40204..b38d5ade9e 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -17,11 +17,11 @@ from monai.utils import exact_version, is_scalar, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} " DEFAULT_TAG = "Loss" diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 0ab323a6e0..a9d7d661ec 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -18,12 +18,12 @@ from monai.utils import exact_version, is_scalar, optional_import from monai.visualize import plot_2d_or_3d_image -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine from torch.utils.tensorboard import SummaryWriter else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") DEFAULT_TAG = "Loss" diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 051ec74be7..3b59d4a779 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index e03d4909ea..45261c1548 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -14,11 +14,11 @@ from monai.engines.evaluator import Evaluator from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class ValidationHandler: diff --git a/requirements-dev.txt b/requirements-dev.txt index a2bc33c326..b0e21b18fb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.3.0 +pytorch-ignite==0.4.2 gdown>=3.6.4 scipy itk diff --git a/setup.cfg b/setup.cfg index 3c0cda087f..34eb9cb285 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ all = scikit-image>=0.14.2 pillow tensorboard - pytorch-ignite==0.3.0 + pytorch-ignite==0.4.2 gdown>=3.6.4 torchvision itk @@ -40,7 +40,7 @@ tensorboard = gdown = gdown>=3.6.4 ignite = - pytorch-ignite==0.3.0 + pytorch-ignite==0.4.2 torchvision = torchvision itk = diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 7a414b9133..438e73bf3a 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -32,11 +32,12 @@ def test_one_save_one_load(self): data2 = net2.state_dict() data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) - engine = Engine(lambda e, b: None) with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pth" + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) @@ -52,12 +53,13 @@ def test_two_save_one_load(self): data2 = net2.state_dict() data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) - engine = Engine(lambda e, b: None) with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) save_dict = {"net": net1, "opt": optimizer} CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine) engine.run([0] * 8, max_epochs=5) - path = tempdir + "/checkpoint_final_iteration=40.pth" + path = tempdir + "/checkpoint_final_iteration=40.pt" + engine = Engine(lambda e, b: None) CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) @@ -73,11 +75,12 @@ def test_save_single_device_load_multi_devices(self): data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) net2 = torch.nn.DataParallel(net2) - engine = Engine(lambda e, b: None) with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pth" + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), 0.1) diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 2c9e649bb6..5952316cc2 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -22,7 +22,7 @@ from monai.handlers import CheckpointSaver -TEST_CASE_1 = [True, False, None, 1, True, 0, None, ["test_checkpoint_final_iteration=40.pth"]] +TEST_CASE_1 = [True, False, None, 1, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]] TEST_CASE_2 = [ False, @@ -32,10 +32,10 @@ True, 0, None, - ["test_checkpoint_key_metric=32.pth", "test_checkpoint_key_metric=40.pth"], + ["test_checkpoint_key_metric=32.pt", "test_checkpoint_key_metric=40.pt"], ] -TEST_CASE_3 = [False, False, None, 1, True, 2, 2, ["test_checkpoint_epoch=2.pth", "test_checkpoint_epoch=4.pth"]] +TEST_CASE_3 = [False, False, None, 1, True, 2, 2, ["test_checkpoint_epoch=2.pt", "test_checkpoint_epoch=4.pt"]] TEST_CASE_4 = [ False, @@ -45,10 +45,10 @@ False, 10, 2, - ["test_checkpoint_iteration=30.pth", "test_checkpoint_iteration=40.pth"], + ["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"], ] -TEST_CASE_5 = [True, False, None, 1, True, 0, None, ["test_checkpoint_final_iteration=40.pth"], True] +TEST_CASE_5 = [True, False, None, 1, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True] class TestHandlerCheckpointSaver(unittest.TestCase): @@ -115,7 +115,7 @@ def _train_func(engine, batch): with self.assertRaises(RuntimeError): engine.run(range(3), max_epochs=2) - self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pth"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt"))) if __name__ == "__main__": diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 4aa6a89bb5..13d141dc73 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -20,6 +20,7 @@ def main(): dist.init_process_group(backend="nccl", init_method="env://") + torch.cuda.set_device(dist.get_rank()) auc_metric = ROCAUC(to_onehot_y=True, softmax=True) if dist.get_rank() == 0: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index bcac66b1ad..ad7b78454e 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -260,106 +260,107 @@ def test_training(self): best_metric = run_training_test(self.data_dir, device=self.device, amp=(i == 2)) print("best metric", best_metric) if i == 2: - np.testing.assert_allclose(best_metric, 0.924358, rtol=1e-2) + np.testing.assert_allclose(best_metric, 0.9218417823314666, rtol=1e-2) else: - np.testing.assert_allclose(best_metric, 0.9250373750925064, rtol=1e-2) + np.testing.assert_allclose(best_metric, 0.9219249188899994, rtol=1e-2) repeated[i].append(best_metric) - model_file = sorted(glob(os.path.join(self.data_dir, "net_key_metric*.pth")))[-1] + model_file = sorted(glob(os.path.join(self.data_dir, "net_key_metric*.pt")))[-1] infer_metric = run_inference_test(self.data_dir, model_file, device=self.device, amp=(i == 2)) print("infer metric", infer_metric) # check inference properties if i == 2: - np.testing.assert_allclose(infer_metric, 0.924627597630024, rtol=1e-2) + np.testing.assert_allclose(infer_metric, 0.9216801583766937, rtol=1e-2) else: - np.testing.assert_allclose(infer_metric, 0.9246308669447899, rtol=1e-2) + np.testing.assert_allclose(infer_metric, 0.9217007756233215, rtol=1e-2) repeated[i].append(infer_metric) + output_files = sorted(glob(os.path.join(self.data_dir, "img*", "*.nii.gz"))) if i == 2: sums = [ - 0.14114046096801758, - 0.1504497528076172, - 0.13713788986206055, - 0.13302993774414062, - 0.18422222137451172, - 0.16304492950439453, - 0.13993120193481445, - 0.16569805145263672, - 0.1551837921142578, - 0.1755976676940918, - 0.16045379638671875, - 0.16413402557373047, - 0.14251232147216797, - 0.10928630828857422, - 0.16003799438476562, - 0.19595718383789062, - 0.17368268966674805, - 0.05275678634643555, - 0.19002151489257812, - 0.1982269287109375, - 0.19471073150634766, - 0.20270061492919922, - 0.1594076156616211, - 0.13070344924926758, - 0.14964008331298828, - 0.13594627380371094, - 0.2263627052307129, - 0.16036462783813477, - 0.14667415618896484, - 0.10274696350097656, - 0.11820268630981445, - 0.12948942184448242, - 0.11093902587890625, - 0.15072298049926758, - 0.1591496467590332, - 0.1892232894897461, - 0.2160496711730957, - 0.17680883407592773, - 0.18494272232055664, - 0.035521507263183594, + 0.14187908172607422, + 0.15141582489013672, + 0.13805580139160156, + 0.1335921287536621, + 0.18461370468139648, + 0.1636824607849121, + 0.1409168243408203, + 0.1665506362915039, + 0.15652847290039062, + 0.17659997940063477, + 0.16120052337646484, + 0.1645350456237793, + 0.14385366439819336, + 0.11049985885620117, + 0.16086244583129883, + 0.19636201858520508, + 0.17445993423461914, + 0.05356025695800781, + 0.19049406051635742, + 0.19910669326782227, + 0.1953139305114746, + 0.2030935287475586, + 0.1603412628173828, + 0.1317133903503418, + 0.1511821746826172, + 0.13686084747314453, + 0.22674274444580078, + 0.16118431091308594, + 0.14728116989135742, + 0.10426139831542969, + 0.11961984634399414, + 0.13056421279907227, + 0.11201953887939453, + 0.15174198150634766, + 0.15967845916748047, + 0.1898040771484375, + 0.2165522575378418, + 0.17767810821533203, + 0.18523073196411133, + 0.03636026382446289, ] else: sums = [ - 0.14113855361938477, - 0.1504507064819336, - 0.13713932037353516, - 0.13303327560424805, - 0.1842188835144043, - 0.16304492950439453, - 0.13993024826049805, - 0.1656951904296875, - 0.1551809310913086, - 0.17559528350830078, - 0.16044998168945312, - 0.16412973403930664, - 0.14251136779785156, - 0.10928821563720703, - 0.1600356101989746, - 0.1959514617919922, - 0.17368221282958984, - 0.05275869369506836, - 0.1900186538696289, - 0.19822216033935547, - 0.19471025466918945, - 0.2026987075805664, - 0.1594090461730957, - 0.1307048797607422, - 0.1496415138244629, - 0.13594770431518555, - 0.2263627052307129, - 0.16036462783813477, - 0.14667081832885742, - 0.10274934768676758, - 0.11820459365844727, - 0.1294875144958496, - 0.11093950271606445, - 0.15072107315063477, - 0.15914440155029297, - 0.1892228126525879, - 0.21604537963867188, - 0.1768054962158203, - 0.1849384307861328, - 0.0355219841003418, + 0.14187145233154297, + 0.15141057968139648, + 0.13805103302001953, + 0.13358211517333984, + 0.18460750579833984, + 0.16367673873901367, + 0.14090681076049805, + 0.16654396057128906, + 0.1565251350402832, + 0.17658662796020508, + 0.16119146347045898, + 0.1645212173461914, + 0.1438441276550293, + 0.11049842834472656, + 0.1608572006225586, + 0.1963505744934082, + 0.17445039749145508, + 0.0535578727722168, + 0.19048500061035156, + 0.19910240173339844, + 0.19530916213989258, + 0.20308685302734375, + 0.16033363342285156, + 0.13170766830444336, + 0.15117835998535156, + 0.13685131072998047, + 0.22673511505126953, + 0.16117477416992188, + 0.14727354049682617, + 0.10425710678100586, + 0.11961698532104492, + 0.13056182861328125, + 0.11200904846191406, + 0.1517343521118164, + 0.15967321395874023, + 0.18979740142822266, + 0.21654462814331055, + 0.17766666412353516, + 0.18522214889526367, + 0.03636026382446289, ] for (output, s) in zip(output_files, sums): ave = np.mean(nib.load(output).get_fdata()) diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index e6abcc641c..fc28ea8647 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -34,14 +34,16 @@ class TestParallelExecution(unittest.TestCase): @expect_failure_if_no_gpu def test_single_gpu(self): - net = torch.nn.Conv2d(1, 1, 3, padding=1) + device = torch.device("cuda:0") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) opt = torch.optim.Adam(net.parameters(), 1e-3) - trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device("cuda:0")]) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device]) trainer.run(fake_data_stream(), 2, 2) @expect_failure_if_no_gpu def test_multi_gpu(self): - net = torch.nn.Conv2d(1, 1, 3, padding=1) + device = torch.device("cuda") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) opt = torch.optim.Adam(net.parameters(), 1e-3) with warnings.catch_warnings():