Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix detection of duplicate torch tensors #379

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def storage_size(tensor: torch.Tensor) -> int:
return tensor.nelement() * _SIZE[tensor.dtype]


def storage_offset(tensor: torch.Tensor) -> int:
return tensor.storage_offset()


def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
filtered_tensors = []
for shared in tensors:
Expand Down Expand Up @@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
for k, v in state_dict.items():
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
# Need to add device as key because of multiple GPU.
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
# Need to add storage_offset as key because views may share the same data_ptr.
tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is correct.

It's trying to see which tensors have shared memory, which is indicated by storage.
offset will make you miss some shared tensors.

Copy link
Author

@fxmarty fxmarty Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, what this functions does is finding which tensors in the state_dict are "identical". Relying on data_ptr() is in this case not enough, as data_ptr is the same for views of a tensor - views that may represent different data. For example, bar = a[2:4] is not the same tensor as foo = a[4:6].

I'll try to come up with a minimal example to showcase what I mean.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil @LysandreJik Here is a minimal repro of the issue:

import torch
import torch.nn as nn
from safetensors.torch import _find_shared_tensors

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(20, 30)
        
        self.q = nn.Linear(20, 10)
        self.q.weight = torch.nn.Parameter(self.qkv.weight[:10])
    
        self.v = nn.Linear(20, 10)
        self.v.weight = torch.nn.Parameter(self.qkv.weight[10:20])

    def forward(self, x):
        return x

model = MyModel()
shared_params = _find_shared_tensors(model.state_dict())
print("shared_params", shared_params)

printing: shared_params [{'v.weight', 'qkv.weight', 'q.weight'}, {'qkv.bias'}, {'q.bias'}, {'v.bias'}]

Copy link
Author

@fxmarty fxmarty Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay but I understand that you argue that _find_shared_tensors is meant to group tensors sharing memory, no matter whether some of them are views and the actual data is different. Is that the case?

Despite the fix huggingface/transformers#27314, Transformers still calls safetensors.torch.save_file, and complains that tensors still have shared memory: https://github.com/fxmarty/safetensors/blob/f04d064884be5ede7b6f7d844ce22b793607d091/bindings/python/py_src/safetensors/torch.py#L474. Couldn't Transformers call save_model as the error suggests? Though I doubt this would help as save_model calls _remove_duplicate_names that in turns calls _find_shared_tensors.

tensors = list(sorted(tensors.values()))
tensors = _filter_shared_not_shared(tensors, state_dict)
return tensors
Expand Down
Loading