From a2f96d6c4779e851fa6ed3b33ba5d735cb8c7675 Mon Sep 17 00:00:00 2001 From: ishanuc Date: Tue, 5 Dec 2023 01:28:21 -0600 Subject: [PATCH] upd --- notebooks/deployment-assets/truthfinder.py | 67 ++++++++++++++++++++++ notebooks/deployment-assets/tscript1.py | 14 +++++ 2 files changed, 81 insertions(+) create mode 100644 notebooks/deployment-assets/truthfinder.py create mode 100644 notebooks/deployment-assets/tscript1.py diff --git a/notebooks/deployment-assets/truthfinder.py b/notebooks/deployment-assets/truthfinder.py new file mode 100644 index 00000000..5ffce72c --- /dev/null +++ b/notebooks/deployment-assets/truthfinder.py @@ -0,0 +1,67 @@ +def getrocprob(raw_scores,rocdata): + sample_prevalence=rocdata['positive_samples']/rocdata['total_samples'] + prevalence = rocdata.get('prevalence',sample_prevalence) + + from zedstat.zedstat import score_to_probability + prob=score_to_probability(raw_scores,df=rocdata['roc'], + total_samples=rocdata['total_samples'], + positive_samples=rocdata['positive_samples'], + prevalence=prevalence) + return prob + + +def _diss_linear(s, qnet, missing_response='',missing_diss_value=0): + import numpy as np + Ds = qnet.predict_distributions(s) + diss_values = [] + for i in range(len(s)): + prob = float(Ds[i].get(str(s[i]),np.max(list(Ds[i].values())) )) + max_prob = max(Ds[i].values()) + diss_value = 1 - prob / max_prob if max_prob != 0 else 0 + diss_values.append(diss_value) + return diss_values + +def truthfinder(patients_responses, problem_type): + import pandas as pd + from quasinet.qnet import load_qnet + import pickle + + patients_responses = [{patient_id: {question_id: str(int(response)) + for question_id, + response in patient_responses.items()} + for patient_id, patient_responses in patient.items()} + for patient in patients_responses] + + + model_path = f"models/{problem_type}/random_order_full_model_0.joblib.gz" + classifier_path = f"classifiers/{problem_type}/runif-classifier-{346}.pkl" + roc_path = f"classifiers/{problem_type}/runif-roc-{346}.pkl" + with open(roc_path, 'rb') as file: + rocdata = pickle.load(file) + + model = load_qnet(model_path) + classifier = pd.read_pickle(classifier_path) + all_data_samples = [] + + for patient_response in patients_responses: + for patient_id, responses in patient_response.items(): + resp_df = pd.DataFrame([responses], columns=model.feature_names) + data_samples = (resp_df + .fillna('') # Replace missing values with empty strings + .astype(str) # Convert all values to strings + .values) + all_data_samples.append(data_samples[0]) + + diss_values = [_diss_linear(sample, model) for sample in all_data_samples] + proba = classifier.predict_proba(diss_values)[:,1] + + estimated_probability_of_event = getrocprob(proba,rocdata) + output_dict = {} + for idx, patient_response in enumerate(patients_responses): + for patient_id in patient_response.keys(): + output_dict[patient_id] = {'probability':estimated_probability_of_event[idx][0], + 'ci': tuple(estimated_probability_of_event[idx][1:]), + 'rawscore':proba[idx]} + return output_dict + + diff --git a/notebooks/deployment-assets/tscript1.py b/notebooks/deployment-assets/tscript1.py new file mode 100644 index 00000000..b0fb754a --- /dev/null +++ b/notebooks/deployment-assets/tscript1.py @@ -0,0 +1,14 @@ +from truthfinder import truthfinder + + + + +problem_type = 'global' + +patients_responses = [{"0": {"20": '1', "31": 2.0, "36": 1.0, "51": 2.0, "60": 1.0, "75": 1.0, "85": 1.0, "115": 1.0, "123": 2.0, "147": 2.0, "150": 3.0, "179": 2.0, "312": 2.0, "355": 2.0, "436": 4.0, "572": 3.0, "719": 2.0, "737": 2.0, "747": 2.0, "817": 2.0, "855": 2.0, "938": 1.0, "983": 2.0, "1053": 2.0, "1060": 2.0, "1069": 1.0, "1132": 2.0, "4201": 1.0, "4203": 1, "4247": 1.0, "4250": 2.0, "4251": 2.0, "4254": 2.0, "4255": 1, "4262": 2.0, "4278": 1, "4280": 1, "4287": 1, "4289": 1, "4439": 1.0, "4441": 1.0, "4442": 2.0, "4452": 1.0, "4453": 0, "4454": 0, "4455": 0, "4456": 0.0, "4525": 1, "4527": 1, "4529": 3.0, "4531": 1.0, "4532": 2.0, "4545": 3.0, "4553": 1.0, "4555": 1, "4571": 1.0, "4574": 2.0, "4587": 1.0}}, {"9": {"20": 2.0, "31": 1.0, "36": 1.0, "46": 1.0, "51": 1.0, "53": 1.0, "61": 2.0, "62": 2.0, "76": 1.0, "85": 1.0, "109": 3.0, "110": 4.0, "111": 3.0, "123": 2.0, "149": 4.0, "210": 3.0, "255": 2.0, "576": 1.0, "603": 4.0, "840": 2.0, "841": 2.0, "855": 2.0, "881": 2.0, "927": 2.0, "928": 2.0, "936": 2.0, "1009": 3.0, "1041": 3.0, "1080": 3.0, "1125": 2.0, "4203": 1, "4247": 1.0, "4248": 4.0, "4250": 2.0, "4251": 3.0, "4255": 3, "4262": 1.0, "4278": 1, "4280": 1, "4287": 1, "4289": 1, "4355": 3.0, "4439": 3.0, "4441": 3.0, "4451": 2.0, "4452": 3.0, "4453": 0, "4454": 0, "4455": 0, "4456": 0.0, "4525": 2, "4527": 1, "4554": 2.0, "4555": 3, "4586": 1.0, "4587": 3.0, "4590": 1.0, "4596": 3.0}}, {"PHX1": {"20": '1', "31": 2.0, "36": 1.0, "51": 2.0, "60": 1.0, "75": 1.0, "85": 1.0, "115": 1.0, "123": 2.0, "147": 2.0, "150": 3.0, "179": 2.0, "312": 2.0, "355": 2.0, "436": 4.0, "572": 3.0, "719": 2.0, "737": 2.0, "747": 2.0, "817": 2.0, "855": 2.0, "938": 1.0, "983": 2.0, "1053": 2.0, "1060": 2.0, "1069": 1.0, "1132": 2.0, "4201": 1.0, "4203": 1, "4247": 1.0, "4250": 2.0, "4251": 2.0, "4254": 2.0, "4255": 1, "4262": 2.0, "4278": 1, "4280": 1} } +] + + + +result = truthfinder(patients_responses, problem_type) +print(result)