Skip to content

Commit

Permalink
Merge pull request #751 from etetoolkit/ete4_#750
Browse files Browse the repository at this point in the history
Ete4 #750
  • Loading branch information
jordibc authored May 31, 2024
2 parents 2fabb2e + c1533ac commit ff6767d
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 54 deletions.
112 changes: 74 additions & 38 deletions ete4/gtdb_taxonomy/gtdbquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,38 @@ def _translate_merged(self, all_taxids):

# return taxid, spname, norm_score

def _get_id2rank(self, internal_taxids):
"""Given a list of numeric ids (each one representing a taxa in GTDB), return a dictionary with their corresponding ranks.
Examples:
> gtdb.get_rank([2174, 205487, 610])
{2174: 'family', 205487: 'order', 610: 'phylum'}
Note: Numeric taxids are not recognized by the official GTDB taxonomy database, only for internal usage.
"""
ids = ','.join('"%s"' % v for v in set(internal_taxids) - {None, ''})
result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
return {tax: spname for tax, spname in result.fetchall()}

def get_rank(self, taxids):
"""Return dictionary converting taxids to their GTDB taxonomy rank."""
"""Give a list of GTDB string taxids, return a dictionary with their corresponding ranks.
Examples:
> gtdb.get_rank(['c__Thorarchaeia', 'RS_GCF_001477695.1'])
{'c__Thorarchaeia': 'class', 'RS_GCF_001477695.1': 'subspecies'}
"""

taxid2rank = {}
name2ids = self._get_name_translator(taxids)
overlap_ids = name2ids.values()
taxids = [item for sublist in overlap_ids for item in sublist]
ids = ','.join('"%s"' % v for v in set(taxids) - {None, ''})
result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
return {tax: spname for tax, spname in result.fetchall()}
for tax, rank in result.fetchall():
taxid2rank[list(self._get_taxid_translator([tax]).values())[0]] = rank

return taxid2rank

def get_lineage_translator(self, taxids):
def _get_lineage_translator(self, taxids):
"""Given a valid taxid number, return its corresponding lineage track as a
hierarchically sorted list of parent taxids.
"""
Expand All @@ -164,23 +189,22 @@ def get_lineage_translator(self, taxids):
id2lineages = {}
for tax, track in result.fetchall():
id2lineages[tax] = list(map(int, reversed(track.split(","))))

return id2lineages

def get_name_lineage(self, taxnames):
"""Given a valid taxname, return its corresponding lineage track as a
hierarchically sorted list of parent taxnames.
"""
name_lineages = []
name2taxid = self.get_name_translator(taxnames)
name2taxid = self._get_name_translator(taxnames)
for key, value in name2taxid.items():
lineage = self.get_lineage(value[0])
names = self.get_taxid_translator(lineage)
lineage = self._get_lineage(value[0])
names = self._get_taxid_translator(lineage)
name_lineages.append({key:[names[taxid] for taxid in lineage]})

return name_lineages

def get_lineage(self, taxid):
def _get_lineage(self, taxid):
"""Given a valid taxid number, return its corresponding lineage track as a
hierarchically sorted list of parent taxids.
"""
Expand Down Expand Up @@ -215,7 +239,7 @@ def get_common_names(self, taxids):
id2name[tax] = common_name
return id2name

def get_taxid_translator(self, taxids, try_synonyms=True):
def _get_taxid_translator(self, taxids, try_synonyms=True):
"""Given a list of taxids, returns a dictionary with their corresponding
scientific names.
"""
Expand Down Expand Up @@ -245,7 +269,7 @@ def get_taxid_translator(self, taxids, try_synonyms=True):

return id2name

def get_name_translator(self, names):
def _get_name_translator(self, names):
"""
Given a list of taxid scientific names, returns a dictionary translating them into their corresponding taxids.
Exact name match is required for translation.
Expand Down Expand Up @@ -276,11 +300,11 @@ def get_name_translator(self, names):
#name2realname[oname] = sp
return name2id

def translate_to_names(self, taxids):
def _translate_to_names(self, taxids):
"""
Given a list of taxid numbers, returns another list with their corresponding scientific names.
"""
id2name = self.get_taxid_translator(taxids)
id2name = self._get_taxid_translator(taxids)
names = []
for sp in taxids:
names.append(id2name.get(sp, sp))
Expand All @@ -296,7 +320,7 @@ def get_descendant_taxa(self, parent, intermediate_nodes=False, rank_limit=None,
taxid = int(parent)
except ValueError:
try:
taxid = self.get_name_translator([parent])[parent][0]
taxid = self._get_name_translator([parent])[parent][0]
except KeyError:
raise ValueError('%s not found!' %parent)

Expand All @@ -322,7 +346,7 @@ def get_descendant_taxa(self, parent, intermediate_nodes=False, rank_limit=None,
elif found == 1:
return [taxid]
if rank_limit or collapse_subspecies or return_tree:
descendants_spnames = self.get_taxid_translator(list(descendants.keys()))
descendants_spnames = self._get_taxid_translator(list(descendants.keys()))
#tree = self.get_topology(list(descendants.keys()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit)
tree = self.get_topology(list(descendants_spnames.values()), intermediate_nodes=intermediate_nodes, collapse_subspecies=collapse_subspecies, rank_limit=rank_limit)
if return_tree:
Expand All @@ -333,10 +357,10 @@ def get_descendant_taxa(self, parent, intermediate_nodes=False, rank_limit=None,
return [n.name for n in tree]

elif intermediate_nodes:
return self.translate_to_names([tid for tid, count in descendants.items()])
return self._translate_to_names([tid for tid, count in descendants.items()])
else:
self.translate_to_names([tid for tid, count in descendants.items() if count == 1])
return self.translate_to_names([tid for tid, count in descendants.items() if count == 1])
self._translate_to_names([tid for tid, count in descendants.items() if count == 1])
return self._translate_to_names([tid for tid, count in descendants.items() if count == 1])

def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
collapse_subspecies=False, annotate=True):
Expand All @@ -356,7 +380,7 @@ def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
"""
from .. import PhyloTree
#taxids, merged_conversion = self._translate_merged(taxids)
tax2id = self.get_name_translator(taxnames) #{'f__Korarchaeaceae': [2174], 'o__Peptococcales': [205487], 'p__Huberarchaeota': [610]}
tax2id = self._get_name_translator(taxnames) #{'f__Korarchaeaceae': [2174], 'o__Peptococcales': [205487], 'p__Huberarchaeota': [610]}
taxids = [i[0] for i in tax2id.values()]

if len(taxids) == 1:
Expand All @@ -376,7 +400,7 @@ def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
# If root taxid is not found in postorder, must be a tip node
subtree = [root_taxid]
leaves = set([v for v, count in Counter(subtree).items() if count == 1])
tax2name = self.get_taxid_translator(list(subtree))
tax2name = self._get_taxid_translator(list(subtree))
name2tax ={spname:taxid for taxid,spname in tax2name.items()}
nodes[root_taxid] = PhyloTree({'name': str(root_taxid)})
current_parent = nodes[root_taxid]
Expand All @@ -394,15 +418,15 @@ def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
taxids = set(map(int, taxids))
sp2track = {}
elem2node = {}
id2lineage = self.get_lineage_translator(taxids)
id2lineage = self._get_lineage_translator(taxids)
all_taxids = set()
for lineage in id2lineage.values():
all_taxids.update(lineage)
id2rank = self.get_rank(all_taxids)
id2rank = self._get_id2rank(all_taxids)

tax2name = self.get_taxid_translator(taxids)
tax2name = self._get_taxid_translator(taxids)
all_taxid_codes = set([_tax for _lin in list(id2lineage.values()) for _tax in _lin])
extra_tax2name = self.get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
tax2name.update(extra_tax2name)
name2tax ={spname:taxid for taxid,spname in tax2name.items()}

Expand Down Expand Up @@ -452,12 +476,12 @@ def get_topology(self, taxnames, intermediate_nodes=False, rank_limit=None,
n.detach()

if annotate:
self.annotate_tree(tree)
self.annotate_tree(tree, ignore_unclassified=False)

return tree

def annotate_tree(self, t, taxid_attr='name',
tax2name=None, tax2track=None, tax2rank=None):
def annotate_tree(self, t, taxid_attr='name', tax2name=None,
tax2track=None, tax2rank=None, ignore_unclassified=False):
"""Annotate a tree containing taxids as leaf names.
It annotates by adding the properties 'taxid', 'sci_name',
Expand All @@ -481,7 +505,7 @@ def annotate_tree(self, t, taxid_attr='name',
try:
# translate gtdb name -> id
taxaname = getattr(n, taxid_attr, n.props.get(taxid_attr))
tid = self.get_name_translator([taxaname])[taxaname][0]
tid = self._get_name_translator([taxaname])[taxaname][0]
taxids.add(tid)
except (KeyError, ValueError, AttributeError):
pass
Expand All @@ -490,18 +514,18 @@ def annotate_tree(self, t, taxid_attr='name',
taxids, merged_conversion = self._translate_merged(taxids)

if not tax2name or taxids - set(map(int, list(tax2name.keys()))):
tax2name = self.get_taxid_translator(taxids)
tax2name = self._get_taxid_translator(taxids)
if not tax2track or taxids - set(map(int, list(tax2track.keys()))):
tax2track = self.get_lineage_translator(taxids)

tax2track = self._get_lineage_translator(taxids)
all_taxid_codes = set([_tax for _lin in list(tax2track.values()) for _tax in _lin])
extra_tax2name = self.get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
extra_tax2name = self._get_taxid_translator(list(all_taxid_codes - set(tax2name.keys())))
tax2name.update(extra_tax2name)

tax2common_name = self.get_common_names(tax2name.keys())

if not tax2rank:
tax2rank = self.get_rank(list(tax2name.keys()))
tax2rank = self._get_id2rank(list(tax2name.keys()))

name2tax ={spname:taxid for taxid,spname in tax2name.items()}
n2leaves = t.get_cached_content()
Expand All @@ -512,10 +536,8 @@ def annotate_tree(self, t, taxid_attr='name',
else:
node_taxid = None
node.add_prop('taxid', node_taxid)

if node_taxid:
tmp_taxid = self.get_name_translator([node_taxid]).get(node_taxid, [None])[0]

tmp_taxid = self._get_name_translator([node_taxid]).get(node_taxid, [None])[0]
if node_taxid in merged_conversion:
node_taxid = merged_conversion[node_taxid]

Expand All @@ -539,16 +561,30 @@ def annotate_tree(self, t, taxid_attr='name',
rank = 'Unknown',
named_lineage = [])
else:
lineage = self._common_lineage([lf.props.get('lineage') for lf in n2leaves[node]])

if ignore_unclassified:
vectors = [lf.props.get('lineage') for lf in n2leaves[node] if lf.props.get('lineage')]
else:
vectors = [lf.props.get('lineage') for lf in n2leaves[node]]
lineage = self._common_lineage(vectors)

rank = tax2rank.get(lineage[-1], 'Unknown')

if lineage[-1]:
ancestor = self.get_taxid_translator([lineage[-1]])[lineage[-1]]
if rank != 'subspecies':
ancestor = self._get_taxid_translator([lineage[-1]])[lineage[-1]]
else:
ancestor = self._get_taxid_translator([lineage[-2]])[lineage[-2]]
lineage = lineage[:-1] # remove subspecies from lineage
rank = tax2rank.get(lineage[-1], 'Unknown') # update rank
else:
ancestor = None

node.add_props(sci_name = tax2name.get(ancestor, str(ancestor)),
common_name = tax2common_name.get(lineage[-1], ''),
taxid = ancestor,
lineage = lineage,
rank = tax2rank.get(lineage[-1], 'Unknown'),
rank = rank,
named_lineage = [tax2name.get(tax, str(tax)) for tax in lineage])

return tax2name, tax2track, tax2rank
Expand Down
18 changes: 11 additions & 7 deletions ete4/ncbi_taxonomy/ncbiquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_topology(self, taxids, intermediate_nodes=False, rank_limit=None,
return tree

def annotate_tree(self, t, taxid_attr="name", tax2name=None,
tax2track=None, tax2rank=None):
tax2track=None, tax2rank=None, ignore_unclassified=False):
"""Annotate a tree containing taxids as leaf names.
The annotation adds the properties: 'taxid', 'sci_name',
Expand Down Expand Up @@ -521,14 +521,18 @@ def annotate_tree(self, t, taxid_attr="name", tax2name=None,
rank = 'Unknown',
named_lineage = [])
else:
lineage = self._common_lineage([lf.props.get('lineage') for lf in n2leaves[n]])
if ignore_unclassified:
vectors = [lf.props.get('lineage') for lf in n2leaves[n] if lf.props.get('lineage')]
else:
vectors = [lf.props.get('lineage') for lf in n2leaves[n]]
lineage = self._common_lineage(vectors)
ancestor = lineage[-1]
n.add_props(sci_name = tax2name.get(ancestor, str(ancestor)),
common_name = tax2common_name.get(ancestor, ''),
taxid = ancestor,
lineage = lineage,
rank = tax2rank.get(ancestor, 'Unknown'),
named_lineage = [tax2name.get(tax, str(tax)) for tax in lineage])
common_name = tax2common_name.get(ancestor, ''),
taxid = ancestor,
lineage = lineage,
rank = tax2rank.get(ancestor, 'Unknown'),
named_lineage = [tax2name.get(tax, str(tax)) for tax in lineage])

return tax2name, tax2track, tax2rank

Expand Down
8 changes: 4 additions & 4 deletions ete4/phylo/phylotree.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def collapse_lineage_specific_expansions(self, species=None, return_copy=True):
return prunned


def annotate_ncbi_taxa(self, taxid_attr='species', tax2name=None, tax2track=None, tax2rank=None, dbfile=None):
def annotate_ncbi_taxa(self, taxid_attr='species', tax2name=None, tax2track=None, tax2rank=None, dbfile=None, ignore_unclassified=False):
"""Add NCBI taxonomy annotation to all descendant nodes. Leaf nodes are
expected to contain a feature (name, by default) encoding a valid taxid
number.
Expand Down Expand Up @@ -694,11 +694,11 @@ def annotate_ncbi_taxa(self, taxid_attr='species', tax2name=None, tax2track=None
"""

ncbi = NCBITaxa(dbfile=dbfile)
return ncbi.annotate_tree(self, taxid_attr=taxid_attr, tax2name=tax2name, tax2track=tax2track, tax2rank=tax2rank)
return ncbi.annotate_tree(self, taxid_attr=taxid_attr, tax2name=tax2name, tax2track=tax2track, tax2rank=tax2rank, ignore_unclassified=ignore_unclassified)

def annotate_gtdb_taxa(self, taxid_attr='species', tax2name=None, tax2track=None, tax2rank=None, dbfile=None):
def annotate_gtdb_taxa(self, taxid_attr='species', tax2name=None, tax2track=None, tax2rank=None, dbfile=None, ignore_unclassified=False):
gtdb = GTDBTaxa(dbfile=dbfile)
return gtdb.annotate_tree(self, taxid_attr=taxid_attr, tax2name=tax2name, tax2track=tax2track, tax2rank=tax2rank)
return gtdb.annotate_tree(self, taxid_attr=taxid_attr, tax2name=tax2name, tax2track=tax2track, tax2rank=tax2rank, ignore_unclassified=ignore_unclassified)

def ncbi_compare(self, autodetect_duplications=True, cached_content=None):
if not cached_content:
Expand Down
Loading

0 comments on commit ff6767d

Please sign in to comment.