diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 7d2b1c1fcc..dd9d695204 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -39,6 +39,7 @@ import importlib import inspect import itertools +import json from typing import List, Tuple, Union, Dict, Callable, Any, Iterable, Optional, TextIO import contextlib import os @@ -49,6 +50,7 @@ import warnings import numpy as np +from safetensors.numpy import load as load_safetensor import torch.nn import torch from torch.utils.data import DataLoader, Dataset @@ -1396,3 +1398,65 @@ def place_model(model: torch.nn.Module, device: torch.device): yield finally: model.to(device=original_device) + + +def load_torch_model_using_safetensors(model_name: str, path: str, filename: str) -> torch.nn.Module: + """ + Load the pytorch model from the given path and filename. + NOTE: The model can only be saved by saving the state dict. Attempting to serialize the entire model will result + in a mismatch between class types of the model defined and the class type that is imported programatically. + + :param model_name: Name of model + :param path: Path where the pytorch model definition file is saved + :param filename: Filename of the pytorch model definition and the safetensors weight file + + :return: Imported pytorch model with embeded metadata + """ + + model_path = os.path.join(path, filename + '.py') + if not os.path.exists(model_path): + logger.error('Unable to find model file at path %s', model_path) + raise AssertionError('Unable to find model file at path ' + model_path) + + # Import model's module and instantiate model + spec = importlib.util.spec_from_file_location(filename, model_path) + module = importlib.util.module_from_spec(spec) + sys.modules[filename] = module + spec.loader.exec_module(module) + model = getattr(module, model_name)() + + # Load state dict using safetensors file + state_dict_path = os.path.join(path, filename + '.safetensors') + if not os.path.exists(state_dict_path): + logger.error('Unable to find state dict file at path %s', state_dict_path) + raise AssertionError('Unable to find state dict file at path ' + state_dict_path) + state_dict, meta_data = _get_metadata_and_state_dict(state_dict_path) + model.load_state_dict(state_dict, strict=False) + + # Sets the MPP meta data extracted from safetensors file into the model as an atribute + # so that it can be extracted and saved at the time of weights export. + model.__setattr__('mpp_meta', meta_data) + return model + + +def _get_metadata_and_state_dict(safetensor_file_path: str) -> [dict, dict]: + """ + Extracts the state dict from a numpy format safetensors as well as metadata. + Converts the state_dict from numpy aray to torch tensors. + + :param safetensor_file_path: Path of the safetensor file. + :return: state dict in torch.Tensor format and metadata + """ + + with open(safetensor_file_path, "rb") as f: + data = f.read() + + # Get the header length to extract the metadata + header_length = int.from_bytes(data[:8], "little", signed=False) + meta_data = json.loads(data[8:8 + header_length].decode()).get('__metadata__', {}) + + # Load the state dict and convert it to torch tensor + state_dict = load_safetensor(data) + state_dict = {k: torch.from_numpy(v) for k, v in state_dict.items()} + + return state_dict, meta_data diff --git a/TrainingExtensions/torch/test/python/test_utils.py b/TrainingExtensions/torch/test/python/test_utils.py index a91d9026be..99d2e393a9 100644 --- a/TrainingExtensions/torch/test/python/test_utils.py +++ b/TrainingExtensions/torch/test/python/test_utils.py @@ -34,7 +34,7 @@ # # @@-COPYRIGHT-END-@@ # ============================================================================= - +import json import os import pytest import unittest.mock @@ -55,6 +55,7 @@ from aimet_torch.quantsim import QuantizationSimModel from models.test_models import TinyModel, MultiInput, ModelWithReusedNodes, SingleResidual, EmbeddingModel +from safetensors.numpy import save_file as save_safetensor_file class TestTrainingExtensionsUtils(unittest.TestCase): @@ -714,6 +715,87 @@ def forward(self, *inputs): with self.assertRaises(AssertionError): _ = utils.load_pytorch_model('MiniModel', tmp_dir, 'mini_model', load_state_dict=False) + def test_load_pytorch_model_using_safeteneors(self): + """ test load_pytorch_model_using_safetensors utility """ + + class MiniModel(torch.nn.Module): + + def __init__(self): + super(MiniModel, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False) + self.bn1 = torch.nn.BatchNorm2d(8) + self.relu1 = torch.nn.ReLU(inplace=True) + self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.fc = torch.nn.Linear(128, 12) + + def forward(self, *inputs): + x = self.conv1(inputs[0]) + x = self.bn1(x) + x = self.relu1(x) + x = self.maxpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, 'mini_model.py'), 'w') as f: + print(""" +import torch +import torch.nn +class MiniModel(torch.nn.Module): + + def __init__(self): + super(MiniModel, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False) + self.bn1 = torch.nn.BatchNorm2d(8) + self.relu1 = torch.nn.ReLU(inplace=True) + self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.fc = torch.nn.Linear(128, 12) + + def forward(self, *inputs): + x = self.conv1(inputs[0]) + x = self.bn1(x) + x = self.relu1(x) + x = self.maxpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + """, file=f) + model = MiniModel() + model.eval() + dummy_input = torch.randn(1, 3, 8, 8) + out1 = model(dummy_input) + + state_dict = model.state_dict() + state_dict = {k: to_numpy(v) for k, v in state_dict.items()} + metadata = {"dummy_meta_key": "dummy_meta_value"} + metadata = {"metadata": json.dumps(metadata)} + + file_path = os.path.join(tmp_dir, 'mini_model.safetensors') + save_safetensor_file(state_dict, file_path, metadata) + + new_model = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model') + new_model.eval() + out2 = new_model(dummy_input) + assert torch.allclose(out1, out2) + + # Delete pth state dict file + if os.path.exists(os.path.join(tmp_dir, "mini_model.safetensors")): + os.remove(os.path.join(tmp_dir, "mini_model.safetensors")) + + with self.assertRaises(AssertionError): + _ = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model') + + # Delete pth state dict file + if os.path.exists(os.path.join(tmp_dir, "mini_model.py")): + os.remove(os.path.join(tmp_dir, "mini_model.py")) + + with self.assertRaises(AssertionError): + _ = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model') + def test_disable_all_quantizers(self): model = TinyModel().to(device="cpu") dummy_input = torch.rand(1, 3, 32, 32)