-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype in self_energies #347
Conversation
@@ -192,7 +192,7 @@ def sae(self, species): | |||
intercept = self.self_energies[-1] | |||
|
|||
self_energies = self.self_energies[species] | |||
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=self_energies.dtype) | |||
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about removing the dtype and replace the 0
with 0.0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zasdfgbnm self_energies.dtype
should be double
, but float32
is enforced in comp6.py
, is there any particular reason for that?
see in: https://github.com/aiqm/torchani/pull/347/checks?check_run_id=273594990#step:9:25
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farhadrgh No, there is no reason. I guess it is just randomly picking a dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farhadrgh Also, there are some trick here, energy shifter is a module, so if you put it inside
model = torch.nn.Sequential(aev_computer, nn, energy_shifter)
and then model.to(torch.float)
it will also be cast into float. I don't have a good idea on how to solve that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the dtype
in comp6.py
(6930879), this will let the tests pass peacefully
I wonder why the test hadn't failed before!