From d5f2c6e8d37813fd2c956d129fc10f791acfb5b9 Mon Sep 17 00:00:00 2001 From: Rayan-Krishnan Date: Thu, 6 Aug 2020 10:30:20 -0700 Subject: [PATCH] Add deterministic compute tests (#4716) Co-authored-by: Rayan Krishnan Co-authored-by: Thiago Crepaldi --- .../orttraining/python/training/debug.py | 31 ++++++ .../orttraining/python/training/orttrainer.py | 4 +- .../orttraining_test_orttrainer_frontend.py | 95 +++++++++++++++---- 3 files changed, 108 insertions(+), 22 deletions(-) create mode 100644 orttraining/orttraining/python/training/debug.py diff --git a/orttraining/orttraining/python/training/debug.py b/orttraining/orttraining/python/training/debug.py new file mode 100644 index 0000000000000..b6a3016b2e602 --- /dev/null +++ b/orttraining/orttraining/python/training/debug.py @@ -0,0 +1,31 @@ + +import numpy as np +import os +import sys +import torch + +from numpy.testing import assert_allclose +from onnxruntime.capi.training import orttrainer + +def compare_onnx_weights(model_a, model_b, verbose=False, rtol=1e-4): + r"""Compare whether weights between 'model_a' and 'model_b' ONNX models are within + a certain tolerance 'rtol' + + Compares the weights of two different ONNX models and throws an error when they diverge + Args: + model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure + verbose (bool, default is False): Indicates if the max absolute difference for each layer should be + calculated and printed for debug information. + rtol (float, default is 1e-4): Tolerance for divergence. + """ + assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer) + state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state() + assert len(state_dict_a.items()) == len(state_dict_b.items()) + for (a_name, a_val), (b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()): + np_a_vals = np.array(a_val).flatten() + np_b_vals = np.array(b_val).flatten() + assert np_a_vals.shape == np_b_vals.shape + if verbose: + print(f'Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}') + assert_allclose(a_val, b_val, rtol=rtol, err_msg=f"Weight mismatch for {a_name}") + diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 7bb7436fe358a..fe9b1138acf94 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -156,7 +156,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): if 'cuda' in device_id.lower(): set_cuda_mem_limit(int(self.options.device.mem_limit)) if ':' in device_id: - set_cuda_device_id(device_id.split(':')[1]) + set_cuda_device_id(int(device_id.split(':')[1])) self._train_step_info = TrainStepInfo(all_finite=True, step=0, optimizer_config=self.optim_config) @@ -345,7 +345,7 @@ def forward(self, *inputs): return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names) def _convert_torch_model_loss_fn_to_onnx(self, inputs): - device = torch.device(self.options.device.id) + device = torch.device('cpu') #torch.device(self.options.device.id) if isinstance(inputs, torch.Tensor): inputs = [inputs] if isinstance(inputs, dict): diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 67aa8429d2ff4..4cb487ed8aa33 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -8,7 +8,8 @@ from onnxruntime.capi.training import orttrainer_options as orttrainer_options from onnxruntime.capi.training import model_desc_validation as md_val -from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils +from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils, debug +from onnxruntime.capi._pybind_state import set_seed @pytest.mark.parametrize("test_input", [ @@ -456,20 +457,7 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): assert_allclose(lr_list[0], expected_values[step], rtol=rtol, err_msg="lr mismatch") - -@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [ - ('train_step', None, None), - ('eval_step', None, None), - ('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, - 0.023843, 0.023843, 0.023843, 0.023843, 0.023843]), - ('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, - 0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]), - ('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, - 0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]), - ('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, - 0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833]) -]) -def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values): +def generate_pytorch_transformer_model_sample(optim_config, options={}, step_fn='train_step', device='cpu'): # Loading external TransformerModel model for testing # A manual import is done as this example is not part of onnxruntime package, # but resides on the onnxruntime repo @@ -489,6 +477,36 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values): my_loss = ort_utils.my_loss model_desc = ort_utils.transformer_model_description() + # Set up relevant options + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) + + # Preparing data + train_data, val_data, _ = utils.prepare_data(device, 20, 20) + + if step_fn == 'eval_step': + data, targets = utils.get_batch(val_data, 0) + elif step_fn == 'train_step': + data, targets = utils.get_batch(train_data, 0) + else: + raise ValueError('Invalid step_fn') + + data, targets = data.to(trainer.options.device.id), targets.to(trainer.options.device.id) + + return model, model_desc, trainer, data, targets + +@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [ + ('train_step', None, None), + ('eval_step', None, None), + ('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.023843, 0.023843, 0.023843, 0.023843, 0.023843]), + ('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]), + ('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]), + ('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833]) +]) +def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values): max_train_step = 1 warmup = 0.5 initial_lr = 1 @@ -502,20 +520,17 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values): opts.update({'lr_scheduler' : lr_scheduler(max_train_step, warmup)}) opts = orttrainer.ORTTrainerOptions(opts) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - # Preparing data - train_data, val_data, _ = utils.prepare_data('cpu', 20, 20) + # Using PyTorch Transformer model as example + model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, step_fn) # Export model to ONNX if step_fn == 'eval_step': step_fn = trainer.eval_step - data, targets = utils.get_batch(val_data, 0) output = trainer.eval_step(data, targets) elif step_fn == 'train_step': step_fn = trainer.train_step for i in range(max_train_step): - data, targets = utils.get_batch(train_data, 0) output = trainer.train_step(data, targets) if lr_scheduler: lr_list = trainer.options.lr_scheduler.get_last_lr() @@ -573,3 +588,43 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values): assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph) assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph)) + +@pytest.mark.parametrize("seed, device_id", [ + (0, 'cpu'), + (42, 'cpu'), + (0, 'cuda:0'), + (24, 'cuda') +]) +def testORTDeterministicCompute(seed, device_id): + optim_config = optim.LambConfig() + opts = orttrainer.ORTTrainerOptions({ + 'debug' : { + 'deterministic_compute': True + }, + 'device' : { + 'id' : device_id, + 'mem_limit' : 10*1024*1024 + } + }) + + torch.manual_seed(seed) + set_seed(seed) + + # Using PyTorch Transformer model as example + model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id) + + # Run first model train step + output = trainer.train_step(data, targets) + assert trainer._onnx_model is not None + + # Reset the seeds + torch.manual_seed(seed) + set_seed(seed) + + # Run second model train step + _, _, second_trainer, _, _ = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id) + output = second_trainer.train_step(data, targets) + assert second_trainer._onnx_model is not None + assert id(trainer._onnx_model) != id(second_trainer._onnx_model) + + debug.compare_onnx_weights(trainer, second_trainer)