Skip to content

Commit

Permalink
chore: add random seed for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Dec 3, 2019
1 parent d892f9c commit 3b1c4d9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TabModel(object):
def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1,
n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02,
lambda_sparse=1e-3,
lambda_sparse=1e-3, seed=0,
clip_value=1, verbose=1,
lr=2e-2, optimizer_fn=torch.optim.Adam,
lr_params={}, scheduler_params=None, scheduler_fn=None,
Expand Down Expand Up @@ -53,6 +53,8 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
self.opt_params = {}
self.opt_params['lr'] = self.lr

self.seed = seed
torch.manual_seed(self.seed)
# Defining device
if device_name == 'auto':
if torch.cuda.is_available():
Expand Down Expand Up @@ -311,7 +313,7 @@ class TabNetClassifier(TabModel):

def __repr__(self):
repr_ = f"""TabNetClassifier(n_d={self.n_d}, n_a={self.n_a}, n_steps={self.n_steps},
lr={self.lr}, lr_params={self.lr_params},
lr={self.lr}, lr_params={self.lr_params}, seed={self.seed},
gamma={self.gamma}, n_independent={self.n_independent}, n_shared={self.n_shared},
cat_idxs={self.cat_idxs},
cat_dims={self.cat_dims},
Expand Down Expand Up @@ -641,7 +643,7 @@ class TabNetRegressor(TabModel):

def __repr__(self):
repr_ = f"""TabNetRegressor(n_d={self.n_d}, n_a={self.n_a}, n_steps={self.n_steps},
lr={self.lr}, lr_params={self.lr_params},
lr={self.lr}, lr_params={self.lr_params}, seed={self.seed},
gamma={self.gamma}, n_independent={self.n_independent}, n_shared={self.n_shared},
cat_idxs={self.cat_idxs},
cat_dims={self.cat_dims},
Expand Down

0 comments on commit 3b1c4d9

Please sign in to comment.