Skip to content

Commit

Permalink
Add deterministic compute tests (#4716)
Browse files Browse the repository at this point in the history
Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Thiago Crepaldi <[email protected]>
  • Loading branch information
3 people committed Aug 14, 2020
1 parent 92c5ba4 commit d5f2c6e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 22 deletions.
31 changes: 31 additions & 0 deletions orttraining/orttraining/python/training/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import numpy as np
import os
import sys
import torch

from numpy.testing import assert_allclose
from onnxruntime.capi.training import orttrainer

def compare_onnx_weights(model_a, model_b, verbose=False, rtol=1e-4):
r"""Compare whether weights between 'model_a' and 'model_b' ONNX models are within
a certain tolerance 'rtol'
Compares the weights of two different ONNX models and throws an error when they diverge
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): Indicates if the max absolute difference for each layer should be
calculated and printed for debug information.
rtol (float, default is 1e-4): Tolerance for divergence.
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
for (a_name, a_val), (b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
np_a_vals = np.array(a_val).flatten()
np_b_vals = np.array(b_val).flatten()
assert np_a_vals.shape == np_b_vals.shape
if verbose:
print(f'Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}')
assert_allclose(a_val, b_val, rtol=rtol, err_msg=f"Weight mismatch for {a_name}")

4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
if 'cuda' in device_id.lower():
set_cuda_mem_limit(int(self.options.device.mem_limit))
if ':' in device_id:
set_cuda_device_id(device_id.split(':')[1])
set_cuda_device_id(int(device_id.split(':')[1]))

self._train_step_info = TrainStepInfo(all_finite=True, step=0, optimizer_config=self.optim_config)

Expand Down Expand Up @@ -345,7 +345,7 @@ def forward(self, *inputs):
return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, inputs):
device = torch.device(self.options.device.id)
device = torch.device('cpu') #torch.device(self.options.device.id)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(inputs, dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from onnxruntime.capi.training import orttrainer_options as orttrainer_options
from onnxruntime.capi.training import model_desc_validation as md_val
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo, _utils, debug
from onnxruntime.capi._pybind_state import set_seed


@pytest.mark.parametrize("test_input", [
Expand Down Expand Up @@ -456,20 +457,7 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
assert_allclose(lr_list[0],
expected_values[step], rtol=rtol, err_msg="lr mismatch")


@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [
('train_step', None, None),
('eval_step', None, None),
('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
def generate_pytorch_transformer_model_sample(optim_config, options={}, step_fn='train_step', device='cpu'):
# Loading external TransformerModel model for testing
# A manual import is done as this example is not part of onnxruntime package,
# but resides on the onnxruntime repo
Expand All @@ -489,6 +477,36 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
my_loss = ort_utils.my_loss
model_desc = ort_utils.transformer_model_description()

# Set up relevant options
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

# Preparing data
train_data, val_data, _ = utils.prepare_data(device, 20, 20)

if step_fn == 'eval_step':
data, targets = utils.get_batch(val_data, 0)
elif step_fn == 'train_step':
data, targets = utils.get_batch(train_data, 0)
else:
raise ValueError('Invalid step_fn')

data, targets = data.to(trainer.options.device.id), targets.to(trainer.options.device.id)

return model, model_desc, trainer, data, targets

@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values", [
('train_step', None, None),
('eval_step', None, None),
('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
max_train_step = 1
warmup = 0.5
initial_lr = 1
Expand All @@ -502,20 +520,17 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
opts.update({'lr_scheduler' : lr_scheduler(max_train_step, warmup)})

opts = orttrainer.ORTTrainerOptions(opts)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)

# Preparing data
train_data, val_data, _ = utils.prepare_data('cpu', 20, 20)
# Using PyTorch Transformer model as example
model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, step_fn)

# Export model to ONNX
if step_fn == 'eval_step':
step_fn = trainer.eval_step
data, targets = utils.get_batch(val_data, 0)
output = trainer.eval_step(data, targets)
elif step_fn == 'train_step':
step_fn = trainer.train_step
for i in range(max_train_step):
data, targets = utils.get_batch(train_data, 0)
output = trainer.train_step(data, targets)
if lr_scheduler:
lr_list = trainer.options.lr_scheduler.get_last_lr()
Expand Down Expand Up @@ -573,3 +588,43 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values):
assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph)
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))


@pytest.mark.parametrize("seed, device_id", [
(0, 'cpu'),
(42, 'cpu'),
(0, 'cuda:0'),
(24, 'cuda')
])
def testORTDeterministicCompute(seed, device_id):
optim_config = optim.LambConfig()
opts = orttrainer.ORTTrainerOptions({
'debug' : {
'deterministic_compute': True
},
'device' : {
'id' : device_id,
'mem_limit' : 10*1024*1024
}
})

torch.manual_seed(seed)
set_seed(seed)

# Using PyTorch Transformer model as example
model, model_desc, trainer, data, targets = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id)

# Run first model train step
output = trainer.train_step(data, targets)
assert trainer._onnx_model is not None

# Reset the seeds
torch.manual_seed(seed)
set_seed(seed)

# Run second model train step
_, _, second_trainer, _, _ = generate_pytorch_transformer_model_sample(optim_config, opts, device=device_id)
output = second_trainer.train_step(data, targets)
assert second_trainer._onnx_model is not None
assert id(trainer._onnx_model) != id(second_trainer._onnx_model)

debug.compare_onnx_weights(trainer, second_trainer)

0 comments on commit d5f2c6e

Please sign in to comment.