Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXP: switch to using calc-full-gather.py #18

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ rule sketch:
expand("sketches/metag/{n}.sig.zip", n=METAGENOME_NAMES),
expand("sketches/genomes/{n}.sig.zip", n=GENOME_NAMES),

rule fastgather:
input:
expand("outputs/metag_gather/{n}.{k}.fastgather.csv",
n=METAGENOME_NAMES, k=GATHER_KSIZE),

rule gather:
input:
expand("outputs/metag_gather/{n}.{k}.gather.csv",
n=METAGENOME_NAMES, k=GATHER_KSIZE),

def genome_inp(wc):
return GENOME_NAMES[wc.name]

Expand Down Expand Up @@ -259,9 +269,11 @@ rule metag_gather:
resources:
gather=1
shell: """
sourmash gather {input.query} {input.db} -k {wildcards.k} \
--picklist {input.fastgather_out}:match_md5:md5 \
-o {output.csv} > {output.out}
scripts/calc-full-gather.py {input.query} \
{input.db} {input.fastgather_out} -o {output.csv} > {output.out}
# sourmash gather {input.query} {input.db} -k {wildcards.k} \
# --picklist {input.fastgather_out}:match_md5:md5 \
# -o {output.csv} > {output.out}
"""

rule prepare_taxdb:
Expand Down
201 changes: 201 additions & 0 deletions scripts/calc-full-gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#! /usr/bin/env python
# copied from https://github.com/ctb/2024-calc-full-gather commit 87d7d9ec8c67c
# for now.
"""

CTB TODO:
- deal with abundance stuff
- deal with threshold bp
- support the usual gather output? matches, matched hashes, etc.
- multiple databases...
"""
import argparse
import sys
import sourmash
import csv

from sourmash.search import GatherResult, format_bp
from sourmash.logging import print_results, set_quiet


def get_ident(name):
return name.split(' ')[0]


def zipfile_load_ss_from_row(db, row):
data = db.storage.load(row['internal_location'])
sigs = sourmash.signature.load_signatures(data)

return_sig = None
for ss in sigs:
if ss.md5sum() == row['md5']:
assert return_sig is None # there can only be one!
return_sig = ss

if return_sig is None:
raise ValueError("no match to requested row in db")
return return_sig


def main():
p = argparse.ArgumentParser()
p.add_argument('query', help='query sketch')
p.add_argument('database', help='database of sketches (zip file)')
p.add_argument('fastgather_csv', help='output from fastgather')
p.add_argument('--scaled', type=int, default=1000,
help='scaled value for comparison')
p.add_argument('--threshold-bp', type=int, default=50000,
help='threshold for matches')
p.add_argument('-o', '--output', default=None,
help='CSV output')
p.add_argument('-q', '--quiet', default=False, action='store_true',
help='suppress output')
p.add_argument('--estimate-ani-ci', default=False, action='store_true',
help='estimate ANI confidence intervals (default: False)')
args = p.parse_args()

set_quiet(args.quiet)

db = sourmash.load_file_as_index(args.database)

fastgather_results = []
with open(args.fastgather_csv, 'r', newline='') as fp:
r = csv.DictReader(fp)
fastgather_results.extend(r)

print(f"loaded {len(fastgather_results)} results.")
for header in 'query_filename,rank,query_name,query_md5,match_name,match_md5,intersect_bp'.split(','):
assert header in fastgather_results[0].keys()

# find manifest entries => load directly? do we have that API?

# or, do picklist?
pl = sourmash.picklist.SignaturePicklist('prefetch',
pickfile=args.fastgather_csv)
_, dup_vals = pl.load()

mf = db.manifest
mf = mf.select_to_manifest(picklist=pl)

# order rows by rank/order in gather result
ident_to_row = {}
for row in mf.rows:
name = row['name']
ident = get_ident(name)
ident_to_row[ident] = row

ordered_rows = []
for n, gather_result in enumerate(fastgather_results):
assert n == int(gather_result['rank'])
ident = get_ident(gather_result['match_name'])
mf_row = ident_to_row[ident]
ordered_rows.append(mf_row)

# guess ksize, get scaled - from first match
first_ss = zipfile_load_ss_from_row(db, ordered_rows[0])
ksize = first_ss.minhash.ksize
scaled = max(args.scaled, first_ss.minhash.scaled)

print(f"ksize={ksize}, scaled={scaled}")

query_ss = sourmash.load_file_as_index(args.query)
query_ss = query_ss.select(ksize=ksize, scaled=scaled)
query_ss = list(query_ss.signatures())
assert len(query_ss) == 1, query_ss
query_ss = query_ss[0]
assert query_ss.minhash.track_abundance, \
"Query signatures must have abundance (for now)."

orig_query_mh = query_ss.minhash.downsample(scaled=scaled)
query_mh = orig_query_mh.to_mutable()

orig_query_abunds = query_mh.hashes
sum_abunds = sum(orig_query_abunds.values())

# initialize output
csv_writer = sourmash.sourmash_args.FileOutputCSV(args.output)
outfp = csv_writer.open()
result_writer = None

# iterate over results, row by row
screen_width = 80
is_abundance = True
sum_f_uniq_found = 0.
found = False
for rank, mf_row in enumerate(ordered_rows):
best_match = zipfile_load_ss_from_row(db, mf_row)

found_mh = best_match.minhash.downsample(scaled=scaled).flatten()

n_weighted_missed = sum(( orig_query_abunds[k] for k in query_mh.hashes ))
sum_weighted_found = sum_abunds - n_weighted_missed

result = GatherResult(query_ss,
best_match,
cmp_scaled=scaled,
filename=args.database,
gather_result_rank=rank,
gather_querymh=query_mh,
ignore_abundance=False,
threshold_bp=args.threshold_bp,
orig_query_len=len(orig_query_mh),
orig_query_abunds=orig_query_abunds,
estimate_ani_ci=args.estimate_ani_ci,
sum_weighted_found=sum_weighted_found,
total_weighted_hashes=sum_abunds)

query_mh.remove_many(found_mh)

sum_f_uniq_found += result.f_unique_to_query

if not found: # first result? print header.
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
print_results("--------- ------- ------- ---------")
else:
print_results("")
print_results("overlap p_query p_match")
print_results("--------- ------- -------")

found = True


# print interim result & save in `found` list for later use
pct_query = '{:.1f}%'.format(result.f_unique_weighted*100)
pct_genome = '{:.1f}%'.format(result.f_match*100)

if is_abundance:
name = result.match._display_name(screen_width - 41)
average_abund ='{:.1f}'.format(result.average_abund)
print_results('{:9} {:>7} {:>7} {:>9} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
average_abund, name)
else:
name = result.match._display_name(screen_width - 31)
print_results('{:9} {:>7} {:>7} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
name)

# write out
if result_writer is None:
result_writer = result.init_dictwriter(outfp)
result.write(result_writer)

outfp.flush()
sys.stdout.flush()

csv_writer.close()

if found:
# use last result!
if is_abundance and result:
p_covered = result.sum_weighted_found / result.total_weighted_hashes
p_covered *= 100
print_results(f'the recovered matches hit {p_covered:.1f}% of the abundance-weighted query.')

print_results(f'the recovered matches hit {sum_f_uniq_found*100:.1f}% of the query k-mers (unweighted).')


if __name__ == '__main__':
sys.exit(main())