Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add refinery max iterations arg #103

Merged
merged 8 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test-aviary.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Test Aviary with Setup-Miniconda From Marketplace
on: [push]
on: [push, pull_request]

jobs:
miniconda:
Expand Down Expand Up @@ -34,3 +34,5 @@ jobs:
run: |
aviary -h
python test/test_assemble.py
python test/test_recover.py
python test/test_run_checkm.py -b
8 changes: 8 additions & 0 deletions aviary/aviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@ def main():
default='global'
)

binning_group.add_argument(
'--refinery-max-iterations', '--refinery_max_iterations',
help='Maximum number of iterations for Rosella refinery. Set to 0 to skip refinery.',
dest='refinery_max_iterations',
type=int,
default=5
)

binning_group.add_argument(
'--skip-binners', '--skip_binners', '--skip_binner', '--skip-binner',
help='Optional list of binning algorithms to skip. Can be any combination of: \n'
Expand Down
43 changes: 12 additions & 31 deletions aviary/modules/binning/binning.smk
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ rule checkm_rosella:
pplacer_threads = config["pplacer_threads"],
checkm2_db_path = config["checkm2_db_folder"],
bin_folder = "data/rosella_bins/",
extension = "fna"
extension = "fna",
refinery_max_iterations = config["refinery_max_iterations"],
group: 'binning'
output:
output_folder = directory("data/rosella_bins/checkm2_out/"),
Expand All @@ -372,20 +373,8 @@ rule checkm_rosella:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'if [ `ls "{params.bin_folder}" |grep .fna$ |wc -l` -eq 0 ]; then '
'echo "No bins found in {params.bin_folder}"; '
'touch {output.output_file}; '
'mkdir -p {output.output_folder}; '
'else '

'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}; '

'fi'
script:
"scripts/run_checkm.py"

rule checkm_metabat2:
input:
Expand All @@ -403,12 +392,8 @@ rule checkm_metabat2:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule checkm_semibin:
input:
Expand All @@ -426,12 +411,8 @@ rule checkm_semibin:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule refine_rosella:
input:
Expand All @@ -449,7 +430,7 @@ rule refine_rosella:
extension = "fna",
output_folder = "data/rosella_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -480,7 +461,7 @@ rule refine_metabat2:
extension = "fa",
output_folder = "data/metabat2_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -509,7 +490,7 @@ rule refine_semibin:
extension = "fa",
output_folder = "data/semibin_refined/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = False
Expand Down Expand Up @@ -623,7 +604,7 @@ rule refine_dastool:
extension = "fa",
output_folder = "data/refined_bins/",
min_bin_size = config["min_bin_size"],
max_iterations = 5,
max_iterations = config["refinery_max_iterations"],
pplacer_threads = config["pplacer_threads"],
max_contamination = 15,
final_refining = True
Expand Down
30 changes: 30 additions & 0 deletions aviary/modules/binning/scripts/run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import subprocess
import shutil
import os
from pathlib import Path

def checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads):
if len([f for f in os.listdir(bin_folder) if f.endswith(bin_ext)]) == 0:
print(f"No bins found in {bin_folder}")
os.makedirs(output_folder)
Path(output_file).touch()
elif refinery_max_iterations == 0:
print("Skipping pre-refinery CheckM2 rules")
os.makedirs(output_folder)
Path(output_file).touch()
else:
print(f"Using CheckM2 database {checkm2_db}/uniref100.KO.1.dmnd")
subprocess.run(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x {bin_ext} -o {output_folder} -t {threads} --force")
shutil.copy(f"{output_folder}/quality_report.tsv", output_file)


if __name__ == '__main__':
checkm2_db = snakemake.params.checkm2_db_path
bin_folder = snakemake.params.bin_folder
bin_ext = snakemake.params.extension
refinery_max_iterations = snakemake.params.refinery_max_iterations
output_folder = snakemake.output.output_folder
output_file = snakemake.output.output_file
threads = snakemake.threads

checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads)
3 changes: 3 additions & 0 deletions aviary/modules/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self,
self.min_contig_size = args.min_contig_size
self.min_bin_size = args.min_bin_size
self.semibin_model = args.semibin_model
self.refinery_max_iterations = args.refinery_max_iterations

self.skip_binners = []
if args.skip_binners:
Expand All @@ -125,6 +126,7 @@ def __init__(self,
self.min_contig_size = 1500
self.min_bin_size = 200000
self.semibin_model = 'global'
self.refinery_max_iterations = 5
self.skip_binners = ["none"]

try:
Expand Down Expand Up @@ -308,6 +310,7 @@ def make_config(self):
conf["gsa_mappings"] = self.gsa_mappings
conf["skip_binners"] = self.skip_binners
conf["semibin_model"] = self.semibin_model
conf["refinery_max_iterations"] = self.refinery_max_iterations
conf["max_threads"] = int(self.threads)
conf["pplacer_threads"] = int(self.pplacer_threads)
conf["max_memory"] = int(self.max_memory)
Expand Down
80 changes: 80 additions & 0 deletions test/test_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import tempfile
import extern
from snakemake import load_configfile

path_to_data = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data')
path_to_conda = os.path.join(path_to_data,'.conda')
Expand Down Expand Up @@ -122,5 +123,84 @@ def test_recover_skip_binners(self):
# Unnecessary
self.assertTrue("complete_assembly_with_qc" not in output)

def test_recover_no_singlem(self):
with tempfile.TemporaryDirectory() as tmpdir:
cmd = (
f"GTDBTK_DATA_PATH=. "
f"CHECKM2DB=. "
f"EGGNOG_DATA_DIR=. "
f"aviary recover "
f"--workflow recover_mags_no_singlem "
f"--assembly {ASSEMBLY} "
f"-1 {FORWARD_READS} "
f"-2 {REVERSE_READS} "
f"--output {tmpdir}/test "
f"--conda-prefix {path_to_conda} "
f"--dryrun "
f"--snakemake-cmds \" --quiet\" "
)
output = extern.run(cmd)

# Binners
self.assertTrue("prepare_binning_files" in output)
self.assertTrue("get_bam_indices" in output)
self.assertTrue("metabat_sens" in output)
self.assertTrue("metabat_spec" in output)
self.assertTrue("metabat_ssens" in output)
self.assertTrue("metabat_sspec" in output)
self.assertTrue("metabat2" in output)
self.assertTrue("maxbin2" in output)
self.assertTrue("rosella" in output)
self.assertTrue("semibin" in output)
self.assertTrue("vamb" in output)
self.assertTrue("concoct" in output)
self.assertTrue("das_tool" in output)

# Refinery
self.assertTrue("checkm_metabat2" in output)
self.assertTrue("refine_metabat2" in output)
self.assertTrue("checkm_rosella" in output)
self.assertTrue("refine_rosella" in output)
self.assertTrue("checkm_semibin" in output)
self.assertTrue("refine_semibin" in output)
self.assertTrue("checkm_das_tool" in output)
self.assertTrue("refine_dastool" in output)

# Extras
self.assertTrue("checkm2" in output)
self.assertTrue("gtdbtk" in output)
self.assertTrue("get_abundances" in output)
self.assertTrue("singlem_pipe_reads" not in output)
self.assertTrue("singlem_appraise" not in output)
self.assertTrue("finalize_stats" in output)
self.assertTrue("recover_mags_no_singlem" in output)

# Unnecessary
self.assertTrue("complete_assembly_with_qc" not in output)

def test_recover_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
cmd = (
f"GTDBTK_DATA_PATH=. "
f"CHECKM2DB=. "
f"EGGNOG_DATA_DIR=. "
f"aviary recover "
f"--refinery-max-iterations 3 "
f"--assembly {ASSEMBLY} "
f"-1 {FORWARD_READS} "
f"-2 {REVERSE_READS} "
f"--output {tmpdir}/test "
f"--conda-prefix {path_to_conda} "
f"--dryrun "
f"--snakemake-cmds \" --quiet\" "
)
extern.run(cmd)

config_path = os.path.join(tmpdir, "test", "config.yaml")
self.assertTrue(os.path.exists(config_path))
config = load_configfile(config_path)

self.assertEqual(config["refinery_max_iterations"], 3)

if __name__ == '__main__':
unittest.main()
59 changes: 59 additions & 0 deletions test/test_run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3

import unittest
import os
import tempfile
from aviary.modules.binning.scripts.run_checkm import checkm
from unittest.mock import patch
import subprocess
from pathlib import Path

def create_output(_):
os.makedirs("output_folder")
Path(os.path.join("output_folder", "quality_report.tsv")).touch()

class Tests(unittest.TestCase):
def test_run_checkm(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run", side_effect=create_output) as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_called_once_with(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x fna -o output_folder -t 1 --force")

def test_run_checkm_no_bins(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()

def test_run_checkm_no_refinery(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 0, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()


if __name__ == '__main__':
unittest.main()
Loading