Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] New embeddings API #1023

Merged
merged 15 commits into from
Apr 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
import tiledbsoma as soma


def _get_experiment_name(organism: str) -> str:
"""Given an organism name, return the experiment name."""
# lower/snake case the organism name to find the experiment name
return re.sub(r"[ ]+", "_", organism).lower()


def _get_experiment(census: soma.Collection, organism: str) -> soma.Experiment:
"""Given a census :class:`tiledbsoma.Collection`, return the experiment for the named organism.
Organism matching is somewhat flexible, attempting to map from human-friendly
Expand Down Expand Up @@ -39,8 +45,7 @@ def _get_experiment(census: soma.Collection, organism: str) -> soma.Experiment:

>>> human = get_experiment(census, "homo_sapiens")
"""
# lower/snake case the organism name to find the experiment name
exp_name = re.sub(r"[ ]+", "_", organism).lower()
exp_name = _get_experiment_name(organism)

if exp_name not in census["census_data"]:
raise ValueError(f"Unknown organism {organism} - does not exist")
Expand Down
47 changes: 44 additions & 3 deletions api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import tiledbsoma as soma
from somacore.options import SparseDFCoord

from ._experiment import _get_experiment
from ._experiment import _get_experiment, _get_experiment_name
from ._release_directory import get_census_version_directory
from ._util import _extract_census_version

CENSUS_EMBEDDINGS_LOCATION_BASE_URI = "s3://cellxgene-contrib-public/contrib/cell-census/soma"


def get_anndata(
Expand All @@ -27,6 +31,8 @@ def get_anndata(
var_value_filter: Optional[str] = None,
var_coords: Optional[SparseDFCoord] = None,
column_names: Optional[soma.AxisColumnNames] = None,
add_obs_embeddings: Optional[Sequence[str]] = (),
add_var_embeddings: Optional[Sequence[str]] = (),
) -> anndata.AnnData:
"""Convenience wrapper around :class:`tiledbsoma.Experiment` query, to build and execute a query,
and return it as an :class:`anndata.AnnData` object.
Expand Down Expand Up @@ -58,12 +64,20 @@ def get_anndata(
Columns to fetch for ``obs`` and ``var`` dataframes.
obsm_layers:
Additional obsm layers to read and return in the ``obsm`` slot.
add_obs_embeddings:
Embeddings to be returned as part of the ``obsm`` slot.
Use :func:`get_all_available_embeddings` to retrieve available embeddings
for this Census version and organism.
add_var_embeddings:
Embeddings to be returned as part of the ``varm`` slot.
Use :func:`get_all_available_embeddings` to retrieve available embeddings
for this Census version and organism.
ebezzi marked this conversation as resolved.
Show resolved Hide resolved

Returns:
An :class:`anndata.AnnData` object containing the census slice.

Lifecycle:
maturing
experimental
ebezzi marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> get_anndata(census, "Mus musculus", obs_value_filter="tissue_general in ['brain', 'lung']")
Expand All @@ -75,14 +89,41 @@ def get_anndata(
exp = _get_experiment(census, organism)
obs_coords = (slice(None),) if obs_coords is None else (obs_coords,)
var_coords = (slice(None),) if var_coords is None else (var_coords,)

with exp.axis_query(
measurement_name,
obs_query=soma.AxisQuery(value_filter=obs_value_filter, coords=obs_coords),
var_query=soma.AxisQuery(value_filter=var_value_filter, coords=var_coords),
) as query:
return query.to_anndata(
adata = query.to_anndata(
X_name=X_name,
column_names=column_names,
X_layers=X_layers,
obsm_layers=obsm_layers,
)

# If add_obs_embeddings or add_var_embeddings are defined, inject them in the appropriate slot
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
if add_obs_embeddings is not None or add_var_embeddings is not None:
from .experimental._embedding import _get_embedding, get_embedding_metadata_by_name

census_version = _extract_census_version(census)
experiment_name = _get_experiment_name(organism)
census_directory = get_census_version_directory()

if add_obs_embeddings is not None:
obs_soma_joinids = query.obs_joinids()
for emb in add_obs_embeddings:
emb_metadata = get_embedding_metadata_by_name(emb, experiment_name, census_version, "obs_embedding")
uri = f"{CENSUS_EMBEDDINGS_LOCATION_BASE_URI}/{census_version}/{emb_metadata['id']}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't these use urljoin()?

embedding = _get_embedding(census, census_directory, census_version, uri, obs_soma_joinids)
adata.obsm[emb] = embedding

if add_var_embeddings is not None:
var_soma_joinids = query.var_joinids()
for emb in add_var_embeddings:
emb_metadata = get_embedding_metadata_by_name(emb, experiment_name, census_version, "var_embedding")
uri = f"{CENSUS_EMBEDDINGS_LOCATION_BASE_URI}/{census_version}/{emb_metadata['id']}"
embedding = _get_embedding(census, census_directory, census_version, uri, var_soma_joinids)
adata.varm[emb] = embedding

return adata
8 changes: 8 additions & 0 deletions api/python/cellxgene_census/src/cellxgene_census/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import urllib.parse

import tiledbsoma as soma


def _uri_join(base: str, url: str) -> str:
"""Like urllib.parse.urljoin, but doesn't get confused by s3://."""
Expand All @@ -18,3 +20,9 @@ def _uri_join(base: str, url: str) -> str:
p_url.fragment,
]
return urllib.parse.urlunparse(parts)


def _extract_census_version(census: soma.Collection) -> str:
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
"""Extract the Census version from the given Census object."""
version: str = urllib.parse.urlparse(census.uri).path.split("/")[2]
return version
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""Experimental API for the CELLxGENE Discover Census."""

from ._embedding import get_embedding, get_embedding_metadata
from ._embedding import (
get_all_available_embeddings,
get_all_census_versions_with_embedding,
get_embedding,
get_embedding_metadata,
get_embedding_metadata_by_name,
)

__all__ = [
"get_embedding",
"get_embedding_metadata",
"get_embedding_metadata_by_name",
"get_all_available_embeddings",
"get_all_census_versions_with_embedding",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import requests
import tiledbsoma as soma

from .._open import get_default_soma_context, open_soma
from .._release_directory import get_census_version_directory
from .._release_directory import CensusVersionDescription, CensusVersionName, get_census_version_directory

CELL_CENSUS_EMBEDDINGS_MANIFEST_URL = "https://contrib.cellxgene.cziscience.com/contrib/cell-census/contributions.json"


def get_embedding_metadata(embedding_uri: str, context: soma.options.SOMATileDBContext | None = None) -> dict[str, Any]:
Expand Down Expand Up @@ -47,6 +50,61 @@ def get_embedding_metadata(embedding_uri: str, context: soma.options.SOMATileDBC
return cast(Dict[str, Any], embedding_metadata)


def _get_embedding(
census: soma.Collection,
census_directory: dict[CensusVersionName, CensusVersionDescription],
census_version: str,
embedding_uri: str,
obs_soma_joinids: npt.NDArray[np.int64] | pa.Array,
context: soma.options.SOMATileDBContext | None = None,
) -> npt.NDArray[np.float32]:
"""Private. Like get_embedding, but accepts a Census object and a Census directory."""
if isinstance(obs_soma_joinids, (pa.Array, pa.ChunkedArray, pd.Series)):
obs_soma_joinids = obs_soma_joinids.to_numpy()
assert isinstance(obs_soma_joinids, np.ndarray)
if obs_soma_joinids.dtype != np.int64:
raise TypeError("obs_soma_joinids must be array of int64")

# Allow the user to override context for exceptional cases (e.g. the aws region)
context = context or get_default_soma_context()

# Attempt to resolve census version aliases
resolved_census_version = census_directory.get(census_version, None)

with soma.open(embedding_uri, context=context) as E:
embedding_metadata = json.loads(E.metadata["CxG_embedding_info"])

if resolved_census_version is None:
warnings.warn(
"Unable to determine Census version - skipping validation of Census and embedding version.",
stacklevel=1,
)
elif resolved_census_version != census_directory.get(embedding_metadata["census_version"], None):
raise ValueError("Census and embedding mismatch - census_version not equal")

with open_soma(census_version=census_version, context=context) as census:
experiment_name = embedding_metadata["experiment_name"]
if experiment_name not in census["census_data"]:
raise ValueError("Census and embedding mismatch - experiment_name does not exist")
measurement_name = embedding_metadata["measurement_name"]
if measurement_name not in census["census_data"][experiment_name].ms:
raise ValueError("Census and embedding mismatch - measurement_name does not exist")

embedding_shape = (len(obs_soma_joinids), E.shape[1])
embedding = np.full(embedding_shape, np.NaN, dtype=np.float32, order="C")

obs_indexer = soma.tiledbsoma_build_index(obs_soma_joinids, context=E.context)
for tbl in E.read(coords=(obs_soma_joinids,)).tables():
obs_idx = obs_indexer.get_indexer(tbl.column("soma_dim_0").to_numpy())
feat_idx = tbl.column("soma_dim_1").to_numpy()
emb = tbl.column("soma_data")

indices = obs_idx * E.shape[1] + feat_idx
np.put(embedding.reshape(-1), indices, emb)

return embedding


def get_embedding(
census_version: str,
embedding_uri: str,
Expand Down Expand Up @@ -91,48 +149,119 @@ def get_embedding(
dtype=float32)

"""
if isinstance(obs_soma_joinids, (pa.Array, pa.ChunkedArray, pd.Series)):
obs_soma_joinids = obs_soma_joinids.to_numpy()
assert isinstance(obs_soma_joinids, np.ndarray)
if obs_soma_joinids.dtype != np.int64:
raise TypeError("obs_soma_joinids must be array of int64")
census_directory = get_census_version_directory()

# Allow the user to override context for exceptional cases (e.g. the aws region)
context = context or get_default_soma_context()
with open_soma(census_version=census_version, context=context) as census:
return _get_embedding(
census, census_directory, census_version, embedding_uri, obs_soma_joinids, context=context
)

# Attempt to resolve census version aliases
census_directory = get_census_version_directory()
resolved_census_version = census_directory.get(census_version, None)

with soma.open(embedding_uri, context=context) as E:
embedding_metadata = json.loads(E.metadata["CxG_embedding_info"])
def get_embedding_metadata_by_name(
embedding_name: str, organism: str, census_version: str, embedding_type: str | None = "obs_embedding"
) -> dict[str, Any]:
"""Return metadata for a specific embedding. If more embeddings match the query parameters,
the most recent one will be returned.

if resolved_census_version is None:
warnings.warn(
"Unable to determine Census version - skipping validation of Census and embedding version.",
stacklevel=1,
)
elif resolved_census_version != census_directory.get(embedding_metadata["census_version"], None):
raise ValueError("Census and embedding mismatch - census_version not equal")
Args:
embedding_name:
The name of the embedding, e.g. "scvi".
organism:
The organism for which the embedding is associated.
census_version:
The Census version tag, e.g., ``"2023-12-15"``.
embedding_type:
Either "obs_embedding" or "var_embedding". Defaults to "obs_embedding".

with open_soma(census_version=census_version, context=context) as census:
experiment_name = embedding_metadata["experiment_name"]
if experiment_name not in census["census_data"]:
raise ValueError("Census and embedding mismatch - experiment_name does not exist")
measurement_name = embedding_metadata["measurement_name"]
if measurement_name not in census["census_data"][experiment_name].ms:
raise ValueError("Census and embedding mismatch - measurement_name does not exist")
Returns:
A dictionary containing metadata describing the embedding.

embedding_shape = (len(obs_soma_joinids), E.shape[1])
embedding = np.full(embedding_shape, np.NaN, dtype=np.float32, order="C")
Raises:
ValueError: if no embeddings are found for the specified query parameters.

obs_indexer = soma.tiledbsoma_build_index(obs_soma_joinids, context=E.context)
for tbl in E.read(coords=(obs_soma_joinids,)).tables():
obs_idx = obs_indexer.get_indexer(tbl.column("soma_dim_0").to_numpy())
feat_idx = tbl.column("soma_dim_1").to_numpy()
emb = tbl.column("soma_data")
"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

indices = obs_idx * E.shape[1] + feat_idx
np.put(embedding.reshape(-1), indices, emb)
manifest = cast(dict[str, dict[str, Any]], response.json())
embeddings = []
for _, obj in manifest.items():
if (
obj["embedding_name"] == embedding_name
and obj["experiment_name"] == organism
and obj["data_type"] == embedding_type
and obj["census_version"] == census_version
):
embeddings.append(obj)

return embedding
if len(embeddings) == 0:
raise ValueError(f"No embeddings found for {embedding_name}, {organism}, {census_version}, {embedding_type}")

return sorted(embeddings, key=lambda x: x["submission_date"])[-1]


def get_all_available_embeddings(census_version: str) -> list[dict[str, Any]]:
"""Return a dictionary of all available embeddings for a given Census version.

Args:
census_version:
The Census version tag, e.g., ``"2023-12-15"``.

Returns:
A list of dictionaries, each containing metadata describing an available embedding.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Options:

  1. Return a subset of the metadata that only has relevant information (name, organism, etc). The example listed here is only for reference
  2. Return the full metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strongly prefer 1) and having a verbose argument


Examples:
>>> get_all_available_embeddings('2023-12-15')
[{
'experiment_name': 'experiment_1',
'measurement_name': 'RNA',
'organism': "homo_sapiens",
'census_version': '2023-12-15',
'n_embeddings': 1000,
'n_features': 200,
'uri': 's3://bucket/embedding_1'
}]

"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

embeddings = []
manifest = response.json()
for _, obj in manifest.items():
if obj["census_version"] == census_version:
embeddings.append(obj)

return embeddings


def get_all_census_versions_with_embedding(
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
embedding_name: str, organism: str, embedding_type: str | None = "obs_embedding"
) -> list[str]:
"""Get a list of all census versions that contain a specific embedding.

Args:
embedding_name:
The name of the embedding, e.g. "scvi".
organism:
The organism for which the embedding is associated.
embedding_type:
The type of embedding. Defaults to "obs_embedding".

Returns:
A list of census versions that contain the specified embedding.
"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

versions = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole blob could be a simple comprehension which might make it more pythonic (this is a nit, up to you). E.g,

return sorted({ obj['census_version'] for obj in manifest.values() if ...  })

And unless there are duplicates expected, I'm not sure what the set adds? If there are duplicates, doesn't that imply you need more filter criteria?

Copy link
Member Author

@ebezzi ebezzi Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set is because multiple embeddings can exist for a single alias, and in this case we're only interested in the census version string, so it needs to be deduplicated. I'll rewrite using the comprehension.

manifest = response.json()
for _, obj in manifest.items():
if (
obj["embedding_name"] == embedding_name
and obj["experiment_name"] == organism
and obj["data_type"] == embedding_type
):
versions.add(obj["census_version"])

return sorted(versions)
Loading
Loading