From 8766e95edd1f58465acc5d3d98960d286dd36b05 Mon Sep 17 00:00:00 2001 From: Luiz Irber Date: Wed, 8 Jan 2020 18:49:38 +0000 Subject: [PATCH] making compute usable inside Python prepare compute test command fix default ksize parsing failing test: if output is a string or path, should we open it as a file? --- sourmash/cli/compute.py | 16 +++- sourmash/command_compute.py | 158 ++++++++++++++++----------------- tests/conftest.py | 5 ++ tests/sourmash_tst_utils.py | 15 ++++ tests/test_sourmash_compute.py | 29 +++--- 5 files changed, 130 insertions(+), 93 deletions(-) diff --git a/sourmash/cli/compute.py b/sourmash/cli/compute.py index fafdf7a7fa..2338a6ca9f 100644 --- a/sourmash/cli/compute.py +++ b/sourmash/cli/compute.py @@ -6,12 +6,24 @@ from sourmash.cli.utils import add_construct_moltype_args +def ksize_parser(ksizes): + # get list of k-mer sizes for which to compute sketches + if ',' in ksizes: + ksizes = ksizes.split(',') + ksizes = list(map(int, ksizes)) + else: + ksizes = [int(ksizes)] + + return ksizes + + def subparser(subparsers): subparser = subparsers.add_parser('compute') sketch_args = subparser.add_argument_group('Sketching options') sketch_args.add_argument( - '-k', '--ksizes', default='21,31,51', + '-k', '--ksizes', default=[21, 31, 51], + type=ksize_parser, help='comma-separated list of k-mer sizes; default=%(default)s' ) sketch_args.add_argument( @@ -126,4 +138,4 @@ def subparser(subparsers): def main(args): from sourmash.command_compute import compute - return compute(args) + return compute(**vars(args)) diff --git a/sourmash/command_compute.py b/sourmash/command_compute.py index e3fea9210b..ee78797850 100644 --- a/sourmash/command_compute.py +++ b/sourmash/command_compute.py @@ -18,7 +18,13 @@ DEFAULT_LINE_COUNT = 1500 -def compute(args): +def compute(filenames=None, check_sequence=False, ksizes=(21, 31, 51), dna=True, dayhoff=False, hp=False, + singleton=False, count_valid_reads=0, barcodes_file=None, line_count=DEFAULT_LINE_COUNT, + rename_10x_barcodes=None, write_barcode_meta_csv=None, save_fastas=None, + email='', scaled=10000, force=False, output=None, num_hashes=500, protein=False, + name_from_first=False, seed=42, input_is_protein=False, merge=None, quiet=False, + track_abundance=False, randomize=False, license='CC0', + input_is_10x=False, processes=2, **kwargs): """Compute the signature for one or more files. Use cases: @@ -31,74 +37,66 @@ def compute(args): => creates one output file file.sig, with all sequences from file1.fa and file2.fa combined into one signature. """ - set_quiet(args.quiet) + set_quiet(quiet) - if args.license != 'CC0': + if license != 'CC0': error('error: sourmash only supports CC0-licensed signatures. sorry!') sys.exit(-1) - if args.input_is_protein and args.dna: + if input_is_protein and dna: notify('WARNING: input is protein, turning off nucleotide hashing') - args.dna = False - args.protein = True + dna = False + protein = True - if args.scaled: - if args.scaled < 1: + if scaled: + if scaled < 1: error('ERROR: --scaled value must be >= 1') sys.exit(-1) - if args.scaled != round(args.scaled, 0): + if scaled != round(scaled, 0): error('ERROR: --scaled value must be integer value') sys.exit(-1) - if args.scaled >= 1e9: + if scaled >= 1e9: notify('WARNING: scaled value is nonsensical!? Continuing anyway.') - if args.num_hashes != 0: + if num_hashes != 0: notify('setting num_hashes to 0 because --scaled is set') - args.num_hashes = 0 + num_hashes = 0 - notify('computing signatures for files: {}', ", ".join(args.filenames)) + notify('computing signatures for files: {}', ", ".join(filenames)) - if args.randomize: + if randomize: notify('randomizing file list because of --randomize') - random.shuffle(args.filenames) - - # get list of k-mer sizes for which to compute sketches - ksizes = args.ksizes - if ',' in ksizes: - ksizes = ksizes.split(',') - ksizes = list(map(int, ksizes)) - else: - ksizes = [int(ksizes)] + random.shuffle(filenames) notify('Computing signature for ksizes: {}', str(ksizes)) num_sigs = 0 - if args.dna and args.protein: + if dna and protein: notify('Computing both nucleotide and protein signatures.') num_sigs = 2*len(ksizes) - elif args.dna and args.dayhoff: + elif dna and dayhoff: notify('Computing both nucleotide and Dayhoff-encoded protein ' 'signatures.') num_sigs = 2*len(ksizes) - elif args.dna and args.hp: + elif dna and hp: notify('Computing both nucleotide and Hp-encoded protein ' 'signatures.') num_sigs = 2*len(ksizes) - elif args.dna: + elif dna: notify('Computing only nucleotide (and not protein) signatures.') num_sigs = len(ksizes) - elif args.protein: + elif protein: notify('Computing only protein (and not nucleotide) signatures.') num_sigs = len(ksizes) - elif args.dayhoff: + elif dayhoff: notify('Computing only Dayhoff-encoded protein (and not nucleotide) ' 'signatures.') num_sigs = len(ksizes) - elif args.hp: + elif hp: notify('Computing only hp-encoded protein (and not nucleotide) ' 'signatures.') num_sigs = len(ksizes) - if (args.protein or args.dayhoff or args.hp) and not args.input_is_protein: + if (protein or dayhoff or hp) and not input_is_protein: bad_ksizes = [ str(k) for k in ksizes if k % 3 != 0 ] if bad_ksizes: error('protein ksizes must be divisible by 3, sorry!') @@ -111,50 +109,48 @@ def compute(args): error('...nothing to calculate!? Exiting!') sys.exit(-1) - if args.merge and not args.output: + if merge and not output: error("must specify -o with --merge") sys.exit(-1) def make_minhashes(): - seed = args.seed - # one minhash for each ksize Elist = [] for k in ksizes: - if args.protein: - E = MinHash(ksize=k, n=args.num_hashes, + if protein: + E = MinHash(ksize=k, n=num_hashes, is_protein=True, dayhoff=False, hp=False, - track_abundance=args.track_abundance, - scaled=args.scaled, + track_abundance=track_abundance, + scaled=scaled, seed=seed) Elist.append(E) - if args.dayhoff: - E = MinHash(ksize=k, n=args.num_hashes, + if dayhoff: + E = MinHash(ksize=k, n=num_hashes, is_protein=True, dayhoff=True, hp=False, - track_abundance=args.track_abundance, - scaled=args.scaled, + track_abundance=track_abundance, + scaled=scaled, seed=seed) Elist.append(E) - if args.hp: - E = MinHash(ksize=k, n=args.num_hashes, + if hp: + E = MinHash(ksize=k, n=num_hashes, is_protein=True, dayhoff=False, hp=True, - track_abundance=args.track_abundance, - scaled=args.scaled, + track_abundance=track_abundance, + scaled=scaled, seed=seed) Elist.append(E) - if args.dna: - E = MinHash(ksize=k, n=args.num_hashes, + if dna: + E = MinHash(ksize=k, n=num_hashes, is_protein=False, dayhoff=False, hp=False, - track_abundance=args.track_abundance, - scaled=args.scaled, + track_abundance=track_abundance, + scaled=scaled, seed=seed) Elist.append(E) return Elist @@ -173,8 +169,8 @@ def build_siglist(Elist, filename, name=None): def save_siglist(siglist, output_fp, filename=None): # save! if output_fp: - sigfile_name = args.output.name - sig.save_signatures(siglist, args.output) + sigfile_name = output.name + sig.save_signatures(siglist, output) else: if filename is None: raise Exception("internal error, filename is None") @@ -185,49 +181,49 @@ def save_siglist(siglist, output_fp, filename=None): 'saved signature(s) to {}. Note: signature license is CC0.', sigfile_name) - if args.track_abundance: + if track_abundance: notify('Tracking abundance of input k-mers.') - if not args.merge: - if args.output: + if not merge: + if output: siglist = [] - for filename in args.filenames: + for filename in filenames: sigfile = os.path.basename(filename) + '.sig' - if not args.output and os.path.exists(sigfile) and not \ - args.force: + if not output and os.path.exists(sigfile) and not \ + force: notify('skipping {} - already done', filename) continue - if args.singleton: + if singleton: siglist = [] for n, record in enumerate(screed.open(filename)): # make minhashes for each sequence Elist = make_minhashes() add_seq(Elist, record.sequence, - args.input_is_protein, args.check_sequence) + input_is_protein, check_sequence) siglist += build_siglist(Elist, filename, name=record.name) notify('calculated {} signatures for {} sequences in {}', len(siglist), n + 1, filename) - elif args.input_is_10x: + elif input_is_10x: from bam2fasta import cli as bam2fasta_cli # Initializing time startt = time.time() metadata = [ - "--write-barcode-meta-csv", args.write_barcode_meta_csv] if args.write_barcode_meta_csv else ['', ''] - save_fastas = ["--save-fastas", args.save_fastas] if args.save_fastas else ['', ''] - barcodes_file = ["--barcodes-file", args.barcodes_file] if args.barcodes_file else ['', ''] + "--write-barcode-meta-csv", write_barcode_meta_csv] if write_barcode_meta_csv else ['', ''] + save_fastas = ["--save-fastas", save_fastas] if save_fastas else ['', ''] + barcodes_file = ["--barcodes-file", barcodes_file] if barcodes_file else ['', ''] rename_10x_barcodes = \ - ["--rename-10x-barcodes", args.rename_10x_barcodes] if args.rename_10x_barcodes else ['', ''] + ["--rename-10x-barcodes", rename_10x_barcodes] if rename_10x_barcodes else ['', ''] bam_to_fasta_args = [ '--filename', filename, - '--min-umi-per-barcode', str(args.count_valid_reads), - '--processes', str(args.processes), - '--line-count', str(args.line_count), + '--min-umi-per-barcode', str(count_valid_reads), + '--processes', str(processes), + '--line-count', str(line_count), barcodes_file[0], barcodes_file[1], rename_10x_barcodes[0], rename_10x_barcodes[1], save_fastas[0], save_fastas[1], @@ -244,7 +240,7 @@ def save_siglist(siglist, output_fp, filename=None): # make minhashes for each sequence Elist = make_minhashes() add_seq(Elist, record.sequence, - args.input_is_protein, args.check_sequence) + input_is_protein, check_sequence) siglist += build_siglist(Elist, fasta, name=record.name) @@ -264,16 +260,16 @@ def save_siglist(siglist, output_fp, filename=None): if n % 10000 == 0: if n: notify('\r...{} {}', filename, n, end='') - elif args.name_from_first: + elif name_from_first: name = record.name add_seq(Elist, record.sequence, - args.input_is_protein, args.check_sequence) + input_is_protein, check_sequence) notify('...{} {} sequences', filename, n, end='') sigs = build_siglist(Elist, filename, name) - if args.output: + if output: siglist += sigs else: siglist = sigs @@ -281,18 +277,18 @@ def save_siglist(siglist, output_fp, filename=None): notify('calculated {} signatures for {} sequences in {}', len(sigs), n + 1, filename) - if not args.output: - save_siglist(siglist, args.output, sigfile) + if not output: + save_siglist(siglist, output, sigfile) - if args.output: - save_siglist(siglist, args.output, sigfile) + if output: + save_siglist(siglist, output, sigfile) else: # single name specified - combine all # make minhashes for the whole file Elist = make_minhashes() n = 0 total_seq = 0 - for filename in args.filenames: + for filename in filenames: # consume & calculate signatures notify('... reading sequences from {}', filename) @@ -301,14 +297,14 @@ def save_siglist(siglist, output_fp, filename=None): notify('\r... {} {}', filename, n, end='') add_seq(Elist, record.sequence, - args.input_is_protein, args.check_sequence) + input_is_protein, check_sequence) notify('... {} {} sequences', filename, n + 1) total_seq += n + 1 - siglist = build_siglist(Elist, filename, name=args.merge) + siglist = build_siglist(Elist, filename, name=merge) notify('calculated {} signatures for {} sequences taken from {} files', - len(siglist), total_seq, len(args.filenames)) + len(siglist), total_seq, len(filenames)) # at end, save! - save_siglist(siglist, args.output) + save_siglist(siglist, output) diff --git a/tests/conftest.py b/tests/conftest.py index 5f5ae54adc..098e6b721e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,11 @@ import pytest +@pytest.fixture(params=[True, False]) +def cli(request): + return request.param + + @pytest.fixture(params=[True, False]) def track_abundance(request): return request.param diff --git a/tests/sourmash_tst_utils.py b/tests/sourmash_tst_utils.py index 43e53d618a..42bfec8328 100644 --- a/tests/sourmash_tst_utils.py +++ b/tests/sourmash_tst_utils.py @@ -1,6 +1,7 @@ "Various utilities used by sourmash tests." from __future__ import print_function +import contextlib import sys import os import tempfile @@ -116,6 +117,9 @@ def runscript(scriptname, args, **kwargs): os.chdir(cwd) + if status is None: + status = 0 + if status != 0 and not fail_ok: print(out) print(err) @@ -257,3 +261,14 @@ def run_shell_cmd(cmd, fail_ok=False, in_directory=None): return (proc.returncode, out, err) finally: os.chdir(cwd) + + +@contextlib.contextmanager +def working_dir(path): + """Changes working directory and returns to previous on exit.""" + prev_cwd = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(prev_cwd) diff --git a/tests/test_sourmash_compute.py b/tests/test_sourmash_compute.py index ff389848e6..982a0a67f6 100644 --- a/tests/test_sourmash_compute.py +++ b/tests/test_sourmash_compute.py @@ -21,12 +21,16 @@ from sourmash import VERSION -def test_do_sourmash_compute(): +def test_do_sourmash_compute(cli): with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa') - status, out, err = utils.runscript('sourmash', - ['compute', '-k', '31', testdata1], - in_directory=location) + + with utils.working_dir(location): + if cli: + status, out, err = utils.runscript('sourmash', + ['compute', '-k', '31', testdata1]) + else: + sourmash.command_compute.compute([testdata1], ksizes=[31]) sigfile = os.path.join(location, 'short.fa.sig') assert os.path.exists(sigfile) @@ -35,7 +39,7 @@ def test_do_sourmash_compute(): assert sig.name().endswith('short.fa') -def test_do_sourmash_compute_output_valid_file(): +def test_do_sourmash_compute_output_valid_file(cli): """ Trigger bug #123 """ with utils.TempDirectory() as location: testdata1 = utils.get_test_data('short.fa') @@ -43,11 +47,16 @@ def test_do_sourmash_compute_output_valid_file(): testdata3 = utils.get_test_data('short3.fa') sigfile = os.path.join(location, 'short.fa.sig') - status, out, err = utils.runscript('sourmash', - ['compute', '-k', '31', '-o', sigfile, - testdata1, - testdata2, testdata3], - in_directory=location) + with utils.working_dir(location): + if cli: + status, out, err = utils.runscript('sourmash', + ['compute', '-k', '31', '-o', sigfile, + testdata1, + testdata2, testdata3]) + else: + sourmash.command_compute.compute([testdata1, testdata2, testdata3], + ksizes=[31], + output=sigfile) assert os.path.exists(sigfile) assert not out # stdout should be empty