Skip to content

Commit

Permalink
zero-range-features bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ppdebreuck committed May 12, 2020
1 parent ca28d16 commit 7e484bd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
4 changes: 1 addition & 3 deletions modnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def fit(self,data:MODData, val_fraction = 0.0, val_key = None, lr=0.001, epochs

def predict(self,data):

df = pd.DataFrame(columns=self.optimal_descriptors[:self.n_feat])
df = df.append(data.get_featurized_df()).replace([np.inf, -np.inf, np.nan], 0)
x = df[self.optimal_descriptors[:self.n_feat]].values
x = data.get_featurized_df()[self.optimal_descriptors[:self.n_feat]].values

#Scale the input features:
if self.xscale == 'minmax':
Expand Down
14 changes: 9 additions & 5 deletions modnet/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
from typing import Dict, List
import pickle
import os
database = None
database = pd.DataFrame([])

def nmi_target(df_feat,df_target):

frange = df_feat.max(axis=0)-df_feat.min(axis=0)
to_drop = frange[frange==0].index
df_feat = df_feat.drop(to_drop,axis=1)

target_name = df_target.columns[0]
mi_df = pd.DataFrame([],columns=[target_name],index=df_feat.columns)
Expand All @@ -24,11 +28,10 @@ def nmi_target(df_feat,df_target):
S_mi = mutual_info_regression(df_target[target_name].values.reshape(-1, 1),df_target[target_name])[0]

diag={}
to_drop=[]
for x in df_feat.columns:
diag[x]=(mutual_info_regression(df_feat[x].values.reshape(-1, 1),df_feat[x]))[0]
if diag[x] == 0:
to_drop.append(x) # features which have an entropy of zero are useless
#if diag[x] < 0.01:
# to_drop.append(x) # features which have an entropy of nearly zero are useless

mi_df.drop(to_drop,inplace=True)
for x in mi_df.index:
Expand Down Expand Up @@ -264,7 +267,8 @@ def featurize(self,fast=0,db_file='feature_database.pkl'):
print('Fast featurization on, retrieving from database...')
this_dir, this_filename = os.path.split(__file__)
global database
database = pd.read_pickle(db_file)
if len(database) == 0:
database = pd.read_pickle(db_file)
mpids_done = [x for x in self.mpids if x in database.index]
print('Retrieved features for {} out of {} materials'.format(len(mpids_done),len(self.mpids)))
df_done = database.loc[mpids_done]
Expand Down

0 comments on commit 7e484bd

Please sign in to comment.