-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ONNX BERT example + minor fixes (#4757)
* Fix training issue when passing ONNX file into ORTTrainer Co-authored-by: Thiago Crepaldi <[email protected]> Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
- Loading branch information
1 parent
49c88e4
commit fb4c90e
Showing
5 changed files
with
110 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
83 changes: 83 additions & 0 deletions
83
orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
|
||
# generate sample input for our example | ||
import inspect | ||
import onnx | ||
import os | ||
import pytest | ||
import torch | ||
|
||
from numpy.testing import assert_allclose | ||
|
||
from onnxruntime.capi._pybind_state import set_seed | ||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ | ||
ModelDescription as Legacy_ModelDescription,\ | ||
LossScaler as Legacy_LossScaler,\ | ||
ORTTrainer as Legacy_ORTTrainer | ||
from onnxruntime.capi.training import _utils, amp, optim, orttrainer, TrainStepInfo,\ | ||
model_desc_validation as md_val,\ | ||
orttrainer_options as orttrainer_options | ||
|
||
|
||
############################################################################### | ||
# Helper functions ############################################################ | ||
############################################################################### | ||
|
||
|
||
def generate_random_input_from_model_desc(desc): | ||
dtype = torch.int64 | ||
vocab_size = 30528 | ||
num_classes = [vocab_size, 2, 2, vocab_size, 2] | ||
device = "cuda:0" | ||
sample_input = [] | ||
for index, input in enumerate(desc['inputs']): | ||
sample_input.append(torch.randint(0, num_classes[index], tuple(input[1]), dtype=dtype).to(device)) | ||
return sample_input | ||
|
||
def bert_model_description(): | ||
vocab_size = 30528 | ||
batch_size = 16 | ||
seq_len = 1 | ||
model_desc = {'inputs': [('input_ids', [batch_size, seq_len]), | ||
('segment_ids', [batch_size, seq_len],), | ||
('input_mask', [batch_size, seq_len],), | ||
('masked_lm_labels', [batch_size, seq_len],), | ||
('next_sentence_labels', [batch_size, ],)], | ||
'outputs': [('loss', [], True)]} | ||
return model_desc | ||
|
||
|
||
############################################################################### | ||
# Testing starts here ######################################################### | ||
############################################################################### | ||
|
||
|
||
def testToyBERTModel(): | ||
#print(bert_model_description()) | ||
model_desc = bert_model_description() | ||
device = torch.device("cuda", 0) | ||
|
||
pytorch_transformer_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata') | ||
bert_onnx_model_path = os.path.join(pytorch_transformer_path, "bert_toy_postprocessed.onnx") | ||
model = onnx.load(bert_onnx_model_path) | ||
|
||
optim_config = optim.LambConfig() | ||
opts = orttrainer.ORTTrainerOptions({ | ||
'debug' : { | ||
'deterministic_compute': True | ||
}, | ||
'device' : { | ||
'id' : "cuda:0", | ||
} | ||
|
||
}) | ||
|
||
torch.manual_seed(1) | ||
set_seed(1) | ||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) | ||
|
||
sample_input = generate_random_input_from_model_desc(model_desc) | ||
|
||
output = trainer.train_step(*sample_input) | ||
#print(output) | ||
assert output.shape == torch.Size([]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters