Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export weight and enc api #3302

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import torch
import onnx
from packaging import version # pylint: disable=wrong-import-order
from safetensors.numpy import save_file as save_safetensor_file

import aimet_common.libpymo as libpymo
from aimet_common import quantsim
Expand Down Expand Up @@ -648,6 +649,26 @@ def export_onnx_model_and_encodings(path: str, filename_prefix: str, original_mo
excluded_layer_names, propagate_encodings,
quantizer_args=quantizer_args)

def export_weights_to_safetensors(self, path: str, filename_prefix: str):
"""
Exports the updated weights in the safetensors format

:param path: Path to save file
:param filename_prefix: Filename to use for saved file
"""

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# Save state dict in safetensors file
unwrapped_model = QuantizationSimModel.get_original_model(self.model)
data = unwrapped_model.state_dict()
data = {k: to_numpy(v) for k, v in data.items()}
metadata = self.model.mpp_meta if hasattr(self.model, 'mpp_meta') else {}

file_path = os.path.join(path, filename_prefix + '.safetensors')
save_safetensor_file(data, file_path, metadata)

def save_encodings_to_json(self, path: str, filename_prefix: str):
"""
Save encodings in the model to json.
Expand Down
13 changes: 13 additions & 0 deletions TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5212,6 +5212,19 @@ def forward(self, inp):
closest_wrapper = qsim._get_closest_producer_wrapper(qsim.connected_graph.ordered_ops[1], module_to_quant_wrapper)
assert closest_wrapper == qsim.model.permute


def test_export_to_safetensors():
torch.manual_seed(0)
model = SmallMnistNoDropoutWithPassThrough()
model.eval()
dummy_data = torch.randn(1, 1, 32, 32)
sim = QuantizationSimModel(model, dummy_data)
sim.compute_encodings(lambda m, itr: m(dummy_data), None)
with tempfile.TemporaryDirectory() as tempDir:
sim.export_weights_to_safetensors(tempDir, 'sim_export')
assert(os.path.exists(os.path.join(tempDir, 'sim_export'+'.safetensors')))


@pytest.mark.cuda
@pytest.mark.parametrize('input_dims', (2, 3, 4))
def test_fused_qdq_linear(input_dims):
Expand Down
Loading