Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Speed up sourmash gather by ignoring unidentifiable hashes #1613

Merged
merged 22 commits into from
Jun 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def _yield_all_sigs(queries, ksize, moltype):


def gather(args):
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet, args.debug)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -676,36 +676,47 @@ def gather(args):
notify("Starting prefetch sweep across databases.")
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()
save_prefetch = SaveSignaturesToLocation(args.save_prefetch)
save_prefetch.open()

counters = []
for db in databases:
counter = None
try:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
except ValueError:
if picklist:
# catch "no signatures to search" ValueError...
continue
else:
raise # re-raise other errors, if no picklist.

save_prefetch.add_many(counter.siglist)
# subtract found hashes as we can.
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

notify(f"Found {len(save_prefetch)} signatures via prefetch; now doing gather.")
save_prefetch.close()
else:
counters = databases
# we can't track unidentified hashes w/o prefetch
noident_mh = None

## ok! now do gather -

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
next_query = query
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)

gather_iter = gather_databases(query, counters, args.threshold_bp,
args.ignore_abundance)
for result, weighted_missed, next_query in gather_iter:
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -737,6 +748,11 @@ def gather(args):
break


# report on thresholding -
if gather_iter.query:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting:
print_results(f'\nfound {len(found)} matches total;')
if args.num_results and len(found) == args.num_results:
Expand All @@ -745,6 +761,8 @@ def gather(args):
p_covered = (1 - weighted_missed) * 100
print_results(f'the recovered matches hit {p_covered:.1f}% of the query')
print_results('')
if gather_iter.scaled != query.minhash.scaled:
print_results(f'WARNING: final scaled was {gather_iter.scaled}, vs query scaled of {query.minhash.scaled}')

# save CSV?
if found and args.output:
Expand Down Expand Up @@ -772,25 +790,31 @@ def gather(args):

# save unassigned hashes?
if args.output_unassigned:
if not len(next_query.minhash):
remaining_query = gather_iter.query
if not (remaining_query.minhash or noident_mh):
notify('no unassigned hashes to save with --output-unassigned!')
else:
notify(f"saving unassigned hashes to '{args.output_unassigned}'")

if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh
remaining_query.minhash = remaining_mh

if is_abundance:
# next_query is flattened; reinflate abundances
hashes = set(next_query.minhash.hashes)
# remaining_query is flattened; reinflate abundances
hashes = set(remaining_query.minhash.hashes)
orig_abunds = orig_query_mh.hashes
abunds = { h: orig_abunds[h] for h in hashes }

abund_query_mh = orig_query_mh.copy_and_clear()
# orig_query might have been downsampled...
abund_query_mh.downsample(scaled=next_query.minhash.scaled)
abund_query_mh.downsample(scaled=gather_iter.scaled)
abund_query_mh.set_abundances(abunds)
next_query.minhash = abund_query_mh
remaining_query.minhash = abund_query_mh

with FileOutput(args.output_unassigned, 'wt') as fp:
sig.save_signatures([ next_query ], fp)
sig.save_signatures([ remaining_query ], fp)

if picklist:
sourmash_args.report_picklist(args, picklist)
Expand All @@ -800,7 +824,7 @@ def gather(args):

def multigather(args):
"Gather many signatures against multiple databases."
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -858,16 +882,23 @@ def multigather(args):
counters = []
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()

counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
for result, weighted_missed, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance):
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -895,6 +926,10 @@ def multigather(args):
name)
found.append(result)

# report on thresholding -
if gather_iter.query.minhash:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting
print_results('\nfound {} matches total;', len(found))
Expand Down Expand Up @@ -938,18 +973,21 @@ def multigather(args):

output_unassigned = output_base + '.unassigned.sig'
with open(output_unassigned, 'wt') as fp:
remaining_query = gather_iter.query
if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh.downsample(scaled=remaining_mh.scaled)
remaining_query.minhash = remaining_mh

if not found:
notify('nothing found - entire query signature unassigned.')
elif not len(query.minhash):
elif not remaining_query:
notify('no unassigned hashes! not saving.')
else:
notify('saving unassigned hashes to "{}"', output_unassigned)

e = MinHash(ksize=query.minhash.ksize, n=0,
scaled=next_query.minhash.scaled)
e.add_many(next_query.minhash.hashes)
# CTB: note, multigather does not save abundances
sig.save_signatures([ sig.SourmashSignature(e) ], fp)
sig.save_signatures([ remaining_query ], fp)
n += 1

# fini, next query!
Expand Down Expand Up @@ -1134,6 +1172,7 @@ def prefetch(args):

# iterate over signatures in db one at a time, for each db;
# find those with sufficient overlap
ident_mh = query_mh.copy_and_clear()
noident_mh = query_mh.to_mutable()

did_a_search = False # track whether we did _any_ search at all!
Expand All @@ -1157,8 +1196,10 @@ def prefetch(args):
for result in prefetch_database(query, db, args.threshold_bp):
match = result.match

# track remaining "untouched" hashes.
noident_mh.remove_many(match.minhash.hashes)
# track found & "untouched" hashes.
match_mh = match.minhash.downsample(scaled=query.minhash.scaled)
ident_mh += query.minhash & match_mh.flatten()
noident_mh.remove_many(match.minhash)

# output match info as we go
if csvout_fp:
Expand Down Expand Up @@ -1194,15 +1235,14 @@ def prefetch(args):
notify(f"saved {matches_out.count} matches to CSV file '{args.output}'")
csvout_fp.close()

matched_query_mh = query_mh.to_mutable()
matched_query_mh.remove_many(noident_mh.hashes)
notify(f"of {len(query_mh)} distinct query hashes, {len(matched_query_mh)} were found in matches above threshold.")
assert len(query_mh) == len(ident_mh) + len(noident_mh)
notify(f"of {len(query_mh)} distinct query hashes, {len(ident_mh)} were found in matches above threshold.")
notify(f"a total of {len(noident_mh)} query hashes remain unmatched.")

if args.save_matching_hashes:
filename = args.save_matching_hashes
notify(f"saving {len(matched_query_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(matched_query_mh)
notify(f"saving {len(ident_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(ident_mh)
with open(filename, "wt") as fp:
sig.save_signatures([ss], fp)

Expand Down
Loading