Skip to content

Commit

Permalink
Save ONNX model in file (#4671)
Browse files Browse the repository at this point in the history
Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
Thiago Crepaldi and Rayan Krishnan committed Aug 15, 2020
1 parent d46dec0 commit 4e24aac
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 51 deletions.
145 changes: 94 additions & 51 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
import os
import onnx
from onnx import numpy_helper
import torch
from inspect import signature

Expand Down Expand Up @@ -147,6 +149,9 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):

self.model_desc = _ORTTrainerModelDesc(model_desc)
self.optim_config = optim_config

if not options:
options = ORTTrainerOptions()
self.options = options

# Set GPU device and memory limit
Expand Down Expand Up @@ -175,13 +180,31 @@ def eval_step(self, *input, **kwargs):
def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard containing
the full graph, including inference and training metadata.
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard.
The graph includes full information, including inference and training metadata.
Args:
path (str): Full path, including filename, to save the model in the filesystem
path (str): Full path, including filename, to save the ONNX model in the filesystem
Raises:
RuntimeWarning: raised when neither `train_step` or `eval_step` was called at least once
ValueError: raised when `path` is not valid path
"""
pass
if not self._training_session:
raise RuntimeWarning("Training session is not initialized yet. "
"'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'.")
state_tensors = self._training_session.get_state()
self._update_onnx_model_initializers(state_tensors)

assert isinstance(path, str), "'path' must be a valid path string"
dir_name = os.path.dirname(path)
file_name = os.path.basename(path)
if not dir_name or not os.path.exists(dir_name) or not file_name:
raise ValueError("'path' is not valid. It must contain an existing folder + filename")

with open(path, "wb") as f:
f.write(self._onnx_model.SerializeToString())


def train_step(self, *input, **kwargs):
r"""Train step method
Expand Down Expand Up @@ -335,53 +358,6 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs):

return onnx_model

def _init_onnx_model(self, inputs):
if self._onnx_model is not None:
return

if self._torch_model is not None:
# PyTorch model is moved to cpu to save GPU memory
self._torch_model.cpu()

# PyTorch buffers (created using 'register_buffer') shouldn't be trained
torch_buffers = list(dict(self._torch_model.named_buffers()).keys())
self.options.utils.frozen_weights.extend(torch_buffers)

# Export to ONNX
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs)

self._init_session()

def _init_session(self):
if self._onnx_model is None:
return

# Perform internal post-processing
if self.options._internal_use.enable_internal_postprocess:
self._onnx_model = postprocess.run_postprocess(self._onnx_model)

# Perform user-specified post-processing
if self.options._internal_use.extra_postprocess:
self.options._internal_use.extra_postprocess(self._onnx_model)

# Create training session used by train_step
self._create_ort_training_session()
return

def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):
# Normalize input to tuple of samples
if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list:
input = tuple(inputs[0])
else:
input = inputs

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)

return input

# TODO: Test this througly along with train step, including
# various optimizer parameter groups, frozen weights, loss and lr
def _create_ort_training_session(self):
Expand Down Expand Up @@ -447,3 +423,70 @@ def _create_ort_training_session(self):
# I/O bindings
self._train_io_binding = self._training_session.io_binding()
self._eval_io_binding = self._training_session.io_binding()

def _init_onnx_model(self, inputs):
if self._onnx_model is not None:
return

if self._torch_model is not None:
# PyTorch model is moved to cpu to save GPU memory
self._torch_model.cpu()

# PyTorch buffers (created using 'register_buffer') shouldn't be trained
torch_buffers = list(dict(self._torch_model.named_buffers()).keys())
self.options.utils.frozen_weights.extend(torch_buffers)

# Export to ONNX
self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs)

self._init_session()

def _init_session(self):
if self._onnx_model is None:
return

# Perform internal post-processing
if self.options._internal_use.enable_internal_postprocess:
self._onnx_model = postprocess.run_postprocess(self._onnx_model)

# Perform user-specified post-processing
if self.options._internal_use.extra_postprocess:
self.options._internal_use.extra_postprocess(self._onnx_model)

# Create training session used by train_step
self._create_ort_training_session()

def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):
# Normalize input to tuple of samples
if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list:
input = tuple(inputs[0])
else:
input = inputs

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)

return input

def _update_onnx_model_initializers(self, state_tensors):
r""" Updates ONNX graph initializers with state_tensors's values
Usually called to save or load an ONNX model.
The tensors names of state_tensors are compared to all ONNX initializer tensors
and when the name matches, the ONNX graph is updated with the new value.
"""
assert isinstance(state_tensors, dict), "state_tensors must be a dict"

new_weights = []
replace_indices = []
for i, w in enumerate(self._onnx_model.graph.initializer):
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self._onnx_model.graph.initializer[w_i]
self._onnx_model.graph.initializer.extend(new_weights)
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,20 @@ def testInstantiateORTTrainer(step_fn):
assert output_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)

# Save current model as ONNX as a file
file_name = os.path.join('..','..','..','temp_onnx_model.onnx')
trainer.save_as_onnx(file_name)
assert os.path.exists(file_name)
with open(file_name, "rb") as f:
bin_str = f.read()
reload_onnx_model = onnx.load_model_from_string(bin_str)
os.remove(file_name)

# Create a new trainer from persisted ONNX model and compare with original ONNX model
trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config)
trainer_from_onnx.train_step(data, targets)
assert trainer_from_onnx._onnx_model is not None
assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
assert (trainer_from_onnx._onnx_model == trainer._onnx_model)
assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph)
assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph))

0 comments on commit 4e24aac

Please sign in to comment.