Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] circumvent a very slow MinHash.remove_many(...) call in sourmash gather #2123

Merged
merged 3 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def gather(args):
scaled = query_mh.scaled

counters = []
ident_mh = noident_mh.copy_and_clear()
for db in databases:
counter = None
try:
Expand All @@ -734,13 +735,16 @@ def gather(args):
raise # re-raise other errors, if no picklist.

save_prefetch.add_many(counter.signatures())
# subtract found hashes as we can.
for found_sig in counter.signatures():
noident_mh.remove_many(found_sig.minhash)

# optionally calculate and save prefetch csv
if prefetch_csvout_fp:
assert scaled
# update found/not found hashes from the union/intersection of
# found.
union_found = counter.union_found
ident_mh.add_many(union_found)
noident_mh.remove_many(union_found)

# optionally calculate and output prefetch info to csv
if prefetch_csvout_fp:
for found_sig in counter.signatures():
# calculate intersection stats and info
prefetch_result = PrefetchResult(prefetch_query, found_sig, cmp_scaled=scaled,
threshold_bp=args.threshold_bp, estimate_ani_ci=args.estimate_ani_ci)
Expand All @@ -762,6 +766,7 @@ def gather(args):
counters = databases
# we can't track unidentified hashes w/o prefetch
noident_mh = None
ident_mh = None

## ok! now do gather -

Expand All @@ -775,6 +780,7 @@ def gather(args):
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh,
ident_mh=ident_mh,
estimate_ani_ci=args.estimate_ani_ci)

for result, weighted_missed in gather_iter:
Expand Down Expand Up @@ -930,23 +936,28 @@ def multigather(args):
counters = []
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
ident_mh = prefetch_query.minhash.copy_and_clear()
noident_mh = prefetch_query.minhash.to_mutable()

counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
for found_sig in counter.signatures():
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

# track found/not found hashes
union_found = counter.union_found
noident_mh.remove_many(union_found)
ident_mh.add_many(union_found)

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)
noident_mh=noident_mh,
ident_mh=ident_mh)
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
Expand Down
20 changes: 20 additions & 0 deletions src/sourmash/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,26 @@ def signatures(self):
for ss in self.siglist.values():
yield ss

@property
def union_found(self):
"""Return a MinHash containing all found hashes in the query.

This calculates the union of the found matches, intersected
with the original query.
"""
orig_query_mh = self.orig_query_mh

# create empty MinHash from orig query
found_mh = orig_query_mh.copy_and_clear()

# for each match, intersect match with query & then add to found_mh.
for ss in self.siglist.values():
intersect_mh = flatten_and_intersect_scaled(ss.minhash,
orig_query_mh)
found_mh.add_many(intersect_mh)

return found_mh

def peek(self, cur_query_mh, *, threshold_bp=0):
"Get next 'gather' result for this database, w/o changing counters."
self.query_started = 1
Expand Down
9 changes: 6 additions & 3 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ class GatherDatabases:
"Iterator object for doing gather/min-set-cov."

def __init__(self, query, counters, *,
threshold_bp=0, ignore_abundance=False, noident_mh=None, estimate_ani_ci=False):
threshold_bp=0, ignore_abundance=False, noident_mh=None, ident_mh=None, estimate_ani_ci=False):
# track original query information for later usage?
track_abundance = query.minhash.track_abundance and not ignore_abundance
self.orig_query = query
Expand All @@ -675,8 +675,11 @@ def __init__(self, query, counters, *,
noident_mh = query_mh.copy_and_clear()
self.noident_mh = noident_mh.to_frozen()

query_mh = query_mh.to_mutable()
query_mh.remove_many(noident_mh)
if ident_mh is None:
query_mh = query_mh.to_mutable()
query_mh.remove_many(noident_mh)
else:
query_mh = ident_mh.to_mutable()

orig_query_mh = query_mh.flatten()
query.minhash = orig_query_mh.to_mutable()
Expand Down