Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] New embeddings API #1023

Merged
merged 15 commits into from
Apr 1, 2024
36 changes: 35 additions & 1 deletion api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
import tiledbsoma as soma
from somacore.options import SparseDFCoord

from ._util import _extract_census_version

from ._experiment import _get_experiment

CENSUS_EMBEDDINGS_LOCATION_BASE_URI = "s3://cellxgene-contrib-public/contrib/cell-census/soma"


def get_anndata(
census: soma.Collection,
Expand All @@ -22,11 +26,14 @@ def get_anndata(
X_name: str = "raw",
X_layers: Optional[Sequence[str]] = (),
obsm_layers: Optional[Sequence[str]] = (),
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
varm_layers: Optional[Sequence[str]] = (),
obs_value_filter: Optional[str] = None,
obs_coords: Optional[SparseDFCoord] = None,
var_value_filter: Optional[str] = None,
var_coords: Optional[SparseDFCoord] = None,
column_names: Optional[soma.AxisColumnNames] = None,
add_obs_embeddings: Optional[Sequence[str]] = (),
add_var_embeddings: Optional[Sequence[str]] = (),
) -> anndata.AnnData:
"""Convenience wrapper around :class:`tiledbsoma.Experiment` query, to build and execute a query,
and return it as an :class:`anndata.AnnData` object.
Expand Down Expand Up @@ -58,6 +65,16 @@ def get_anndata(
Columns to fetch for ``obs`` and ``var`` dataframes.
obsm_layers:
Additional obsm layers to read and return in the ``obsm`` slot.
varm_layers:
Additional varm layers to read and return in the ``varm`` slot.
add_obs_embeddings:
Embeddings to be returned as part of the ``obsm`` slot.
Use :func:`get_all_available_embeddings` to retrieve available embeddings
for this Census version and organism.
add_var_embeddings:
Embeddings to be returned as part of the ``varm`` slot.
Use :func:`get_all_available_embeddings` to retrieve available embeddings
for this Census version and organism.
ebezzi marked this conversation as resolved.
Show resolved Hide resolved

Returns:
An :class:`anndata.AnnData` object containing the census slice.
Expand All @@ -75,14 +92,31 @@ def get_anndata(
exp = _get_experiment(census, organism)
obs_coords = (slice(None),) if obs_coords is None else (obs_coords,)
var_coords = (slice(None),) if var_coords is None else (var_coords,)

with exp.axis_query(
measurement_name,
obs_query=soma.AxisQuery(value_filter=obs_value_filter, coords=obs_coords),
var_query=soma.AxisQuery(value_filter=var_value_filter, coords=var_coords),
) as query:
return query.to_anndata(
adata = query.to_anndata(
X_name=X_name,
column_names=column_names,
X_layers=X_layers,
obsm_layers=obsm_layers,
varm_layers=varm_layers
)

# If add_obs_embeddings or add_var_embeddings are defined, inject them in the appropriate slot
ebezzi marked this conversation as resolved.
Show resolved Hide resolved
if add_obs_embeddings or add_var_embeddings:
obs_soma_joinids = query.obs_joinids()
from cellxgene_census.experimental import get_embedding, get_embedding_metadata_by_name
census_version = _extract_census_version(census)
for emb in add_obs_embeddings:
emb_metadata = get_embedding_metadata_by_name(emb, organism, census_version, "obs_embedding")
uri = f"{CENSUS_EMBEDDINGS_LOCATION_BASE_URI}/{census_version}/{emb_metadata['id']}"
embedding = get_embedding(census_version, uri, obs_soma_joinids)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this will cause the census object to be re-opened. While this shouldn't be an issue, it will result into an extra call. With some effort I can refactor get_embedding to also accept an existing Census object, but I'm not sure if it's worth it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, you should refactor the code to have a (common, shared) function that accepts an already open Census handle

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

adata.obsm[emb] = embedding

return adata


5 changes: 5 additions & 0 deletions api/python/cellxgene_census/src/cellxgene_census/_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import urllib.parse
import tiledbsoma as soma


def _uri_join(base: str, url: str) -> str:
Expand All @@ -18,3 +19,7 @@ def _uri_join(base: str, url: str) -> str:
p_url.fragment,
]
return urllib.parse.urlunparse(parts)

def _extract_census_version(census: soma.Collection):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a live corpus unit test for this method. This should ensure that this parsing method remains consistent across releases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code has some lint - please run (an up to date) pre-commit across it

"""Extract the Census version from the given Census object."""
return urllib.parse.urlparse(census.uri).path.split("/")[2]
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Experimental API for the CELLxGENE Discover Census."""

from ._embedding import get_embedding, get_embedding_metadata
from ._embedding import get_embedding, get_embedding_metadata, get_embedding_metadata_by_name, get_all_available_embeddings, get_all_census_versions_with_embedding

__all__ = [
"get_embedding",
"get_embedding_metadata",
"get_embedding_metadata_by_name",
"get_all_available_embeddings",
"get_all_census_versions_with_embedding",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import pandas as pd
import pyarrow as pa
import tiledbsoma as soma
import requests

from .._open import get_default_soma_context, open_soma
from .._release_directory import get_census_version_directory

CELL_CENSUS_EMBEDDINGS_MANIFEST_URL = "https://contrib.cellxgene.cziscience.com/contrib/cell-census/contributions.json"


def get_embedding_metadata(embedding_uri: str, context: soma.options.SOMATileDBContext | None = None) -> dict[str, Any]:
"""Read embedding metadata and return as a Python dict.
Expand Down Expand Up @@ -136,3 +139,98 @@ def get_embedding(
np.put(embedding.reshape(-1), indices, emb)

return embedding

def get_embedding_metadata_by_name(embedding_name: str, organism: str, census_version: str, embedding_type: str | None = "obs_embedding") -> dict[str, Any]:
"""Return metadata for a specific embedding. If more embeddings match the query parameters,
the most recent one will be returned.

Args:
embedding_name:
The name of the embedding, e.g. "scvi".
organism:
The organism for which the embedding is associated.
census_version:
The Census version tag, e.g., ``"2023-12-15"``.
embedding_type:
Either "obs_embedding" or "var_embedding". Defaults to "obs_embedding".

Returns:
A dictionary containing metadata describing the embedding.

Raises:
ValueError: if no embeddings are found for the specified query parameters.

"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

manifest = response.json()
embeddings = []
for _, obj in manifest.items():
if obj["embedding_name"] == embedding_name and obj["experiment_name"] == organism and obj["data_type"] == embedding_type and obj["census_version"] == census_version:
embeddings.append(obj)

if len(embeddings) == 0:
raise ValueError(f"No embeddings found for {embedding_name}, {organism}, {census_version}, {embedding_type}")

return sorted(embeddings, key=lambda x: x["submission_date"])[-1]

def get_all_available_embeddings(census_version: str) -> list[dict[str, Any]]:
"""Return a dictionary of all available embeddings for a given Census version.

Args:
census_version:
The Census version tag, e.g., ``"2023-12-15"``.

Returns:
A list of dictionaries, each containing metadata describing an available embedding.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Options:

  1. Return a subset of the metadata that only has relevant information (name, organism, etc). The example listed here is only for reference
  2. Return the full metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strongly prefer 1) and having a verbose argument


Examples:
>>> get_all_available_embeddings('2023-12-15')
[{
'experiment_name': 'experiment_1',
'measurement_name': 'RNA',
'organism': "homo_sapiens",
'census_version': '2023-12-15',
'n_embeddings': 1000,
'n_features': 200,
'uri': 's3://bucket/embedding_1'
}]

"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

embeddings = []
manifest = response.json()
for _, obj in manifest.items():
if obj["census_version"] == census_version:
embeddings.append(obj)

return embeddings

def get_all_census_versions_with_embedding(embedding_name: str, organism: str, embedding_type: str | None = "obs_embedding") -> list[str]:
"""
Get a list of all census versions that contain a specific embedding.

Args:
embedding_name:
The name of the embedding, e.g. "scvi".
organism:
The organism for which the embedding is associated.
embedding_type:
The type of embedding. Defaults to "obs_embedding".

Returns:
A list of census versions that contain the specified embedding.
"""
response = requests.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL)
response.raise_for_status()

versions = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole blob could be a simple comprehension which might make it more pythonic (this is a nit, up to you). E.g,

return sorted({ obj['census_version'] for obj in manifest.values() if ...  })

And unless there are duplicates expected, I'm not sure what the set adds? If there are duplicates, doesn't that imply you need more filter criteria?

Copy link
Member Author

@ebezzi ebezzi Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set is because multiple embeddings can exist for a single alias, and in this case we're only interested in the census version string, so it needs to be deduplicated. I'll rewrite using the comprehension.

manifest = response.json()
for _, obj in manifest.items():
if obj["embedding_name"] == embedding_name and obj["experiment_name"] == organism and obj["data_type"] == embedding_type:
versions.add(obj["census_version"])

return sorted(list(versions))
163 changes: 163 additions & 0 deletions api/python/cellxgene_census/tests/experimental/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import pytest
import requests_mock as rm

from cellxgene_census.experimental import get_all_available_embeddings, get_all_census_versions_with_embedding, get_embedding_metadata_by_name

from cellxgene_census.experimental._embedding import CELL_CENSUS_EMBEDDINGS_MANIFEST_URL


def test_get_embedding_metadata_by_name(requests_mock: rm.Mocker) -> None:
mock_embeddings = {
"embedding-id-1": {
"id": "embedding-id-1",
"embedding_name": "emb_1",
"title": "Embedding 1",
"description": "First embedding",
"experiment_name": "homo_sapiens",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
"submission_date": "2023-11-15"
},
"embedding-id-2": {
"id": "embedding-id-2",
"embedding_name": "emb_1",
"title": "Embedding 2",
"description": "Second embedding",
"experiment_name": "homo_sapiens",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
"submission_date": "2023-12-31",
},
"embedding-id-3": {
"id": "embedding-id-3",
"embedding_name": "emb_3",
"title": "Embedding 3",
"description": "Third embedding",
"experiment_name": "homo_sapiens",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
"submission_date": "2023-11-15",
},
}
requests_mock.real_http = True
requests_mock.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL, json=mock_embeddings)

embedding = get_embedding_metadata_by_name("emb_1", organism = "homo_sapiens", census_version = "2023-12-15", embedding_type = "obs_embedding")
assert embedding is not None
assert embedding["id"] == "embedding-id-2" # most recent version
assert embedding == mock_embeddings["embedding-id-2"]

embedding = get_embedding_metadata_by_name("emb_3", organism = "homo_sapiens", census_version = "2023-12-15", embedding_type = "obs_embedding")
assert embedding is not None
assert embedding["id"] == "embedding-id-3"
assert embedding == mock_embeddings["embedding-id-3"]

with pytest.raises(ValueError):
get_embedding_metadata_by_name("emb_2", organism = "homo_sapiens", census_version = "2023-12-15", embedding_type = "obs_embedding")
get_embedding_metadata_by_name("emb_1", organism = "mus_musculus", census_version = "2023-12-15", embedding_type = "obs_embedding")
get_embedding_metadata_by_name("emb_1", organism = "homo_sapiens", census_version = "2023-10-15", embedding_type = "obs_embedding")
get_embedding_metadata_by_name("emb_1", organism = "mus_musculus", census_version = "2023-12-15", embedding_type = "var_embedding")



def test_get_all_available_embeddings(requests_mock: rm.Mocker) -> None:
mock_embeddings = {
"embedding-id-1": {
"id": "embedding-id-1",
"embedding_name": "emb_1",
"title": "Embedding 1",
"description": "First embedding",
"experiment_name": "homo_sapiens",
"measurement_name": "RNA",
"n_embeddings": 1000,
"n_features": 200,
"data_type": "obs_embedding",
"census_version": "2023-12-15",
},
"embedding-id-2": {
"id": "embedding-id-2",
"embedding_name": "emb_2",
"title": "Embedding 2",
"description": "Second embedding",
"experiment_name": "homo_sapiens",
"measurement_name": "RNA",
"n_embeddings": 1000,
"n_features": 200,
"data_type": "obs_embedding",
"census_version": "2023-12-15",
},
}
requests_mock.real_http = True
requests_mock.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL, json=mock_embeddings)

embeddings = get_all_available_embeddings("2023-12-15")
assert embeddings is not None
assert len(embeddings) == 2

# Query for a non existing version of the Census
embeddings = get_all_available_embeddings("2024-12-15")
assert len(embeddings) == 0


def test_get_all_census_versions_with_embedding(requests_mock: rm.Mocker) -> None:
mock_embeddings = {
"embedding-id-1": {
"id": "embedding-id-1",
"embedding_name": "emb_1",
"title": "Embedding 1",
"description": "First embedding",
"experiment_name": "homo_sapiens",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
},
"embedding-id-2": {
"id": "embedding-id-2",
"embedding_name": "emb_1",
"title": "Embedding 2",
"description": "Second embedding",
"experiment_name": "homo_sapiens",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
},
"embedding-id-3": {
"id": "embedding-id-3",
"embedding_name": "emb_1",
"title": "Embedding 3",
"description": "Third embedding",
"experiment_name": "mus_musculus",
"data_type": "obs_embedding",
"census_version": "2023-12-15",
},
"embedding-id-4": {
"id": "embedding-id-4",
"embedding_name": "emb_1",
"title": "Embedding 4",
"description": "Fourth embedding",
"experiment_name": "mus_musculus",
"data_type": "obs_embedding",
"census_version": "2024-01-01",
},
"embedding-id-5": {
"id": "embedding-id-5",
"embedding_name": "emb_2",
"title": "Embedding 5",
"description": "Fifth embedding",
"experiment_name": "mus_musculus",
"data_type": "var_embedding",
"census_version": "2023-12-15",
},
}
requests_mock.real_http = True
requests_mock.get(CELL_CENSUS_EMBEDDINGS_MANIFEST_URL, json=mock_embeddings)

versions = get_all_census_versions_with_embedding("emb_1", organism="homo_sapiens", embedding_type="obs_embedding")
assert versions == ["2023-12-15"]

versions = get_all_census_versions_with_embedding("emb_1", organism="mus_musculus", embedding_type="obs_embedding")
assert versions == ["2023-12-15", "2024-01-01"]

versions = get_all_census_versions_with_embedding("emb_1", organism="mus_musculus", embedding_type="var_embedding")
assert versions == []

versions = get_all_census_versions_with_embedding("emb_2", organism="mus_musculus", embedding_type="var_embedding")
assert versions == ["2023-12-15"]
Loading
Loading