diff --git a/.github/workflows/test-aviary.yml b/.github/workflows/test-aviary.yml index deefa77b..54930740 100644 --- a/.github/workflows/test-aviary.yml +++ b/.github/workflows/test-aviary.yml @@ -1,5 +1,5 @@ name: Test Aviary with Setup-Miniconda From Marketplace -on: [push] +on: [push, pull_request] jobs: miniconda: @@ -34,3 +34,5 @@ jobs: run: | aviary -h python test/test_assemble.py + python test/test_recover.py + python test/test_run_checkm.py -b diff --git a/aviary/aviary.py b/aviary/aviary.py index 1b5a9f05..9d143a57 100644 --- a/aviary/aviary.py +++ b/aviary/aviary.py @@ -494,6 +494,14 @@ def main(): default='global' ) + binning_group.add_argument( + '--refinery-max-iterations', '--refinery_max_iterations', + help='Maximum number of iterations for Rosella refinery. Set to 0 to skip refinery.', + dest='refinery_max_iterations', + type=int, + default=5 + ) + binning_group.add_argument( '--skip-binners', '--skip_binners', '--skip_binner', '--skip-binner', help='Optional list of binning algorithms to skip. Can be any combination of: \n' diff --git a/aviary/modules/binning/binning.smk b/aviary/modules/binning/binning.smk index 7c021b22..c0314dba 100644 --- a/aviary/modules/binning/binning.smk +++ b/aviary/modules/binning/binning.smk @@ -363,7 +363,8 @@ rule checkm_rosella: pplacer_threads = config["pplacer_threads"], checkm2_db_path = config["checkm2_db_folder"], bin_folder = "data/rosella_bins/", - extension = "fna" + extension = "fna", + refinery_max_iterations = config["refinery_max_iterations"], group: 'binning' output: output_folder = directory("data/rosella_bins/checkm2_out/"), @@ -372,20 +373,8 @@ rule checkm_rosella: "../../envs/checkm2.yaml" threads: config["max_threads"] - shell: - 'touch {output.output_file}; ' - 'if [ `ls "{params.bin_folder}" |grep .fna$ |wc -l` -eq 0 ]; then ' - 'echo "No bins found in {params.bin_folder}"; ' - 'touch {output.output_file}; ' - 'mkdir -p {output.output_folder}; ' - 'else ' - - 'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; ' - 'echo "Using CheckM2 database $CHECKM2DB"; ' - 'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; ' - 'cp {output.output_folder}/quality_report.tsv {output.output_file}; ' - - 'fi' + script: + "scripts/run_checkm.py" rule checkm_metabat2: input: @@ -403,12 +392,8 @@ rule checkm_metabat2: "../../envs/checkm2.yaml" threads: config["max_threads"] - shell: - 'touch {output.output_file}; ' - 'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; ' - 'echo "Using CheckM2 database $CHECKM2DB"; ' - 'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; ' - 'cp {output.output_folder}/quality_report.tsv {output.output_file}' + script: + "scripts/run_checkm.py" rule checkm_semibin: input: @@ -426,12 +411,8 @@ rule checkm_semibin: "../../envs/checkm2.yaml" threads: config["max_threads"] - shell: - 'touch {output.output_file}; ' - 'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; ' - 'echo "Using CheckM2 database $CHECKM2DB"; ' - 'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; ' - 'cp {output.output_folder}/quality_report.tsv {output.output_file}' + script: + "scripts/run_checkm.py" rule refine_rosella: input: @@ -449,7 +430,7 @@ rule refine_rosella: extension = "fna", output_folder = "data/rosella_refined/", min_bin_size = config["min_bin_size"], - max_iterations = 5, + max_iterations = config["refinery_max_iterations"], pplacer_threads = config["pplacer_threads"], max_contamination = 15, final_refining = False @@ -480,7 +461,7 @@ rule refine_metabat2: extension = "fa", output_folder = "data/metabat2_refined/", min_bin_size = config["min_bin_size"], - max_iterations = 5, + max_iterations = config["refinery_max_iterations"], pplacer_threads = config["pplacer_threads"], max_contamination = 15, final_refining = False @@ -509,7 +490,7 @@ rule refine_semibin: extension = "fa", output_folder = "data/semibin_refined/", min_bin_size = config["min_bin_size"], - max_iterations = 5, + max_iterations = config["refinery_max_iterations"], pplacer_threads = config["pplacer_threads"], max_contamination = 15, final_refining = False @@ -623,7 +604,7 @@ rule refine_dastool: extension = "fa", output_folder = "data/refined_bins/", min_bin_size = config["min_bin_size"], - max_iterations = 5, + max_iterations = config["refinery_max_iterations"], pplacer_threads = config["pplacer_threads"], max_contamination = 15, final_refining = True diff --git a/aviary/modules/binning/scripts/run_checkm.py b/aviary/modules/binning/scripts/run_checkm.py new file mode 100755 index 00000000..d989823d --- /dev/null +++ b/aviary/modules/binning/scripts/run_checkm.py @@ -0,0 +1,30 @@ +import subprocess +import shutil +import os +from pathlib import Path + +def checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads): + if len([f for f in os.listdir(bin_folder) if f.endswith(bin_ext)]) == 0: + print(f"No bins found in {bin_folder}") + os.makedirs(output_folder) + Path(output_file).touch() + elif refinery_max_iterations == 0: + print("Skipping pre-refinery CheckM2 rules") + os.makedirs(output_folder) + Path(output_file).touch() + else: + print(f"Using CheckM2 database {checkm2_db}/uniref100.KO.1.dmnd") + subprocess.run(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x {bin_ext} -o {output_folder} -t {threads} --force") + shutil.copy(f"{output_folder}/quality_report.tsv", output_file) + + +if __name__ == '__main__': + checkm2_db = snakemake.params.checkm2_db_path + bin_folder = snakemake.params.bin_folder + bin_ext = snakemake.params.extension + refinery_max_iterations = snakemake.params.refinery_max_iterations + output_folder = snakemake.output.output_folder + output_file = snakemake.output.output_file + threads = snakemake.threads + + checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads) diff --git a/aviary/modules/processor.py b/aviary/modules/processor.py index ab2f5039..4e6a251a 100644 --- a/aviary/modules/processor.py +++ b/aviary/modules/processor.py @@ -108,6 +108,7 @@ def __init__(self, self.min_contig_size = args.min_contig_size self.min_bin_size = args.min_bin_size self.semibin_model = args.semibin_model + self.refinery_max_iterations = args.refinery_max_iterations self.skip_binners = [] if args.skip_binners: @@ -125,6 +126,7 @@ def __init__(self, self.min_contig_size = 1500 self.min_bin_size = 200000 self.semibin_model = 'global' + self.refinery_max_iterations = 5 self.skip_binners = ["none"] try: @@ -308,6 +310,7 @@ def make_config(self): conf["gsa_mappings"] = self.gsa_mappings conf["skip_binners"] = self.skip_binners conf["semibin_model"] = self.semibin_model + conf["refinery_max_iterations"] = self.refinery_max_iterations conf["max_threads"] = int(self.threads) conf["pplacer_threads"] = int(self.pplacer_threads) conf["max_memory"] = int(self.max_memory) diff --git a/test/test_recover.py b/test/test_recover.py index 5ddc6005..0d73654e 100644 --- a/test/test_recover.py +++ b/test/test_recover.py @@ -4,6 +4,7 @@ import os import tempfile import extern +from snakemake import load_configfile path_to_data = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data') path_to_conda = os.path.join(path_to_data,'.conda') @@ -122,5 +123,84 @@ def test_recover_skip_binners(self): # Unnecessary self.assertTrue("complete_assembly_with_qc" not in output) + def test_recover_no_singlem(self): + with tempfile.TemporaryDirectory() as tmpdir: + cmd = ( + f"GTDBTK_DATA_PATH=. " + f"CHECKM2DB=. " + f"EGGNOG_DATA_DIR=. " + f"aviary recover " + f"--workflow recover_mags_no_singlem " + f"--assembly {ASSEMBLY} " + f"-1 {FORWARD_READS} " + f"-2 {REVERSE_READS} " + f"--output {tmpdir}/test " + f"--conda-prefix {path_to_conda} " + f"--dryrun " + f"--snakemake-cmds \" --quiet\" " + ) + output = extern.run(cmd) + + # Binners + self.assertTrue("prepare_binning_files" in output) + self.assertTrue("get_bam_indices" in output) + self.assertTrue("metabat_sens" in output) + self.assertTrue("metabat_spec" in output) + self.assertTrue("metabat_ssens" in output) + self.assertTrue("metabat_sspec" in output) + self.assertTrue("metabat2" in output) + self.assertTrue("maxbin2" in output) + self.assertTrue("rosella" in output) + self.assertTrue("semibin" in output) + self.assertTrue("vamb" in output) + self.assertTrue("concoct" in output) + self.assertTrue("das_tool" in output) + + # Refinery + self.assertTrue("checkm_metabat2" in output) + self.assertTrue("refine_metabat2" in output) + self.assertTrue("checkm_rosella" in output) + self.assertTrue("refine_rosella" in output) + self.assertTrue("checkm_semibin" in output) + self.assertTrue("refine_semibin" in output) + self.assertTrue("checkm_das_tool" in output) + self.assertTrue("refine_dastool" in output) + + # Extras + self.assertTrue("checkm2" in output) + self.assertTrue("gtdbtk" in output) + self.assertTrue("get_abundances" in output) + self.assertTrue("singlem_pipe_reads" not in output) + self.assertTrue("singlem_appraise" not in output) + self.assertTrue("finalize_stats" in output) + self.assertTrue("recover_mags_no_singlem" in output) + + # Unnecessary + self.assertTrue("complete_assembly_with_qc" not in output) + + def test_recover_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + cmd = ( + f"GTDBTK_DATA_PATH=. " + f"CHECKM2DB=. " + f"EGGNOG_DATA_DIR=. " + f"aviary recover " + f"--refinery-max-iterations 3 " + f"--assembly {ASSEMBLY} " + f"-1 {FORWARD_READS} " + f"-2 {REVERSE_READS} " + f"--output {tmpdir}/test " + f"--conda-prefix {path_to_conda} " + f"--dryrun " + f"--snakemake-cmds \" --quiet\" " + ) + extern.run(cmd) + + config_path = os.path.join(tmpdir, "test", "config.yaml") + self.assertTrue(os.path.exists(config_path)) + config = load_configfile(config_path) + + self.assertEqual(config["refinery_max_iterations"], 3) + if __name__ == '__main__': unittest.main() diff --git a/test/test_run_checkm.py b/test/test_run_checkm.py new file mode 100644 index 00000000..8e878577 --- /dev/null +++ b/test/test_run_checkm.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import unittest +import os +import tempfile +from aviary.modules.binning.scripts.run_checkm import checkm +from unittest.mock import patch +import subprocess +from pathlib import Path + +def create_output(_): + os.makedirs("output_folder") + Path(os.path.join("output_folder", "quality_report.tsv")).touch() + +class Tests(unittest.TestCase): + def test_run_checkm(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + with patch.object(subprocess, "run", side_effect=create_output) as mock_subprocess: + checkm2_db = os.path.join("checkm2_db") + os.makedirs(checkm2_db) + bin_folder = os.path.join("bin_folder") + os.makedirs(bin_folder) + Path(os.path.join(bin_folder, "bin_1.fna")).touch() + + checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1) + self.assertTrue(os.path.exists("output_file")) + mock_subprocess.assert_called_once_with(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x fna -o output_folder -t 1 --force") + + def test_run_checkm_no_bins(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + with patch.object(subprocess, "run") as mock_subprocess: + checkm2_db = os.path.join("checkm2_db") + os.makedirs(checkm2_db) + bin_folder = os.path.join("bin_folder") + os.makedirs(bin_folder) + + checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1) + self.assertTrue(os.path.exists("output_file")) + mock_subprocess.assert_not_called() + + def test_run_checkm_no_refinery(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + with patch.object(subprocess, "run") as mock_subprocess: + checkm2_db = os.path.join("checkm2_db") + os.makedirs(checkm2_db) + bin_folder = os.path.join("bin_folder") + os.makedirs(bin_folder) + Path(os.path.join(bin_folder, "bin_1.fna")).touch() + + checkm(checkm2_db, bin_folder, "fna", 0, "output_folder", "output_file", 1) + self.assertTrue(os.path.exists("output_file")) + mock_subprocess.assert_not_called() + + +if __name__ == '__main__': + unittest.main()