Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force dtype to int64 to ensure that we don't index with non-long tensor #258

Open
TobiasMadsenQiagen opened this issue Jun 21, 2022 · 0 comments

Comments

@TobiasMadsenQiagen
Copy link

TobiasMadsenQiagen commented Jun 21, 2022

In the triplet data loaders (utils.py:load_triplet_data and utils.py:load_raw_triplet_data) the imported data must be forced to be of type int64, to ensure that torch tensors are always long. Otherwise torch may complain that a vector used for indexing is not of type long, when calling predict:

line 186, in __call__
return self.emb[idx].to(self.device)
IndexError: tensors used as indices must be long, byte or bool tensors

np.asarray tries to infer the data type for the input, which on the windows system we have tested on is int32 as long as the input ints are smaller than 2^31-1. On mac and ubuntu we did not observe the problem.
We have tested with dglke 0.1.2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant