-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
def getrocprob(raw_scores,rocdata): | ||
sample_prevalence=rocdata['positive_samples']/rocdata['total_samples'] | ||
prevalence = rocdata.get('prevalence',sample_prevalence) | ||
|
||
from zedstat.zedstat import score_to_probability | ||
prob=score_to_probability(raw_scores,df=rocdata['roc'], | ||
total_samples=rocdata['total_samples'], | ||
positive_samples=rocdata['positive_samples'], | ||
prevalence=prevalence) | ||
return prob | ||
|
||
|
||
def _diss_linear(s, qnet, missing_response='',missing_diss_value=0): | ||
import numpy as np | ||
Ds = qnet.predict_distributions(s) | ||
diss_values = [] | ||
for i in range(len(s)): | ||
prob = float(Ds[i].get(str(s[i]),np.max(list(Ds[i].values())) )) | ||
max_prob = max(Ds[i].values()) | ||
diss_value = 1 - prob / max_prob if max_prob != 0 else 0 | ||
diss_values.append(diss_value) | ||
return diss_values | ||
|
||
def truthfinder(patients_responses, problem_type): | ||
import pandas as pd | ||
from quasinet.qnet import load_qnet | ||
import pickle | ||
|
||
patients_responses = [{patient_id: {question_id: str(int(response)) | ||
for question_id, | ||
response in patient_responses.items()} | ||
for patient_id, patient_responses in patient.items()} | ||
for patient in patients_responses] | ||
|
||
|
||
model_path = f"models/{problem_type}/random_order_full_model_0.joblib.gz" | ||
classifier_path = f"classifiers/{problem_type}/runif-classifier-{346}.pkl" | ||
roc_path = f"classifiers/{problem_type}/runif-roc-{346}.pkl" | ||
with open(roc_path, 'rb') as file: | ||
rocdata = pickle.load(file) | ||
|
||
model = load_qnet(model_path) | ||
classifier = pd.read_pickle(classifier_path) | ||
all_data_samples = [] | ||
|
||
for patient_response in patients_responses: | ||
for patient_id, responses in patient_response.items(): | ||
resp_df = pd.DataFrame([responses], columns=model.feature_names) | ||
data_samples = (resp_df | ||
.fillna('') # Replace missing values with empty strings | ||
.astype(str) # Convert all values to strings | ||
.values) | ||
all_data_samples.append(data_samples[0]) | ||
|
||
diss_values = [_diss_linear(sample, model) for sample in all_data_samples] | ||
proba = classifier.predict_proba(diss_values)[:,1] | ||
|
||
estimated_probability_of_event = getrocprob(proba,rocdata) | ||
output_dict = {} | ||
for idx, patient_response in enumerate(patients_responses): | ||
for patient_id in patient_response.keys(): | ||
output_dict[patient_id] = {'probability':estimated_probability_of_event[idx][0], | ||
'ci': tuple(estimated_probability_of_event[idx][1:]), | ||
'rawscore':proba[idx]} | ||
return output_dict | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from truthfinder import truthfinder | ||
|
||
|
||
|
||
|
||
problem_type = 'global' | ||
|
||
patients_responses = [{"0": {"20": '1', "31": 2.0, "36": 1.0, "51": 2.0, "60": 1.0, "75": 1.0, "85": 1.0, "115": 1.0, "123": 2.0, "147": 2.0, "150": 3.0, "179": 2.0, "312": 2.0, "355": 2.0, "436": 4.0, "572": 3.0, "719": 2.0, "737": 2.0, "747": 2.0, "817": 2.0, "855": 2.0, "938": 1.0, "983": 2.0, "1053": 2.0, "1060": 2.0, "1069": 1.0, "1132": 2.0, "4201": 1.0, "4203": 1, "4247": 1.0, "4250": 2.0, "4251": 2.0, "4254": 2.0, "4255": 1, "4262": 2.0, "4278": 1, "4280": 1, "4287": 1, "4289": 1, "4439": 1.0, "4441": 1.0, "4442": 2.0, "4452": 1.0, "4453": 0, "4454": 0, "4455": 0, "4456": 0.0, "4525": 1, "4527": 1, "4529": 3.0, "4531": 1.0, "4532": 2.0, "4545": 3.0, "4553": 1.0, "4555": 1, "4571": 1.0, "4574": 2.0, "4587": 1.0}}, {"9": {"20": 2.0, "31": 1.0, "36": 1.0, "46": 1.0, "51": 1.0, "53": 1.0, "61": 2.0, "62": 2.0, "76": 1.0, "85": 1.0, "109": 3.0, "110": 4.0, "111": 3.0, "123": 2.0, "149": 4.0, "210": 3.0, "255": 2.0, "576": 1.0, "603": 4.0, "840": 2.0, "841": 2.0, "855": 2.0, "881": 2.0, "927": 2.0, "928": 2.0, "936": 2.0, "1009": 3.0, "1041": 3.0, "1080": 3.0, "1125": 2.0, "4203": 1, "4247": 1.0, "4248": 4.0, "4250": 2.0, "4251": 3.0, "4255": 3, "4262": 1.0, "4278": 1, "4280": 1, "4287": 1, "4289": 1, "4355": 3.0, "4439": 3.0, "4441": 3.0, "4451": 2.0, "4452": 3.0, "4453": 0, "4454": 0, "4455": 0, "4456": 0.0, "4525": 2, "4527": 1, "4554": 2.0, "4555": 3, "4586": 1.0, "4587": 3.0, "4590": 1.0, "4596": 3.0}}, {"PHX1": {"20": '1', "31": 2.0, "36": 1.0, "51": 2.0, "60": 1.0, "75": 1.0, "85": 1.0, "115": 1.0, "123": 2.0, "147": 2.0, "150": 3.0, "179": 2.0, "312": 2.0, "355": 2.0, "436": 4.0, "572": 3.0, "719": 2.0, "737": 2.0, "747": 2.0, "817": 2.0, "855": 2.0, "938": 1.0, "983": 2.0, "1053": 2.0, "1060": 2.0, "1069": 1.0, "1132": 2.0, "4201": 1.0, "4203": 1, "4247": 1.0, "4250": 2.0, "4251": 2.0, "4254": 2.0, "4255": 1, "4262": 2.0, "4278": 1, "4280": 1} } | ||
] | ||
|
||
|
||
|
||
result = truthfinder(patients_responses, problem_type) | ||
print(result) |