Skip to content

Commit

Permalink
Merge pull request #103 from AroneyS/add-refinery-max-iterations-arg
Browse files Browse the repository at this point in the history
Add refinery max iterations arg
  • Loading branch information
AroneyS authored Jul 28, 2023
2 parents 019a9bf + 0ff4bc3 commit 1a699ac
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 32 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-aviary.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Test Aviary with Setup-Miniconda From Marketplace
on: [push]
on: [push, pull_request]

jobs:
miniconda:
Expand Down Expand Up @@ -34,3 +34,5 @@ jobs:
run: |
aviary -h
python test/test_assemble.py
python test/test_recover.py
python test/test_run_checkm.py -b
8 changes: 8 additions & 0 deletions aviary/aviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@ def main():
default='global'
)

binning_group.add_argument(
'--refinery-max-iterations', '--refinery_max_iterations',
help='Maximum number of iterations for Rosella refinery. Set to 0 to skip refinery.',
dest='refinery_max_iterations',
type=int,
default=5
)

binning_group.add_argument(
'--skip-binners', '--skip_binners', '--skip_binner', '--skip-binner',
help='Optional list of binning algorithms to skip. Can be any combination of: \n'
Expand Down
43 changes: 12 additions & 31 deletions aviary/modules/binning/binning.smk
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ rule checkm_rosella:
pplacer_threads = config["pplacer_threads"],
checkm2_db_path = config["checkm2_db_folder"],
bin_folder = "data/rosella_bins/",
extension = "fna"
extension = "fna",
refinery_max_iterations = config["refinery_max_iterations"],
group: 'binning'
output:
output_folder = directory("data/rosella_bins/checkm2_out/"),
Expand All @@ -372,20 +373,8 @@ rule checkm_rosella:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'if [ `ls "{params.bin_folder}" |grep .fna$ |wc -l` -eq 0 ]; then '
'echo "No bins found in {params.bin_folder}"; '
'touch {output.output_file}; '
'mkdir -p {output.output_folder}; '
'else '

'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}; '

'fi'
script:
"scripts/run_checkm.py"

rule checkm_metabat2:
input:
Expand All @@ -403,12 +392,8 @@ rule checkm_metabat2:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule checkm_semibin:
input:
Expand All @@ -426,12 +411,8 @@ rule checkm_semibin:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule refine_rosella:
input:
Expand All @@ -449,7 +430,7 @@ rule refine_rosella:
extension = "fna",
output_folder = "data/rosella_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -480,7 +461,7 @@ rule refine_metabat2:
extension = "fa",
output_folder = "data/metabat2_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -509,7 +490,7 @@ rule refine_semibin:
extension = "fa",
output_folder = "data/semibin_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -623,7 +604,7 @@ rule refine_dastool:
extension = "fa",
output_folder = "data/refined_bins/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = True
Expand Down
30 changes: 30 additions & 0 deletions aviary/modules/binning/scripts/run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import subprocess
import shutil
import os
from pathlib import Path

def checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads):
if len([f for f in os.listdir(bin_folder) if f.endswith(bin_ext)]) == 0:
print(f"No bins found in {bin_folder}")
os.makedirs(output_folder)
Path(output_file).touch()
elif refinery_max_iterations == 0:
print("Skipping pre-refinery CheckM2 rules")
os.makedirs(output_folder)
Path(output_file).touch()
else:
print(f"Using CheckM2 database {checkm2_db}/uniref100.KO.1.dmnd")
subprocess.run(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x {bin_ext} -o {output_folder} -t {threads} --force")
shutil.copy(f"{output_folder}/quality_report.tsv", output_file)


if __name__ == '__main__':
checkm2_db = snakemake.params.checkm2_db_path
bin_folder = snakemake.params.bin_folder
bin_ext = snakemake.params.extension
refinery_max_iterations = snakemake.params.refinery_max_iterations
output_folder = snakemake.output.output_folder
output_file = snakemake.output.output_file
threads = snakemake.threads

checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads)
3 changes: 3 additions & 0 deletions aviary/modules/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self,
self.min_contig_size = args.min_contig_size
self.min_bin_size = args.min_bin_size
self.semibin_model = args.semibin_model
self.refinery_max_iterations = args.refinery_max_iterations

self.skip_binners = []
if args.skip_binners:
Expand All @@ -125,6 +126,7 @@ def __init__(self,
self.min_contig_size = 1500
self.min_bin_size = 200000
self.semibin_model = 'global'
self.refinery_max_iterations = 5
self.skip_binners = ["none"]

try:
Expand Down Expand Up @@ -308,6 +310,7 @@ def make_config(self):
conf["gsa_mappings"] = self.gsa_mappings
conf["skip_binners"] = self.skip_binners
conf["semibin_model"] = self.semibin_model
conf["refinery_max_iterations"] = self.refinery_max_iterations
conf["max_threads"] = int(self.threads)
conf["pplacer_threads"] = int(self.pplacer_threads)
conf["max_memory"] = int(self.max_memory)
Expand Down
80 changes: 80 additions & 0 deletions test/test_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import tempfile
import extern
from snakemake import load_configfile

path_to_data = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data')
path_to_conda = os.path.join(path_to_data,'.conda')
Expand Down Expand Up @@ -122,5 +123,84 @@ def test_recover_skip_binners(self):
# Unnecessary
self.assertTrue("complete_assembly_with_qc" not in output)

def test_recover_no_singlem(self):
with tempfile.TemporaryDirectory() as tmpdir:
cmd = (
f"GTDBTK_DATA_PATH=. "
f"CHECKM2DB=. "
f"EGGNOG_DATA_DIR=. "
f"aviary recover "
f"--workflow recover_mags_no_singlem "
f"--assembly {ASSEMBLY} "
f"-1 {FORWARD_READS} "
f"-2 {REVERSE_READS} "
f"--output {tmpdir}/test "
f"--conda-prefix {path_to_conda} "
f"--dryrun "
f"--snakemake-cmds \" --quiet\" "
)
output = extern.run(cmd)

# Binners
self.assertTrue("prepare_binning_files" in output)
self.assertTrue("get_bam_indices" in output)
self.assertTrue("metabat_sens" in output)
self.assertTrue("metabat_spec" in output)
self.assertTrue("metabat_ssens" in output)
self.assertTrue("metabat_sspec" in output)
self.assertTrue("metabat2" in output)
self.assertTrue("maxbin2" in output)
self.assertTrue("rosella" in output)
self.assertTrue("semibin" in output)
self.assertTrue("vamb" in output)
self.assertTrue("concoct" in output)
self.assertTrue("das_tool" in output)

# Refinery
self.assertTrue("checkm_metabat2" in output)
self.assertTrue("refine_metabat2" in output)
self.assertTrue("checkm_rosella" in output)
self.assertTrue("refine_rosella" in output)
self.assertTrue("checkm_semibin" in output)
self.assertTrue("refine_semibin" in output)
self.assertTrue("checkm_das_tool" in output)
self.assertTrue("refine_dastool" in output)

# Extras
self.assertTrue("checkm2" in output)
self.assertTrue("gtdbtk" in output)
self.assertTrue("get_abundances" in output)
self.assertTrue("singlem_pipe_reads" not in output)
self.assertTrue("singlem_appraise" not in output)
self.assertTrue("finalize_stats" in output)
self.assertTrue("recover_mags_no_singlem" in output)

# Unnecessary
self.assertTrue("complete_assembly_with_qc" not in output)

def test_recover_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
cmd = (
f"GTDBTK_DATA_PATH=. "
f"CHECKM2DB=. "
f"EGGNOG_DATA_DIR=. "
f"aviary recover "
f"--refinery-max-iterations 3 "
f"--assembly {ASSEMBLY} "
f"-1 {FORWARD_READS} "
f"-2 {REVERSE_READS} "
f"--output {tmpdir}/test "
f"--conda-prefix {path_to_conda} "
f"--dryrun "
f"--snakemake-cmds \" --quiet\" "
)
extern.run(cmd)

config_path = os.path.join(tmpdir, "test", "config.yaml")
self.assertTrue(os.path.exists(config_path))
config = load_configfile(config_path)

self.assertEqual(config["refinery_max_iterations"], 3)

if __name__ == '__main__':
unittest.main()
59 changes: 59 additions & 0 deletions test/test_run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3

import unittest
import os
import tempfile
from aviary.modules.binning.scripts.run_checkm import checkm
from unittest.mock import patch
import subprocess
from pathlib import Path

def create_output(_):
os.makedirs("output_folder")
Path(os.path.join("output_folder", "quality_report.tsv")).touch()

class Tests(unittest.TestCase):
def test_run_checkm(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run", side_effect=create_output) as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_called_once_with(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x fna -o output_folder -t 1 --force")

def test_run_checkm_no_bins(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()

def test_run_checkm_no_refinery(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 0, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()


if __name__ == '__main__':
unittest.main()

0 comments on commit 1a699ac

Please sign in to comment.