Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] Expand export_model arguments: dummy input and onnx opset_version #3968

Merged
merged 5 commits into from
Jul 26, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ def _wrap_modules(self, layer, config):
wrapper.to(layer.module.weight.device)
return wrapper

def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None,
dummy_input=None, opset_version=None):
"""
Export pruned model weights, masks and onnx model(optional)

Expand All @@ -388,9 +389,19 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
this shape is used for creating a dummy input tensor for torch.onnx.export
if the input has a complex structure (e.g., a tuple), please directly create the input and
pass it to dummy_input instead
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
device of the model, where to place the dummy input tensor for exporting onnx file;
the tensor is placed on cpu if ```device``` is None
only useful when both onnx_path and input_shape are passed
dummy_input: torch.Tensor or tuple
dummy input to the onnx model; used when input_shape is not enough to specify dummy input
user should ensure that the dummy_input is on the same device as the model
opset_version: int
opset_version parameter for torch.onnx.export; only useful when onnx_path is not None
if not passed, torch.onnx.export will use its default opset_version
"""
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
Expand All @@ -411,17 +422,31 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N

torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)

if mask_path is not None:
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)

if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
assert input_shape is not None or dummy_input is not None,\
'input_shape or dummy_input must be specified to export onnx model'
# create dummy_input using input_shape if input_shape is not passed
if dummy_input is None:
_logger.warning("""The argument input_shape and device will be removed in the next release.
Please create a dummy input and pass it to dummy_input instead.""")
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape).to(device)
else:
input_data = dummy_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if user both set dummy_input and device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think device should be ignored in that case. Have updated the docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommand we also input_data = dummy_input.to(device), or this may confuse user, if user also set device but we ignore it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this operation may fail when e.g., dummy_input is a tuple

if opset_version is not None:
torch.onnx.export(self.bound_model, input_data, onnx_path, opset_version=opset_version)
else:
torch.onnx.export(self.bound_model, input_data, onnx_path)
if dummy_input is None:
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
else:
_logger.info('Model in onnx saved to %s', onnx_path)

self._wrap_model()

Expand Down