Skip to content

Commit

Permalink
Util to load torch model with weights from safetensors (#3304)
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Thakur <[email protected]>
  • Loading branch information
quic-ristha authored Sep 3, 2024
1 parent 0d084f4 commit cd95a49
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 1 deletion.
64 changes: 64 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import importlib
import inspect
import itertools
import json
from typing import List, Tuple, Union, Dict, Callable, Any, Iterable, Optional, TextIO
import contextlib
import os
Expand All @@ -49,6 +50,7 @@
import warnings

import numpy as np
from safetensors.numpy import load as load_safetensor
import torch.nn
import torch
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -1396,3 +1398,65 @@ def place_model(model: torch.nn.Module, device: torch.device):
yield
finally:
model.to(device=original_device)


def load_torch_model_using_safetensors(model_name: str, path: str, filename: str) -> torch.nn.Module:
"""
Load the pytorch model from the given path and filename.
NOTE: The model can only be saved by saving the state dict. Attempting to serialize the entire model will result
in a mismatch between class types of the model defined and the class type that is imported programatically.
:param model_name: Name of model
:param path: Path where the pytorch model definition file is saved
:param filename: Filename of the pytorch model definition and the safetensors weight file
:return: Imported pytorch model with embeded metadata
"""

model_path = os.path.join(path, filename + '.py')
if not os.path.exists(model_path):
logger.error('Unable to find model file at path %s', model_path)
raise AssertionError('Unable to find model file at path ' + model_path)

# Import model's module and instantiate model
spec = importlib.util.spec_from_file_location(filename, model_path)
module = importlib.util.module_from_spec(spec)
sys.modules[filename] = module
spec.loader.exec_module(module)
model = getattr(module, model_name)()

# Load state dict using safetensors file
state_dict_path = os.path.join(path, filename + '.safetensors')
if not os.path.exists(state_dict_path):
logger.error('Unable to find state dict file at path %s', state_dict_path)
raise AssertionError('Unable to find state dict file at path ' + state_dict_path)
state_dict, meta_data = _get_metadata_and_state_dict(state_dict_path)
model.load_state_dict(state_dict, strict=False)

# Sets the MPP meta data extracted from safetensors file into the model as an atribute
# so that it can be extracted and saved at the time of weights export.
model.__setattr__('mpp_meta', meta_data)
return model


def _get_metadata_and_state_dict(safetensor_file_path: str) -> [dict, dict]:
"""
Extracts the state dict from a numpy format safetensors as well as metadata.
Converts the state_dict from numpy aray to torch tensors.
:param safetensor_file_path: Path of the safetensor file.
:return: state dict in torch.Tensor format and metadata
"""

with open(safetensor_file_path, "rb") as f:
data = f.read()

# Get the header length to extract the metadata
header_length = int.from_bytes(data[:8], "little", signed=False)
meta_data = json.loads(data[8:8 + header_length].decode()).get('__metadata__', {})

# Load the state dict and convert it to torch tensor
state_dict = load_safetensor(data)
state_dict = {k: torch.from_numpy(v) for k, v in state_dict.items()}

return state_dict, meta_data
84 changes: 83 additions & 1 deletion TrainingExtensions/torch/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import json
import os
import pytest
import unittest.mock
Expand All @@ -55,6 +55,7 @@

from aimet_torch.quantsim import QuantizationSimModel
from models.test_models import TinyModel, MultiInput, ModelWithReusedNodes, SingleResidual, EmbeddingModel
from safetensors.numpy import save_file as save_safetensor_file


class TestTrainingExtensionsUtils(unittest.TestCase):
Expand Down Expand Up @@ -714,6 +715,87 @@ def forward(self, *inputs):
with self.assertRaises(AssertionError):
_ = utils.load_pytorch_model('MiniModel', tmp_dir, 'mini_model', load_state_dict=False)

def test_load_pytorch_model_using_safeteneors(self):
""" test load_pytorch_model_using_safetensors utility """

class MiniModel(torch.nn.Module):

def __init__(self):
super(MiniModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
self.bn1 = torch.nn.BatchNorm2d(8)
self.relu1 = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
self.fc = torch.nn.Linear(128, 12)

def forward(self, *inputs):
x = self.conv1(inputs[0])
x = self.bn1(x)
x = self.relu1(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

with tempfile.TemporaryDirectory() as tmp_dir:
with open(os.path.join(tmp_dir, 'mini_model.py'), 'w') as f:
print("""
import torch
import torch.nn
class MiniModel(torch.nn.Module):
def __init__(self):
super(MiniModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
self.bn1 = torch.nn.BatchNorm2d(8)
self.relu1 = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
self.fc = torch.nn.Linear(128, 12)
def forward(self, *inputs):
x = self.conv1(inputs[0])
x = self.bn1(x)
x = self.relu1(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
""", file=f)
model = MiniModel()
model.eval()
dummy_input = torch.randn(1, 3, 8, 8)
out1 = model(dummy_input)

state_dict = model.state_dict()
state_dict = {k: to_numpy(v) for k, v in state_dict.items()}
metadata = {"dummy_meta_key": "dummy_meta_value"}
metadata = {"metadata": json.dumps(metadata)}

file_path = os.path.join(tmp_dir, 'mini_model.safetensors')
save_safetensor_file(state_dict, file_path, metadata)

new_model = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model')
new_model.eval()
out2 = new_model(dummy_input)
assert torch.allclose(out1, out2)

# Delete pth state dict file
if os.path.exists(os.path.join(tmp_dir, "mini_model.safetensors")):
os.remove(os.path.join(tmp_dir, "mini_model.safetensors"))

with self.assertRaises(AssertionError):
_ = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model')

# Delete pth state dict file
if os.path.exists(os.path.join(tmp_dir, "mini_model.py")):
os.remove(os.path.join(tmp_dir, "mini_model.py"))

with self.assertRaises(AssertionError):
_ = utils.load_torch_model_using_safetensors('MiniModel', tmp_dir, 'mini_model')

def test_disable_all_quantizers(self):
model = TinyModel().to(device="cpu")
dummy_input = torch.rand(1, 3, 32, 32)
Expand Down

0 comments on commit cd95a49

Please sign in to comment.