Skip to content

Commit

Permalink
Merge pull request #66 from traja-team/fortasyn
Browse files Browse the repository at this point in the history
fixing accuracy > 1 error
  • Loading branch information
MaddyThakker authored Feb 13, 2021
2 parents 045a4bd + 51f7c0e commit f87978d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion traja/models/losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

class Criterion:
"""Implements the loss functions of Autoencoders, Variational Autoencoders and LSTM models
Expand Down Expand Up @@ -51,7 +52,8 @@ def classifier_criterion(self, predicted, target):
:return: Cross entropy loss
"""

# _, predicted = torch.max(predicted.data, 1)
predicted = predicted.to(device)
target = target.to(device)
loss = self.crossentropy_loss(predicted, target.view(-1))
return loss

Expand Down
5 changes: 3 additions & 2 deletions traja/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def fit(
# Compute number of correct samples
total += ids.size(0)
_, predicted = torch.max(classifier_out.data, 1)
correct += (predicted == classes).sum().item()

correct += (predicted.cpu() == classes.cpu().T).sum().item()

if self.regress:
regressor_out = self.model(
Expand Down Expand Up @@ -401,7 +402,7 @@ def validate(self, validation_loader):
# Compute number of correct samples
total += ids.size(0)
_, predicted = torch.max(classifier_out.data, 1)
correct += (predicted == classes).sum().item()
correct += (predicted.cpu() == classes.cpu().T).sum().item()

if self.regress:
regressor_out = self.model(
Expand Down

0 comments on commit f87978d

Please sign in to comment.