Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor clean_obs_names #532

Merged
merged 11 commits into from
Jul 14, 2021
90 changes: 47 additions & 43 deletions scvelo/core/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@
from scipy.sparse import csr_matrix, issparse, spmatrix

from anndata import AnnData
from scanpy._utils import deprecated_arg_names

from scvelo import logging as logg
from ._arithmetic import sum


@deprecated_arg_names(
{"data": "adata", "copy": "inplace", "ID_length": "id_length", "base": "alphabet"}
)
def clean_obs_names(
data: AnnData,
base: str = "[AGTCBDHKMNRSVWY]",
ID_length: int = 12,
copy: bool = False,
adata: AnnData,
alphabet: str = "[AGTCBDHKMNRSVWY]",
id_length: int = 12,
inplace: bool = True,
) -> Optional[AnnData]:
"""Clean up the obs_names.

Expand All @@ -31,14 +35,14 @@ def clean_obs_names(

Arguments
---------
data
adata
Annotated data matrix.
base
alphabet
Genetic code letters to be identified.
ID_length
id_length
Length of the Genetic Codes in the samples.
copy
Return a copy instead of writing to adata.
inplace
Whether to update `adata` inplace or not.

Returns
-------
Expand All @@ -50,53 +54,53 @@ def clean_obs_names(
names of the identified sample batches
"""

def get_base_list(name, base):
base_list = base
while re.search(base_list + base, name) is not None:
base_list += base
if len(base_list) == 0:
raise ValueError("Encountered an invalid ID in obs_names: ", name)
return base_list

adata = data.copy() if copy else data

names = adata.obs_names
base_list = get_base_list(names[0], base)
if not inplace:
adata = adata.copy()

if len(np.unique([len(name) for name in adata.obs_names])) == 1:
start, end = re.search(base_list, names[0]).span()
newIDs = [name[start:end] for name in names]
start, end = 0, len(newIDs[0])
for i in range(end - ID_length):
if np.any([ID[i] not in base for ID in newIDs]):
if adata.obs_names.map(len).unique().size == 1:
start = re.search(alphabet, adata.obs_names[0]).start()
end = start + re.search(f"{alphabet}*", adata.obs_names[0][start:]).end()
new_obs_names = [obs_name[start:end] for obs_name in adata.obs_names]
start, end = 0, len(new_obs_names[0])
for i in range(end - id_length):
if np.any(
[new_obs_name[i] not in alphabet for new_obs_name in new_obs_names]
):
start += 1
if np.any([ID[::-1][i] not in base for ID in newIDs]):
if np.any(
[
new_obs_name[::-1][i] not in alphabet
for new_obs_name in new_obs_names
]
):
end -= 1

newIDs = [ID[start:end] for ID in newIDs]
prefixes = [names[i].replace(newIDs[i], "") for i in range(len(names))]
new_obs_names = [new_obs_name[start:end] for new_obs_name in new_obs_names]
prefixes = [
obs_name.replace(new_obs_name, "")
for obs_name, new_obs_name in zip(adata.obs_names, new_obs_names)
]
else:
prefixes, newIDs = [], []
for name in names:
match = re.search(base_list, name)
newID = (
re.search(get_base_list(name, base), name).group()
if match is None
else match.group()
)
newIDs.append(newID)
prefixes.append(name.replace(newID, ""))

adata.obs_names = newIDs
def rename_obs(obs_name):
start = re.search(alphabet, obs_name).start()
new_obs_name = re.search(f"{alphabet}*", obs_name[start:]).group()
return new_obs_name, obs_name.replace(new_obs_name, "")

new_obs_names, prefixes = zip(*adata.obs_names.map(rename_obs))

adata.obs_names = new_obs_names
adata.obs_names_make_unique()

if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1:
adata.obs["sample_batch"] = (
pd.Categorical(prefixes)
if len(np.unique(prefixes)) < adata.n_obs
else prefixes
)

adata.obs_names_make_unique()
return adata if copy else None
if not inplace:
return adata


def cleanup(
Expand Down