Skip to content

Commit

Permalink
Add ONNX BERT example + minor fixes (#4757)
Browse files Browse the repository at this point in the history
* Fix training issue when passing ONNX file into ORTTrainer

Co-authored-by: Thiago Crepaldi <[email protected]>
Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
3 people committed Aug 14, 2020
1 parent ab0249a commit 4fec029
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 11 deletions.
19 changes: 19 additions & 0 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
if self.options._internal_use.extra_postprocess:
self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model)

# When input model is already ONNX (and not exported from Pytorch within ORTTrainer),
# append 'dtype' from ONNX into model description's
for idx_i, i_desc in enumerate(self.model_desc.inputs):
dtype = None
for onnx_input in self._onnx_model.graph.input:
if onnx_input.name == i_desc.name:
dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type)
self.model_desc.add_type_to_input_description(idx_i, dtype)
break
assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})"
for idx_o, o_desc in enumerate(self.model_desc.outputs):
dtype = None
for onnx_output in self._onnx_model.graph.output:
if onnx_output.name == o_desc.name:
dtype = _utils.dtype_onnx_to_torch(onnx_output.type.tensor_type.elem_type)
self.model_desc.add_type_to_output_description(idx_o, dtype)
break
assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})"

# Set GPU device and memory limit
if 'cuda' in self.options.device.id.lower():
mem_limit = self.options.device.mem_limit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

# generate sample input for our example
import inspect
import onnx
import os
import pytest
import torch

from numpy.testing import assert_allclose

from onnxruntime.capi._pybind_state import set_seed
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
ModelDescription as Legacy_ModelDescription,\
LossScaler as Legacy_LossScaler,\
ORTTrainer as Legacy_ORTTrainer
from onnxruntime.capi.training import _utils, amp, optim, orttrainer, TrainStepInfo,\
model_desc_validation as md_val,\
orttrainer_options as orttrainer_options


###############################################################################
# Helper functions ############################################################
###############################################################################


def generate_random_input_from_model_desc(desc):
dtype = torch.int64
vocab_size = 30528
num_classes = [vocab_size, 2, 2, vocab_size, 2]
device = "cuda:0"
sample_input = []
for index, input in enumerate(desc['inputs']):
sample_input.append(torch.randint(0, num_classes[index], tuple(input[1]), dtype=dtype).to(device))
return sample_input

def bert_model_description():
vocab_size = 30528
batch_size = 16
seq_len = 1
model_desc = {'inputs': [('input_ids', [batch_size, seq_len]),
('segment_ids', [batch_size, seq_len],),
('input_mask', [batch_size, seq_len],),
('masked_lm_labels', [batch_size, seq_len],),
('next_sentence_labels', [batch_size, ],)],
'outputs': [('loss', [], True)]}
return model_desc


###############################################################################
# Testing starts here #########################################################
###############################################################################


def testToyBERTModel():
#print(bert_model_description())
model_desc = bert_model_description()
device = torch.device("cuda", 0)

pytorch_transformer_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata')
bert_onnx_model_path = os.path.join(pytorch_transformer_path, "bert_toy_postprocessed.onnx")
model = onnx.load(bert_onnx_model_path)

optim_config = optim.LambConfig()
opts = orttrainer.ORTTrainerOptions({
'debug' : {
'deterministic_compute': True
},
'device' : {
'id' : "cuda:0",
}

})

torch.manual_seed(1)
set_seed(1)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)

sample_input = generate_random_input_from_model_desc(model_desc)

output = trainer.train_step(*sample_input)
#print(output)
assert output.shape == torch.Size([])

Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
ModelDescription as Legacy_ModelDescription,\
LossScaler as Legacy_LossScaler,\
ORTTrainer as Legacy_ORTTrainer
from onnxruntime.capi.training import _utils, amp, debug, optim, orttrainer, TrainStepInfo,\
from onnxruntime.capi.training import _utils, amp, optim, orttrainer, TrainStepInfo,\
model_desc_validation as md_val,\
orttrainer_options as orttrainer_options

import _test_helpers

###############################################################################
# Helper functions ############################################################
Expand Down Expand Up @@ -644,7 +644,7 @@ def testORTDeterministicCompute(seed, device):

# Compare two different instances with identical setup
assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model)
debug.assert_onnx_weights(first_trainer, second_trainer)
_test_helpers.assert_onnx_weights(first_trainer, second_trainer)


@pytest.mark.parametrize("seed,device,expected_loss", [
Expand All @@ -658,8 +658,7 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss):

# Setup ORTTrainer
loss_scaler = amp.DynamicLossScaler()
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device,
'mem_limit' : 100*1024*1024},
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'mixed_precision' : {
'enabled' : True,
'loss_scaler' : loss_scaler},
Expand All @@ -679,10 +678,9 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss):
actual_loss.append(loss.cpu())

# Compare loss to ground truth computed from current ORTTrainer API
debug.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
assert trainer._onnx_model is not None


###############################################################################
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
###############################################################################
Expand Down Expand Up @@ -733,7 +731,7 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
_, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))

# Compare legacy vs experimental APIs
debug.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)
_test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4)


@pytest.mark.parametrize("seed,device", [
Expand Down Expand Up @@ -793,5 +791,5 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):

# Compare legacy vs experimental APIs
assert experimental_preds_dtype == legacy_preds_dtype
debug.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
debug.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
_test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=1e-4, atol=1e-2)
_test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-4)
1 change: 0 additions & 1 deletion samples/python/pytorch_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torchtext
from torchtext.data.utils import get_tokenizer


def batchify(data, bsz, TEXT, device):
data = TEXT.numericalize([data.examples[0].text])
# Divide the dataset into bsz parts.
Expand Down

0 comments on commit 4fec029

Please sign in to comment.