Skip to content

Commit

Permalink
fix: sparsemax on train and predict epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Feb 24, 2020
1 parent 939f01c commit 6f7c0e0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def train_epoch(self, train_loader):
for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
if self.output_dim == 2:
y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy())
y_preds.append(torch.nn.Softmax(dim=1)(batch_outs["y_preds"])[:, 1]
.cpu().detach().numpy())
else:
values, indices = torch.max(batch_outs["y_preds"], dim=1)
y_preds.append(indices.cpu().detach().numpy())
Expand Down Expand Up @@ -566,7 +567,8 @@ def predict_epoch(self, loader):
batch_outs = self.predict_batch(data, targets)
total_loss += batch_outs["loss"]
if self.output_dim == 2:
y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy())
y_preds.append(torch.nn.Softmax(dim=1)(batch_outs["y_preds"])[:, 1]
.cpu().detach().numpy())
else:
values, indices = torch.max(batch_outs["y_preds"], dim=1)
y_preds.append(indices.cpu().detach().numpy())
Expand Down
1 change: 1 addition & 0 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu,
}

if shared_layers is None:
self.shared = None
self.specifics = GLU_Block(input_dim, output_dim,
first=True,
**params)
Expand Down

0 comments on commit 6f7c0e0

Please sign in to comment.