diff --git a/careless/stats/isigi.py b/careless/stats/isigi.py new file mode 100644 index 0000000..e683f38 --- /dev/null +++ b/careless/stats/isigi.py @@ -0,0 +1,148 @@ +""" +Compute I/sigI from careless output. +""" +import argparse +import numpy as np +import matplotlib.pyplot as plt +import reciprocalspaceship as rs +import seaborn as sns +import os + + +from careless.io.formatter import get_first_key_of_dtype +from careless.stats.parser import BaseParser +class ArgumentParser(BaseParser): + def __init__(self): + super().__init__( + description=__doc__ + ) + + # Required arguments + self.add_argument( + "mtz", + nargs="+", + help="MTZs containing crossvalidation data from careless", + ) + + self.add_argument( + "-b", + "--bins", + default=20, + type=int, + help=("Number of resolution bins to use, the default is 20."), + ) + + self.add_argument( + "--intensity-key", + dest="I_col", + default=None, + type=str, + help=("Intensity key"), + ) + self.add_argument( + "--uncertainty-key", + dest="sigI_col", + default=None, + type=str, + help=("Sigma(Intensity) key"), + ) + self.add_argument( + "--log", + action="store_true", + help=("Use a logarithmic scale for the y-axis."), + ) + + self.add_argument( + "--overall", + action="store_true", + help="Pool all prediction mtz files into a single calculation rather than treating each file individually.", + ) + + +def run_analysis(args): + ds = [] + for m in args.mtz: + _ds = rs.read_mtz(m) + print(m) + #non-isomorphism could lead to different resolution for each mtz + #need to calculate dHKL before concatenating + _ds.compute_dHKL(inplace=True) + if len(m)<50: + _ds['file'] = m + else: + _ds['file'] = os.path.basename(m) + _ds['Spacegroup'] = _ds.spacegroup.xhm() + ds.append(_ds) + ds = rs.concat(ds, check_isomorphous=False) + bins,edges = rs.utils.bin_by_percentile(ds.dHKL, args.bins, ascending=False) + ds['bin'] = bins + labels = [ + f"{e1:0.2f} - {e2:0.2f}" + for e1, e2 in zip(edges[:-1], edges[1:]) + ] + + ikey = args.I_col + if ikey is None: + ikey = get_first_key_of_dtype(ds, 'J') + sigkey=args.sigI_col + if sigkey is None: + sigkey = get_first_key_of_dtype(ds, 'Q') + + if args.overall: + grouper = ds.groupby(["bin"]) + else: + grouper = ds.groupby(["file", "bin"]) + + result = grouper.apply(lambda x : np.mean(x[ikey]/x[sigkey])) + result = rs.DataSet({"I/sigI" : result}).reset_index() + result['Resolution Range (Å)'] = np.array(labels)[result.bin] + result['Spacegroup'] = grouper['Spacegroup'].apply('first').to_numpy() + if not args.overall: + result['file'] = grouper['file'].apply('first').to_numpy() + result = result[['file', 'Resolution Range (Å)', 'bin', 'Spacegroup', 'I/sigI']] + else: + result = result[['Resolution Range (Å)', 'bin', 'Spacegroup', 'I/sigI']] + + + if args.output is not None: + result.to_csv(args.output) + else: + print(result.to_string()) + + plot_kwargs = { + 'data' : result, + 'x' : 'bin', + 'y' : 'I/sigI', + } + + if args.overall: + plot_kwargs['color'] = 'k' + else: + plot_kwargs['hue'] = 'file' + plot_kwargs['palette'] = "Dark2" + + ax=sns.lineplot(**plot_kwargs) + if args.log: + ax.set(yscale='log') + plt.xticks(range(args.bins), labels, rotation=45, ha="right", rotation_mode="anchor") + plt.ylabel(r"$\mathrm{I/\sigma(I)}$ ") + plt.xlabel("Resolution ($\mathrm{\AA}$)") + plt.grid(which='both', axis='both', ls='dashdot') + plt.ylim(args.ylim) + + plt.tight_layout() + + if args.image is not None: + plt.savefig(args.image) + + if args.show: + plt.show() + + +def main(): + parser = ArgumentParser().parse_args() + # print(parser) + run_analysis(parser) + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 54679da..0d0843e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ careless = "careless.careless:main" "careless.cchalf" = "careless.stats.cchalf:main" "careless.ccpred" = "careless.stats.ccpred:main" "careless.image_ccpred" = "careless.stats.image_cc:main" +"careless.isigi" = "careless.stats.isigi:main" "careless.filter_image" = "careless.stats.filter_by_image_cc:main" "careless.plot_history" = "careless.stats.history:main" "careless.bfactor" = "careless.stats.prior_b:main" diff --git a/tests/stats/test_cc.py b/tests/stats/test_cc.py index 35ed204..6572136 100644 --- a/tests/stats/test_cc.py +++ b/tests/stats/test_cc.py @@ -1,4 +1,4 @@ -from careless.stats import cchalf,ccanom,ccpred,rsplit,image_cc,filter_by_image_cc +from careless.stats import cchalf,ccanom,ccpred,rsplit,image_cc,filter_by_image_cc,isigi from tempfile import TemporaryDirectory from os.path import exists from os import symlink @@ -105,6 +105,43 @@ def test_ccpred(predictions_mtz, method, bins, overall, multi): else: assert len(df) == 2*bins +@pytest.mark.parametrize("bins", [1, 5]) +@pytest.mark.parametrize("overall", [True, False]) +@pytest.mark.parametrize("method", ["spearman", "pearson"]) +@pytest.mark.parametrize("multi", [False, True]) +def test_isigi(predictions_mtz, method, bins, overall, multi): + tf = TemporaryDirectory() + csv = f"{tf.name}/out.csv" + png = f"{tf.name}/out.png" + command = f"-o {csv} -i {png} -b {bins} " + if overall: + command = command + ' --overall ' + + if multi: + mtz_0 = f'{tf.name}/test_predictions_0.mtz' + mtz_1 = f'{tf.name}/test_predictions_1.mtz' + symlink(predictions_mtz, mtz_0) + symlink(predictions_mtz, mtz_1) + command = command + f" {mtz_0} " + command = command + f" {mtz_1} " + else: + command = command + f" {predictions_mtz} " + + parser = isigi.ArgumentParser().parse_args(command.split()) + + assert not exists(csv) + assert not exists(png) + isigi.run_analysis(parser) + assert exists(csv) + assert exists(png) + + df = pd.read_csv(csv) + + if multi and not overall: + assert len(df) == 2*bins + else: + assert len(df) == 1*bins + @pytest.mark.parametrize("method", ["spearman", "pearson"]) @pytest.mark.parametrize("multi", [False, True])