Skip to content

Commit

Permalink
Merge pull request #23 from mzmine/library-handler
Browse files Browse the repository at this point in the history
Library handler
  • Loading branch information
niekdejonge authored Nov 2, 2023
2 parents 4d30d73 + 0c4f014 commit 92ae6ee
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 39 deletions.
175 changes: 151 additions & 24 deletions library_spectra_validation/library_handler.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,163 @@
from matchms.importing import load_spectra
from matchms.filtering.SpectrumProcessor import SpectrumProcessor
from filters import PRIMARY_FILTERS
from validation_pipeline import Modification, SpectrumRepairer, SpectrumValidator

class LibraryHandler:
"""Stores the 3 different types of spectra. Correct, repaired, wrong.
Has internal organization using spectrum ids"""

def __init__(self, f, pipeline):
#todo modify default pipeline
def __init__(self, f):
metadata_field_harmonization = SpectrumProcessor(predefined_pipeline=None,
additional_filters=PRIMARY_FILTERS)
self.spectra = metadata_field_harmonization.process_spectrums(load_spectra(f))
self.pipeline = pipeline
self.spectra_dictionary = {
'valid': None, #[id1, id2,...]
'repaired': None, #[id1:[modifications],..]
'invalid': None #also a dictionary
}
self.modifications = {} #todo change to Modifications class

def clean_and_validate_spectrum(self, spectrum_id):
spectrum = self.spectra[spectrum_id]
modifications = self.pipeline.run(spectrum)
spectrum_id.update_spectra_dictionary(spectrum_id, modifications)
self.modifications.append(modifications)

def update_spectra_dictionary(self, spectrum_id, modifications):
self.spectra_dictionary[modifications["spectra_quality"]["updated"]].append(spectrum_id) #valid, repaired,...
if ((modifications["spectra_quality"]["updated"] != None) &
(modifications["spectra_quality"]["updated"] != modifications["spectra_quality"]["previous"])):
self.spectra_dictionary[modifications["spectra_quality"]["previous"]].remove(spectrum_id)

def run(self):
self.spectrum_repairer = SpectrumRepairer()
self.spectrum_validator = SpectrumValidator()
self.validated_spectra = []
self.nonvalidated_spectra = []
self.modifications = {}
self.failed_requirements = {}

self.initial_run()

def initial_run(self):
for spectrum_id in range(len(self.spectra)):
self.clean_and_validate_spectrum(spectrum_id)
spectrum = self.spectra[spectrum_id]
modifications, spectrum = self.spectrum_repairer.process_spectrum_store_modifications(spectrum)
self.modifications[spectrum_id] = modifications

self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(
spectrum)
self.update_spectra_quality_lists(spectrum_id)
self.spectra[spectrum_id] = spectrum

# iterate over all failed requirements
# it's almost streamlit
# for the dashboard run should use spectrum id
# for spectrum_id in range(len(self.spectra)):
# if len(self.failed_requirements[spectrum_id]) != 0:
# self.pass_user_validation_info(spectrum_id)
# #todo should we grab here state variable from streamlit - accept or change
# # self.user_approve_repair(spectrum_id)
# # self.user_metadat_change(spectrum_id)

def update_spectra_quality_lists(self, spectrum_id):
"""Will update validated_spectra and nonvalidated_spectra list for this spectrum_id"""
valid_spectrum = True
if len(self.failed_requirements[spectrum_id]) != 0:
valid_spectrum = False
for modification in self.modifications[spectrum_id]:
if modification.validated_by_user is False:
valid_spectrum = False

if valid_spectrum is True:
if spectrum_id not in self.validated_spectra:
self.validated_spectra.append(spectrum_id)
if spectrum_id in self.nonvalidated_spectra:
self.nonvalidated_spectra.remove(spectrum_id)
else:
if spectrum_id not in self.nonvalidated_spectra:
self.nonvalidated_spectra.append(spectrum_id)
if spectrum_id in self.validated_spectra:
self.validated_spectra.remove(spectrum_id)

def return_user_validation_info(self, spectrum_id):
"""
Returns all info related to spectrum_id
"""
assert spectrum_id in self.nonvalidated_spectra

modifications = self.modifications[spectrum_id]
failed_requirements = self.failed_requirements[spectrum_id]

return modifications, failed_requirements, self.spectra[spectrum_id]

def approve_repair(self, spectrum_id, field_name):
"""Accepts every modification done to a field_name"""
# Accepts every modification so far.
for modification in self.modifications[spectrum_id]:
if modification.metadata_field == field_name:
modification.validated_by_user = True
self.update_spectra_quality_lists(spectrum_id)

def approve_all_repairs(self, spectrum_id):
"""Accepts all modifications done for a spectrum"""
for modification in self.modifications[spectrum_id]:
modification.validated_by_user = True
self.update_spectra_quality_lists(spectrum_id)

def decline_last_repair(self, spectrum_id, field_name):
"""Undo the last modification made to a field"""
for mod_idx, modification in enumerate(self.modifications[spectrum_id]):
# Checks if it is the correct metadata field and if it was the last changed made
if modification.metadata_field == field_name and modification.after == self.spectra[spectrum_id].get(field_name):
# undo change
spectrum = self.spectra[spectrum_id]
spectrum.set(field_name, modification.before)
self.spectra[spectrum_id] = spectrum
# remove the modification from the list of modifications
del self.modifications[spectrum_id][mod_idx]
# todo run validation after.

def decline_all_repairs_on_a_field(self, spectrum_id, field_name):
"""Undoes all the repairs for a specific field.
This is achieved by iteratively removing the last added repair"""
nr_of_modifications_to_field = len([modification for modification in self.modifications[spectrum_id]
if modification.metadata_field == field_name])
# Removes all the modifications until the last one was removed.
for _ in range(nr_of_modifications_to_field):
self.decline_last_repair(spectrum_id, field_name)
# todo run validation after.

def decline_all_repairs_spectrum(self, spectrum_id):
"""Undoes all modifications made to a spectrum"""
while len(self.modifications[spectrum_id]) > 0:
for mod_idx, modification in enumerate(self.modifications[spectrum_id]):
field_name = modification.metadata_field
# Checks if it was the last changed made
if modification.after == self.spectra[spectrum_id].get(field_name):
# undo change
spectrum = self.spectra[spectrum_id]
spectrum.set(field_name, modification.before)
self.spectra[spectrum_id] = spectrum
# remove the modification from the list of modifications
del self.modifications[spectrum_id][mod_idx]

def decline_wrapper(self, spectrum_id, field_name, only_last_repair: bool):
if field_name is None:
self.decline_all_repairs_spectrum(spectrum_id)
elif only_last_repair:
self.decline_last_repair(spectrum_id, field_name)
else:
self.decline_all_repairs_on_a_field(spectrum_id, field_name)

self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])
self.update_spectra_quality_lists(spectrum_id)

def user_metadata_change(self, field_name, user_input, spectrum_id):
"""This function takes user defined metadata and rewrites the required field in spectra
The info on user-defined modifications is added to modifications dictionary and mandatory
validation is rerun.
"""
# Add a user defined modification
self.modifications[spectrum_id].append(
Modification(metadata_field=field_name, before=self.spectra[spectrum_id].get(field_name),
after=user_input, logging_message="Manual change", validated_by_user=True))
self.spectra[spectrum_id].set(field_name, user_input)
self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])
self.update_spectra_quality_lists(spectrum_id)


def user_rerun_repair(self, spectrum_id, rerun: bool):
'''
The function behind user's choice to rerun the repairment and validation
Should be linked to a button in a dashboard
'''
if rerun: #todo do we even need it??
self.modifications[spectrum_id] = self.spectrum_repairer.process_spectrum_store_modifications(self.spectra[spectrum_id])
self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])




63 changes: 63 additions & 0 deletions library_spectra_validation/tests/test_library_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from library_handler import LibraryHandler


def test_init_library_handler():
LibraryHandler("./examples/test_case_correct.mgf")


def test_approve_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.approve_repair(spectrum_id=spectrum_id, field_name="inchi")
assert library_handler.modifications[spectrum_id][0].validated_by_user is True


def test_approve_all_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.approve_all_repairs(spectrum_id=spectrum_id)
for modification in library_handler.modifications[spectrum_id]:
assert modification.validated_by_user is True


def test_decline_last_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.decline_last_repair(spectrum_id=spectrum_id, field_name="inchi")
assert len(library_handler.modifications[0]) == 1


def test_decline_all_repairs_on_a_field():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
# todo add test that actually has multiple repairs for one field
library_handler.decline_all_repairs_on_a_field(spectrum_id=spectrum_id, field_name="inchi")
assert len(library_handler.modifications[0]) == 1


def test_decline_all_repairs_spectrum():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.decline_all_repairs_spectrum(spectrum_id=spectrum_id)
assert len(library_handler.modifications[0]) == 0

# todo check that change is undone


def test_decline_wrapper():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
original_spectrum = library_handler.spectra[0]
spectrum_id = 0
library_handler.decline_wrapper(spectrum_id=spectrum_id, field_name=None, only_last_repair=False)
assert len(library_handler.modifications[0]) == 0
# check that changes were undone
assert original_spectrum == library_handler.spectra[spectrum_id]
assert len(library_handler.failed_requirements[spectrum_id]) == 3
assert spectrum_id in library_handler.nonvalidated_spectra


def test_user_metadata_change():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.user_metadata_change(spectrum_id=spectrum_id, field_name="smiles", user_input="CCC")
assert library_handler.spectra[spectrum_id].get("smiles") == "CCC"
4 changes: 0 additions & 4 deletions library_spectra_validation/tests/test_spectra_loading.py

This file was deleted.

42 changes: 31 additions & 11 deletions library_spectra_validation/validation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@
"""

import logging
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, Optional, Union, Tuple
from matchms.filtering.SpectrumProcessor import SpectrumProcessor
from matchms import Spectrum

logger = logging.getLogger("matchms")


class Modification:
def __init__(self, metadata_field, before, after, logging_message, validated_by_user):
self.metadata_field = metadata_field
self.before = before
self.after = after
# self.original =
self.logging_message = logging_message
self.validated_by_user = validated_by_user


class RequirementFailure:
def __init__(self, metadata_field, logging_message):
self.metadata_field = metadata_field
self.logging_message = logging_message


def find_modifications(spectrum_old, spectrum_new, logging_message: str):
"""Checks which modifications have been made in a filter step"""
modifications = []
Expand All @@ -34,7 +41,7 @@ def find_modifications(spectrum_old, spectrum_new, logging_message: str):
modifications.append(
Modification(metadata_field=metadata_field,
before=spectrum_old.get(metadata_field),
after=spectrum_new(metadata_field),
after=spectrum_new.get(metadata_field),
logging_message=logging_message,
validated_by_user=False))
return modifications
Expand All @@ -50,7 +57,7 @@ def process_spectrum(self, spectrum,
processing_report=None):
raise AttributeError("process spectrum is not a valid method of SpectrumValidator")

def process_spectrum_store_modifications(self, spectrum) -> List[Modification]:
def process_spectrum_store_modifications(self, spectrum) -> Tuple[List[Modification], Spectrum]:
if not self.filters:
raise TypeError("No filters to process")
modifications = []
Expand All @@ -64,21 +71,31 @@ def process_spectrum_store_modifications(self, spectrum) -> List[Modification]:
if spectrum_out is None:
raise AttributeError("SpectrumRepairer is only expected to repair spectra, not set to None")
spectrum = spectrum_out
return modifications
return modifications, spectrum


class SpectrumValidator(SpectrumProcessor):
def __init__(self):
# todo add the fields each requirement checks.
fields_checked_by_filter = {filter_name: [fields_checked]}
self.fields_checked_by_filter = {
"require_precursor_mz": ["precursor_mz"],
"require_valid_annotation": ["smiles", "inchi", "inchikey"],
"require_correct_ionmode": ["ionmode", "adduct", "charge"],
# "require_parent_mass_match_smiles": ["smiles", "parent_mass"]
}
# todo require adduct, precursor mz and parent mass match.
# todo add all the checks for formatting. That everything is filled and of the expected format.
super().__init__(predefined_pipeline=None,
additional_filters=list(fields_checked_by_filter.keys()))

additional_filters=("require_precursor_mz",
"require_valid_annotation",
("require_correct_ionmode", {"ion_mode_to_keep": "both"}),
# ("require_parent_mass_match_smiles", {'mass_tolerance': 0.1}),
))
# todo add require parent mass match smiles after matchms release.
def process_spectrum(self, spectrum,
processing_report=None):
raise AttributeError("process spectrum is not a valid method of SpectrumValidator")

def process_spectrum_store_failed_filters(self, spectrum) -> List[Modification]:
def process_spectrum_store_failed_filters(self, spectrum) -> List[RequirementFailure]:
if not self.filters:
raise TypeError("No filters to process")
failed_requirements = []
Expand All @@ -87,5 +104,8 @@ def process_spectrum_store_failed_filters(self, spectrum) -> List[Modification]:
logging_message = ""
spectrum_out = filter_func(spectrum)
if spectrum_out is None:
failed_requirements += logging_message
fields_changed = self.fields_checked_by_filter[filter_func.__name__]
for field_changed in fields_changed:
failed_requirements.append(RequirementFailure(field_changed,
logging_message))
return failed_requirements

0 comments on commit 92ae6ee

Please sign in to comment.