From f04d064884be5ede7b6f7d844ce22b793607d091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:11:34 +0100 Subject: [PATCH] fix detection of duplicate tensors --- bindings/python/py_src/safetensors/torch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 7fa59675..cdde1a23 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -41,6 +41,10 @@ def storage_size(tensor: torch.Tensor) -> int: return tensor.nelement() * _SIZE[tensor.dtype] +def storage_offset(tensor: torch.Tensor) -> int: + return tensor.storage_offset() + + def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: filtered_tensors = [] for shared in tensors: @@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: for k, v in state_dict.items(): if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0: # Need to add device as key because of multiple GPU. - tensors[(v.device, storage_ptr(v), storage_size(v))].add(k) + # Need to add storage_offset as key because views may share the same data_ptr. + tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k) tensors = list(sorted(tensors.values())) tensors = _filter_shared_not_shared(tensors, state_dict) return tensors