Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix detection of duplicate torch tensors #379

Closed
wants to merge 1 commit into from

Conversation

fxmarty
Copy link

@fxmarty fxmarty commented Nov 6, 2023

As per title. Two tensors with different data may share the same data_ptr in case they are views of an other tensor. For example:

import torch

a = torch.rand(8, 10, 50)

b = a[2:4]
c = a[4:6]

print(a.untyped_storage().data_ptr())
print(b.untyped_storage().data_ptr())
print(c.untyped_storage().data_ptr())

prints

93888779151936
93888779151936
93888779151936

@@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
for k, v in state_dict.items():
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
# Need to add device as key because of multiple GPU.
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
# Need to add storage_offset as key because views may share the same data_ptr.
tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is correct.

It's trying to see which tensors have shared memory, which is indicated by storage.
offset will make you miss some shared tensors.

Copy link
Author

@fxmarty fxmarty Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, what this functions does is finding which tensors in the state_dict are "identical". Relying on data_ptr() is in this case not enough, as data_ptr is the same for views of a tensor - views that may represent different data. For example, bar = a[2:4] is not the same tensor as foo = a[4:6].

I'll try to come up with a minimal example to showcase what I mean.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil @LysandreJik Here is a minimal repro of the issue:

import torch
import torch.nn as nn
from safetensors.torch import _find_shared_tensors

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(20, 30)
        
        self.q = nn.Linear(20, 10)
        self.q.weight = torch.nn.Parameter(self.qkv.weight[:10])
    
        self.v = nn.Linear(20, 10)
        self.v.weight = torch.nn.Parameter(self.qkv.weight[10:20])

    def forward(self, x):
        return x

model = MyModel()
shared_params = _find_shared_tensors(model.state_dict())
print("shared_params", shared_params)

printing: shared_params [{'v.weight', 'qkv.weight', 'q.weight'}, {'qkv.bias'}, {'q.bias'}, {'v.bias'}]

Copy link
Author

@fxmarty fxmarty Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay but I understand that you argue that _find_shared_tensors is meant to group tensors sharing memory, no matter whether some of them are views and the actual data is different. Is that the case?

Despite the fix huggingface/transformers#27314, Transformers still calls safetensors.torch.save_file, and complains that tensors still have shared memory: https://github.com/fxmarty/safetensors/blob/f04d064884be5ede7b6f7d844ce22b793607d091/bindings/python/py_src/safetensors/torch.py#L474. Couldn't Transformers call save_model as the error suggests? Though I doubt this would help as save_model calls _remove_duplicate_names that in turns calls _find_shared_tensors.

@Narsil
Copy link
Collaborator

Narsil commented Nov 17, 2023

Fixed directly in transformers.

@Narsil Narsil closed this Nov 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants