Skip to content

Commit

Permalink
[MRG] Speed up sourmash gather by ignoring unidentifiable hashes (#…
Browse files Browse the repository at this point in the history
…1613)

* refactor gather_databases into iterator class

* remove notify statements in search.py

* Fix flakes

* tests passing

* more cleanup

* cleanup; still passing :)

* do expensive stuff less frequently; tests passing

* add property access to cmp_scaled

* simplify interface

* add noident_mh arg

* refactor/add in noident_mh to calcs

* first attempt to fully integrate noident_mh; 6 tests failing

* houston, we have liftoff

* updated multigather, too

* update multigather output_unassigned with noident_mh

* groan update more remove_many calls

* one more remove_many

* add debug

* <sigh> avoid calculating matched_query_mh until end

* better fix for prefetch hang

* remove unneeded comment

* add check for downsampling reportin gather
  • Loading branch information
ctb authored Jun 20, 2021
1 parent 74de59a commit 0814bcc
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 89 deletions.
90 changes: 65 additions & 25 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def _yield_all_sigs(queries, ksize, moltype):


def gather(args):
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet, args.debug)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -676,36 +676,47 @@ def gather(args):
notify("Starting prefetch sweep across databases.")
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()
save_prefetch = SaveSignaturesToLocation(args.save_prefetch)
save_prefetch.open()

counters = []
for db in databases:
counter = None
try:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
except ValueError:
if picklist:
# catch "no signatures to search" ValueError...
continue
else:
raise # re-raise other errors, if no picklist.

save_prefetch.add_many(counter.siglist)
# subtract found hashes as we can.
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

notify(f"Found {len(save_prefetch)} signatures via prefetch; now doing gather.")
save_prefetch.close()
else:
counters = databases
# we can't track unidentified hashes w/o prefetch
noident_mh = None

## ok! now do gather -

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
next_query = query
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)

gather_iter = gather_databases(query, counters, args.threshold_bp,
args.ignore_abundance)
for result, weighted_missed, next_query in gather_iter:
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -737,6 +748,11 @@ def gather(args):
break


# report on thresholding -
if gather_iter.query:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting:
print_results(f'\nfound {len(found)} matches total;')
if args.num_results and len(found) == args.num_results:
Expand All @@ -745,6 +761,8 @@ def gather(args):
p_covered = (1 - weighted_missed) * 100
print_results(f'the recovered matches hit {p_covered:.1f}% of the query')
print_results('')
if gather_iter.scaled != query.minhash.scaled:
print_results(f'WARNING: final scaled was {gather_iter.scaled}, vs query scaled of {query.minhash.scaled}')

# save CSV?
if found and args.output:
Expand Down Expand Up @@ -772,25 +790,31 @@ def gather(args):

# save unassigned hashes?
if args.output_unassigned:
if not len(next_query.minhash):
remaining_query = gather_iter.query
if not (remaining_query.minhash or noident_mh):
notify('no unassigned hashes to save with --output-unassigned!')
else:
notify(f"saving unassigned hashes to '{args.output_unassigned}'")

if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh
remaining_query.minhash = remaining_mh

if is_abundance:
# next_query is flattened; reinflate abundances
hashes = set(next_query.minhash.hashes)
# remaining_query is flattened; reinflate abundances
hashes = set(remaining_query.minhash.hashes)
orig_abunds = orig_query_mh.hashes
abunds = { h: orig_abunds[h] for h in hashes }

abund_query_mh = orig_query_mh.copy_and_clear()
# orig_query might have been downsampled...
abund_query_mh.downsample(scaled=next_query.minhash.scaled)
abund_query_mh.downsample(scaled=gather_iter.scaled)
abund_query_mh.set_abundances(abunds)
next_query.minhash = abund_query_mh
remaining_query.minhash = abund_query_mh

with FileOutput(args.output_unassigned, 'wt') as fp:
sig.save_signatures([ next_query ], fp)
sig.save_signatures([ remaining_query ], fp)

if picklist:
sourmash_args.report_picklist(args, picklist)
Expand All @@ -800,7 +824,7 @@ def gather(args):

def multigather(args):
"Gather many signatures against multiple databases."
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -858,16 +882,23 @@ def multigather(args):
counters = []
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()

counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
for result, weighted_missed, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance):
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -895,6 +926,10 @@ def multigather(args):
name)
found.append(result)

# report on thresholding -
if gather_iter.query.minhash:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting
print_results('\nfound {} matches total;', len(found))
Expand Down Expand Up @@ -938,18 +973,21 @@ def multigather(args):

output_unassigned = output_base + '.unassigned.sig'
with open(output_unassigned, 'wt') as fp:
remaining_query = gather_iter.query
if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh.downsample(scaled=remaining_mh.scaled)
remaining_query.minhash = remaining_mh

if not found:
notify('nothing found - entire query signature unassigned.')
elif not len(query.minhash):
elif not remaining_query:
notify('no unassigned hashes! not saving.')
else:
notify('saving unassigned hashes to "{}"', output_unassigned)

e = MinHash(ksize=query.minhash.ksize, n=0,
scaled=next_query.minhash.scaled)
e.add_many(next_query.minhash.hashes)
# CTB: note, multigather does not save abundances
sig.save_signatures([ sig.SourmashSignature(e) ], fp)
sig.save_signatures([ remaining_query ], fp)
n += 1

# fini, next query!
Expand Down Expand Up @@ -1134,6 +1172,7 @@ def prefetch(args):

# iterate over signatures in db one at a time, for each db;
# find those with sufficient overlap
ident_mh = query_mh.copy_and_clear()
noident_mh = query_mh.to_mutable()

did_a_search = False # track whether we did _any_ search at all!
Expand All @@ -1157,8 +1196,10 @@ def prefetch(args):
for result in prefetch_database(query, db, args.threshold_bp):
match = result.match

# track remaining "untouched" hashes.
noident_mh.remove_many(match.minhash.hashes)
# track found & "untouched" hashes.
match_mh = match.minhash.downsample(scaled=query.minhash.scaled)
ident_mh += query.minhash & match_mh.flatten()
noident_mh.remove_many(match.minhash)

# output match info as we go
if csvout_fp:
Expand Down Expand Up @@ -1194,15 +1235,14 @@ def prefetch(args):
notify(f"saved {matches_out.count} matches to CSV file '{args.output}'")
csvout_fp.close()

matched_query_mh = query_mh.to_mutable()
matched_query_mh.remove_many(noident_mh.hashes)
notify(f"of {len(query_mh)} distinct query hashes, {len(matched_query_mh)} were found in matches above threshold.")
assert len(query_mh) == len(ident_mh) + len(noident_mh)
notify(f"of {len(query_mh)} distinct query hashes, {len(ident_mh)} were found in matches above threshold.")
notify(f"a total of {len(noident_mh)} query hashes remain unmatched.")

if args.save_matching_hashes:
filename = args.save_matching_hashes
notify(f"saving {len(matched_query_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(matched_query_mh)
notify(f"saving {len(ident_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(ident_mh)
with open(filename, "wt") as fp:
sig.save_signatures([ss], fp)

Expand Down
Loading

0 comments on commit 0814bcc

Please sign in to comment.