-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix detection of duplicate torch tensors #379
Conversation
@@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: | |||
for k, v in state_dict.items(): | |||
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0: | |||
# Need to add device as key because of multiple GPU. | |||
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k) | |||
# Need to add storage_offset as key because views may share the same data_ptr. | |||
tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is correct.
It's trying to see which tensors have shared memory, which is indicated by storage.
offset
will make you miss some shared tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my understanding, what this functions does is finding which tensors in the state_dict are "identical". Relying on data_ptr()
is in this case not enough, as data_ptr
is the same for views of a tensor - views that may represent different data. For example, bar = a[2:4]
is not the same tensor as foo = a[4:6]
.
I'll try to come up with a minimal example to showcase what I mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Narsil @LysandreJik Here is a minimal repro of the issue:
import torch
import torch.nn as nn
from safetensors.torch import _find_shared_tensors
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.qkv = nn.Linear(20, 30)
self.q = nn.Linear(20, 10)
self.q.weight = torch.nn.Parameter(self.qkv.weight[:10])
self.v = nn.Linear(20, 10)
self.v.weight = torch.nn.Parameter(self.qkv.weight[10:20])
def forward(self, x):
return x
model = MyModel()
shared_params = _find_shared_tensors(model.state_dict())
print("shared_params", shared_params)
printing: shared_params [{'v.weight', 'qkv.weight', 'q.weight'}, {'qkv.bias'}, {'q.bias'}, {'v.bias'}]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay but I understand that you argue that _find_shared_tensors
is meant to group tensors sharing memory, no matter whether some of them are views and the actual data is different. Is that the case?
Despite the fix huggingface/transformers#27314, Transformers still calls safetensors.torch.save_file
, and complains that tensors still have shared memory: https://github.com/fxmarty/safetensors/blob/f04d064884be5ede7b6f7d844ce22b793607d091/bindings/python/py_src/safetensors/torch.py#L474. Couldn't Transformers call save_model
as the error suggests? Though I doubt this would help as save_model
calls _remove_duplicate_names
that in turns calls _find_shared_tensors
.
Fixed directly in transformers. |
As per title. Two tensors with different data may share the same
data_ptr
in case they are views of an other tensor. For example:prints