Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
ishanuc committed Jan 25, 2024
1 parent cf5722c commit 0a597d6
Show file tree
Hide file tree
Showing 16 changed files with 584 additions and 322 deletions.
96 changes: 62 additions & 34 deletions build/lib/truthnet/truthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,6 @@
from zedstat import zedstat
from concurrent.futures import ProcessPoolExecutor

# VeRITAS
# objewctive here is to train the veritas model which will be used to determine if a
# subject is being adversarial in a strutured interview. Example of such
# adversarial responses is when a subject is malingering in a
# mental health diagnosis interview or autimated computed aided diatgnostic test
#
#datapath is the path to the survet or interveiew database where we have
# a reasonable number of responses logged from which we will learn our model
# typically we would have a "target label" columns which identifies the ground truth
# on if the subject belonged to one ctaeofry or teh other (have a certain mental health condition, as
# determined by a psychiatrist, or not). Note this column in not ground truth on malingering, there is typically no ground truth available on that

# index_present describes if the traininmg datafraem/dataset has an index column in teh first column
# training fraction is teh fraction used as training data to learn the "cross-talk or Q-net" models. The reminaing or "test data" is used to infer the three decision thresholds
# query limit is the number of features or items (columns is the data) that are used to dtermine the
# malingering status in deployment. The decision-thresholds are dteremined using this many features, where the features are ordered from most predictive to least predictive using a SHAP analysis

# There are three decision thresholds that the VeRITAS model
#
# QNETmodel
# \Phi(x_{-i}) is a probability distribution of outcomes over the ith variable or question asked, given all other responses (notation for which is x_{-i}. In general we can also have othetr enties missing, and such missing data is interpreted as a distribution over all possible outcomes at that index of missing data. This qnet model allows us to define a metrix between two response vectors x, y denoted as \theta(x,y), and allows us to define the probability Pr(x \rightarrow x)
#
# LOWER_DECISION THRESHOLD is an estimate of the negative loglikelihood -log Pr(x \rightsrrow x) for a given x per item with a non-missing response. Turns out that as we hav ethis estimate fall below 1, it becomes extreemy unkikley to be naturally generated.
#
# VERITAS THRESHOLD: catures what is teh average deviation of a response vector from what teh model says the responses should be.
#
# UPPER THRESHOLD, estimates a threshold on the ration of loglikelihhods of a response being produced by aqnet inferred fro the positive cases vs that inferred for negative cases
# So for non-malingering response, one needs to be above UPPER threshold, below veritas threshold, and above the LOWER threshold.

global_NSTR = None
global_steps = None
global_model = None
Expand All @@ -53,6 +24,11 @@ def init_globals(model, steps, NSTR):
'''
global variable initialization necessary for
getting maximum paralleization in calibration
Parameters:
model: The model to be used globally across parallel tasks.
steps: The number of steps to be used for a specific operation, globally.
NSTR: Network String Representation, a global variable to represent the network state.
'''
global global_model, global_steps, global_NSTR
global_model = model
Expand All @@ -62,14 +38,21 @@ def init_globals(model, steps, NSTR):
def task(seed):
'''
Helper function for parallelization
Parameters:
seed: An integer seed for random number generation to ensure reproducibility.
Returns:
A tuple containing the function 'm' output and the median dissonance distribution for a sample.
'''
s=qsample(global_NSTR, global_model, steps=global_steps)
return funcm(s,global_model),dissonance_distr_median(s,global_model)

class truthnet:
"""
The truthnet class is designed to train the Veritas model which is used to determine if a
subject is being adversarial in a structured interview. It is particularly focused on identifying
subject is being deceptive or untruthful or insincere in a structured interview.
Examples of target scenarios include identifying
adversarial responses in contexts like mental health diagnosis interviews
or automated computer-aided diagnostic tests.
Expand All @@ -85,6 +68,9 @@ class truthnet:
- problem (str): A description or identifier for the type of problem being addressed.
- threshold_alpha (float): Significance level for lower decision threshold.
- threshold_alpha_veritas (float): Significance level for Veritas threshold.
- veritas_model (dict): model, see detailed documentation on veritas model.
- problem (str): descriptive string for problem
- VERBOSE (bool): flag to denote if there should be verbose output
"""
def __init__(self, datapath,
target_label,
Expand Down Expand Up @@ -219,8 +205,9 @@ def funcm_(S):
self.veritas_model['model']=modelpos
self.veritas_model['model_neg']=modelneg
self.veritas_model['problem']=self.problem
self.veritas_model['shapvalues']=shap_values

else:
else:

X=df_test.values.astype(str)

Expand All @@ -245,6 +232,9 @@ def funcm_(S):
def save(self, filepath):
'''
save veritas model
Parameters:
filepath (str): The path where the model should be saved.
'''
with gzip.open(filepath, 'wb') as file:
M=self.veritas_model
Expand All @@ -258,9 +248,11 @@ def calibrate(self,
from the trained model. It involves sampling, revealing, and fitting distributions
to determine appropriate thresholds.
Parameters:
- qsteps (int): Number of steps for q-sampling during calibration.
- calibration_num (int): Number of samples to use for calibration.
qsteps (int): Steps for q-sampling during calibration.
num_workers (int): Number of parallel workers for calibration.
calibration_num (int): Number of calibration samples.
"""

featurenames = self.veritas_model['model'].feature_names
Expand Down Expand Up @@ -347,6 +339,15 @@ def calibrate(self,
return

def synccols(self, df_):
"""
Synchronize columns between positive and negative cases.
Parameters:
df_ (DataFrame): The DataFrame to process.
Returns:
DataFrame: A DataFrame with synchronized columns.
"""
df=df_.copy()
if self.target_label:
df1 = df[df[self.target_label] == str(self.target_label_positive)]
Expand All @@ -360,13 +361,28 @@ def synccols(self, df_):

def load_veritas_model(filepath):
'''
load veritas model
Load a Veritas model from a specified file.
Parameters:
filepath (str): The path to the file containing the saved Veritas model.
Returns:
The loaded Veritas model.
'''
with gzip.open(filepath, 'rb') as file:
model = pickle.load(file)
return model

def remove_identical_columns(df):
'''
Remove columns from a DataFrame that have identical values across all rows.
Parameters:
df (DataFrame): The DataFrame to process.
Returns:
DataFrame: A DataFrame with identical columns removed.
'''
columns_to_drop = [col for col in df.columns if df[col].nunique() == 1]
df_cleaned = df.drop(columns=columns_to_drop)

Expand All @@ -376,6 +392,18 @@ def remove_identical_columns(df):
def train(datapath,modelpath,
shapnum=10,target_label=None,
query_limit=20,calibration_num=5000):
'''
Train a Veritas model with specified parameters.
Parameters:
datapath (str): Path to the data file.
modelpath (str): Path to save the trained model.
shapnum (int): Number of samples for SHAP value calculation.
target_label (str): Target label column name.
query_limit (int): Limit on the number of features to use.
calibration_num (int): Number of samples for calibration.
'''
TR=truthnet(datapath=datapath,
target_label=target_label,
query_limit=query_limit,VERBOSE=False)
Expand Down
26 changes: 12 additions & 14 deletions build/lib/truthnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ def validate(response_dataframe,C0,C1,C2,


if validation_type == "withdx":
mratio=(response_dataframe[(response_dataframe.mg==-1) & (response_dataframe.dx==1)].index.size)/response_dataframe.dx.sum()
mratio=(response_dataframe[(response_dataframe.mg==-1)
& (response_dataframe.dx==1)].index.size)/response_dataframe.dx.sum()
fullauc=zt.auc()

if plots:
plt.style.use('seaborn-dark-palette')
#plt.style.use('seaborn-dark-palette')

plt.figure(figsize=[20,12])
plt.subplot(231)
Expand All @@ -149,29 +151,27 @@ def validate(response_dataframe,C0,C1,C2,
cf=response_dataframe.corr()
plt.subplot(234)
sns.heatmap(cf,cmap='jet',alpha=.5)



plt.subplot(235)

plt.plot(fpr,tpr,'g',lw=2)
plt.gca().legend(['R20'])
zt.get().tpr.plot(style='-b',lw=2)
fullauc=zt.auc()

ax = plt.subplot(236)
ax.text(0.5, 0.6, f'malinger prevalenec in DX: {mratio:.2f}', fontsize=16, ha='center')
ax.text(0.5, 0.4, f'AUC: {fullauc[0]:.2f} $\pm$ {fullauc[1]-fullauc[0]:.2f}', fontsize=16, ha='center')
ax.set_xticks([])
ax.set_yticks([])
ax.set_frame_on(False)
else:
return {'auc':fullauc,'mratio':mratio}
return {'auc':fullauc,'mratio':mratio}, response_dataframe, zt

if validation_type == "fnrexpt":
fnr=response_dataframe[(response_dataframe.mg==1)].index.size/response_dataframe.index.size
if plots:
plt.style.use('seaborn-dark-palette')
#plt.style.use('seaborn-dark-palette')

plt.figure(figsize=[20,12])
plt.subplot(231)
Expand All @@ -191,21 +191,20 @@ def validate(response_dataframe,C0,C1,C2,
cf=response_dataframe.corr()
plt.subplot(234)
sns.heatmap(cf,cmap='jet',alpha=.5)


ax = plt.subplot(236)
ax.text(0.5, 0.6, f'FNR in EXPT: {fnr:.2f}', fontsize=16, ha='center')
ax.set_xticks([])
ax.set_yticks([])
ax.set_frame_on(False)
else:
return {'fnr':fnr}

return {'fnr':fnr}, response_dataframe


if validation_type == "noscore":
mrate=response_dataframe[response_dataframe.mg==-1].index.size/response_dataframe.index.size
if plots:
plt.style.use('seaborn-dark-palette')
#plt.style.use('seaborn-dark-palette')

plt.figure(figsize=[8,8])
plt.subplot(111)
Expand All @@ -216,10 +215,9 @@ def validate(response_dataframe,C0,C1,C2,

ax = plt.gca()


ax.text(0.65, 0.8, f'mrate: {mrate:.2f}', fontsize=16, ha='center')
else:
return {'mrate':mrate}

return {'mrate':mrate}, response_dataframe

plt.savefig(outfile,dpi=300,bbox_inches='tight',transparent=True)

Expand Down
Binary file modified dist/truthnet-0.0.23-py3-none-any.whl
Binary file not shown.
Binary file modified dist/truthnet-0.0.23.tar.gz
Binary file not shown.
Loading

0 comments on commit 0a597d6

Please sign in to comment.