Skip to content

Commit

Permalink
chore: clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp committed Feb 7, 2020
1 parent 5f0e43f commit d8631f8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 43 deletions.
1 change: 0 additions & 1 deletion pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ def train_epoch(self, train_loader):

if self.scheduler is not None:
self.scheduler.step()
print("Current learning rate: ", self.optimizer.param_groups[-1]["lr"])
return epoch_metrics

def train_batch(self, data, targets):
Expand Down
67 changes: 25 additions & 42 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,34 +263,28 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu,
Float value between 0 and 1 which will be used for momentum in batch norm
"""

params = {
'n_glu': n_glu,
'virtual_batch_size': virtual_batch_size,
'momentum': momentum,
'device': device
}

if shared_layers is None:
self.specifics = GLU_Block(input_dim, output_dim,
n_glu=n_glu,
first=True,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device)
**params)
else:
self.shared = GLU_Block(input_dim, output_dim,
n_glu=n_glu,
first=True,
shared_layers=shared_layers,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device)
**params)
self.specifics = GLU_Block(output_dim, output_dim,
n_glu=n_glu,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device)
**params)

def forward(self, x):
if self.shared is not None:
# print('-------before----------')
# print(self.shared.glu_layers[0].bn.bn.running_mean)
x = self.shared(x)
# print('-------after-----------')
# print(self.shared.glu_layers[0].bn.bn.running_mean)
x = self.specifics(x)
return x

Expand All @@ -308,32 +302,21 @@ def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=No
self.glu_layers = torch.nn.ModuleList()
self.scale = torch.sqrt(torch.FloatTensor([0.5]).to(device))

if shared_layers:
for glu_id in range(self.n_glu):
if glu_id == 0:
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
fc=shared_layers[glu_id],
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
fc=shared_layers[glu_id],
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
for glu_id in range(self.n_glu):
if glu_id == 0:
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
params = {
'virtual_batch_size': virtual_batch_size,
'momentum': momentum,
'device': device
}

fc = shared_layers[0] if shared_layers else None
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
fc=fc,
**params))
for glu_id in range(1, self.n_glu):
fc = shared_layers[glu_id] if shared_layers else None
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
fc=fc,
**params))

def forward(self, x):
if self.first: # the first layer of the block has no scale multiplication
Expand Down

0 comments on commit d8631f8

Please sign in to comment.