diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index db8e2da4..4d5dd16b 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -502,7 +502,6 @@ def train_epoch(self, train_loader): if self.scheduler is not None: self.scheduler.step() - print("Current learning rate: ", self.optimizer.param_groups[-1]["lr"]) return epoch_metrics def train_batch(self, data, targets): diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 9c11b85f..edd929f3 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -263,34 +263,28 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu, Float value between 0 and 1 which will be used for momentum in batch norm """ + params = { + 'n_glu': n_glu, + 'virtual_batch_size': virtual_batch_size, + 'momentum': momentum, + 'device': device + } + if shared_layers is None: self.specifics = GLU_Block(input_dim, output_dim, - n_glu=n_glu, first=True, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device) + **params) else: self.shared = GLU_Block(input_dim, output_dim, - n_glu=n_glu, first=True, shared_layers=shared_layers, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device) + **params) self.specifics = GLU_Block(output_dim, output_dim, - n_glu=n_glu, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device) + **params) def forward(self, x): if self.shared is not None: - # print('-------before----------') - # print(self.shared.glu_layers[0].bn.bn.running_mean) x = self.shared(x) - # print('-------after-----------') - # print(self.shared.glu_layers[0].bn.bn.running_mean) x = self.specifics(x) return x @@ -308,32 +302,21 @@ def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=No self.glu_layers = torch.nn.ModuleList() self.scale = torch.sqrt(torch.FloatTensor([0.5]).to(device)) - if shared_layers: - for glu_id in range(self.n_glu): - if glu_id == 0: - self.glu_layers.append(GLU_Layer(input_dim, output_dim, - fc=shared_layers[glu_id], - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) - else: - self.glu_layers.append(GLU_Layer(output_dim, output_dim, - fc=shared_layers[glu_id], - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) - else: - for glu_id in range(self.n_glu): - if glu_id == 0: - self.glu_layers.append(GLU_Layer(input_dim, output_dim, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) - else: - self.glu_layers.append(GLU_Layer(output_dim, output_dim, - virtual_batch_size=virtual_batch_size, - momentum=momentum, - device=device)) + params = { + 'virtual_batch_size': virtual_batch_size, + 'momentum': momentum, + 'device': device + } + + fc = shared_layers[0] if shared_layers else None + self.glu_layers.append(GLU_Layer(input_dim, output_dim, + fc=fc, + **params)) + for glu_id in range(1, self.n_glu): + fc = shared_layers[glu_id] if shared_layers else None + self.glu_layers.append(GLU_Layer(output_dim, output_dim, + fc=fc, + **params)) def forward(self, x): if self.first: # the first layer of the block has no scale multiplication