Skip to content

Commit

Permalink
Adding analysis module to facilitate all analysis tasks. (#8)
Browse files Browse the repository at this point in the history
* WIP: added analysis major module

* added _infer_one_epoch in basetrainer.py

* added test_base.py

* added an instance checker in datamanager.base and its test

* added test_analysis_opencell.py;
fixed bugs in analysis_opencell.py;
updated datamanager & BaseTrainer._infer_one_epoch;

* updated requirements.txt

* added --ignore=E203 for flake8 in pytest-codecov-conda.yml
  • Loading branch information
li-li-github authored May 25, 2022
1 parent d611466 commit 98e98e4
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest-codecov-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
run: |
conda install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 cytoself --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 cytoself --count --select=E9,F63,F7,F82, --ignore=E203, --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 cytoself --count --exit-zero --max-complexity=15 --max-line-length=127 --statistics
- name: Generate coverage report
Expand Down
Empty file added cytoself/analysis/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions cytoself/analysis/analysis_opencell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import inspect
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from cytoself.analysis.base import BaseAnalysis


class AnalysisOpenCell(BaseAnalysis):
"""
Analysis class for OpenCell data
"""

def __init__(self, datamanager, trainer, homepath: Optional[str] = None, **kwargs):
super().__init__(datamanager, trainer, homepath, **kwargs)

def plot_umap_of_embedding_vector(
self,
data_loader: Optional = None,
label_data: Optional = None,
umap_data: Optional = None,
embedding_data: Optional = None,
image_data: Optional = None,
group_col: int = 0,
unique_groups: Optional = None,
**kwargs,
):
# Get compute umap data from embedding_data
if umap_data is None:
if embedding_data is None:
print('Computing embeddings from image...')
if data_loader is None:
if image_data is None:
raise ValueError('image_data is required.')
if label_data is None:
raise ValueError('label_data is required.')
embedding_data = self.trainer.infer_embeddings(
image_data if data_loader is None else data_loader,
**{a: kwargs[a] for a in inspect.getfullargspec(self.trainer.infer_embeddings).args if a in kwargs},
)
if data_loader is not None:
embedding_data, label_data = embedding_data
print('Computing UMAP coordinates from embeddings...')
umap_data = self._transform_umap(
embedding_data,
**{a: kwargs[a] for a in inspect.getfullargspec(self._transform_umap).args if a in kwargs},
)

# Configure arguments
local_figsize = kwargs['figsize'] if 'figsize' in kwargs else (6.4, 4.8)
scatter_kwargs = {a: kwargs[a] for a in inspect.getfullargspec(plt.scatter).args if a in kwargs}
if unique_groups is None:
unique_groups = np.unique(label_data[:, group_col])

# Making the plot
fig, ax = plt.subplots(figsize=local_figsize)
for gp in tqdm(unique_groups):
ind = label_data[:, group_col] == gp
ax.scatter(umap_data[ind, 0], umap_data[ind, 1], label=gp, **scatter_kwargs)
if 'title' in kwargs:
ax.set_title(kwargs['title'])
if 'xlabel' in kwargs:
ax.set_xlabel(kwargs['xlabel'])
if 'ylabel' in kwargs:
ax.set_ylabel(kwargs['ylabel'])
fig.tight_layout()

if 'savepath' in kwargs:
fig.savefig(kwargs['savepath'], dpi=kwargs['dpi'] if 'dpi' in kwargs else 'figure')
42 changes: 42 additions & 0 deletions cytoself/analysis/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os.path
from os.path import join
import umap
from typing import Optional


class BaseAnalysis:
"""
Base class for Analysis
"""

def __init__(
self,
datamanager,
trainer,
homepath: Optional[str] = None,
**kwargs,
):
self.datamanager = datamanager
self.trainer = trainer
self.reducer = None
self.savepath_dict = {
'homepath': join(trainer.savepath_dict['homepath'], 'analysis') if homepath is None else homepath
}
self._init_savepath()

def _init_savepath(self):
folders = ['umap_figures', 'umap_data', 'feature_spectra_figures', 'feature_spectra_data']
for f in folders:
p = join(self.savepath_dict['homepath'], f)
if not os.path.exists(p):
os.makedirs(p)
self.savepath_dict[f] = p

def _fit_umap(self, data, n_neighbors=15, min_dist=0.1, metric='euclidean', verbose=True, **kwargs):
self.reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, verbose=verbose, **kwargs)
self.reducer.fit(data.reshape(data.shape[0], -1))

def _transform_umap(self, data, n_neighbors=15, min_dist=0.1, metric='euclidean', verbose=True, **kwargs):
if self.reducer is None:
self._fit_umap(data, n_neighbors, min_dist, metric, verbose, **kwargs)
return self.reducer.transform(data.reshape(data.shape[0], -1))
Empty file.
84 changes: 84 additions & 0 deletions cytoself/analysis/test/test_analysis_opencell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from os.path import exists, join
from contextlib import contextmanager
import pytest

import numpy as np
from ..analysis_opencell import AnalysisOpenCell
from cytoself.trainer.test.test_vanilla_trainer import test_VanillaAETrainer


@contextmanager
def assert_not_raises():
try:
yield
except Exception as e:
raise pytest.fail(f'Did raise {e}')


class test_BaseAnalysis(test_VanillaAETrainer):
def setUp(self):
super().setUp()
self.analysis = AnalysisOpenCell(self.datamgr, self.trainer)
self.label_data = np.repeat(np.arange(10), 10).reshape(-1, 1)
np.random.shuffle(self.label_data)
self.file_name = join(self.analysis.savepath_dict['umap_figures'], 'test.png')

def test_plot_umap_of_embedding_vector_nullinput(self):
with self.assertRaises(ValueError):
self.analysis.plot_umap_of_embedding_vector(label_data=self.label_data)
with self.assertRaises(ValueError):
self.analysis.plot_umap_of_embedding_vector(image_data=np.random.randn(100, 1, 32, 32))

def test_plot_umap_of_embedding_vector_umapdata(self):
with assert_not_raises():
self.analysis.plot_umap_of_embedding_vector(
label_data=self.label_data,
umap_data=np.random.randn(100, 2),
savepath=self.file_name,
title='fig title',
xlabel='x axis',
ylabel='y axis',
)
assert exists(self.file_name)
os.remove(self.file_name)

def test_plot_umap_of_embedding_vector_embedding(self):
with assert_not_raises():
self.analysis.plot_umap_of_embedding_vector(
label_data=self.label_data,
embedding_data=np.random.randn(100, 10),
savepath=self.file_name,
title='fig title',
xlabel='x axis',
ylabel='y axis',
)
assert exists(self.file_name)
os.remove(self.file_name)

def test_plot_umap_of_embedding_vector_image(self):
with assert_not_raises():
self.analysis.plot_umap_of_embedding_vector(
label_data=self.datamgr.test_loader.dataset.label,
image_data=self.datamgr.test_loader.dataset.data,
savepath=self.file_name,
title='fig title',
xlabel='x axis',
ylabel='y axis',
)
assert exists(self.file_name)
os.remove(self.file_name)

def test_plot_umap_of_embedding_vector_dataloader(self):
self.datamgr.const_dataset(label_format='index')
self.datamgr.const_dataloader()
with assert_not_raises():
self.analysis.plot_umap_of_embedding_vector(
data_loader=self.datamgr.test_loader,
savepath=self.file_name,
title='fig title',
xlabel='x axis',
ylabel='y axis',
)
assert exists(self.file_name)
os.remove(self.file_name)
28 changes: 28 additions & 0 deletions cytoself/analysis/test/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from os.path import exists, join
import numpy as np
from ..base import BaseAnalysis
from cytoself.trainer.test.test_vanilla_trainer import test_VanillaAETrainer


class test_BaseAnalysis(test_VanillaAETrainer):
def setUp(self):
super().setUp()
self.analysis = BaseAnalysis(self.datamgr, self.trainer)

def test__init_savepath(self):
self.analysis._init_savepath()
for f in ['umap_figures', 'umap_data', 'feature_spectra_figures', 'feature_spectra_data']:
assert f in self.analysis.savepath_dict
assert exists(join(self.analysis.savepath_dict['homepath'], f))

def test__fit_umap(self):
self.analysis._fit_umap(np.random.randn(100, 10), verbose=False)
assert self.analysis.reducer.embedding_.shape == (100, 2)

def test__transform_umap(self):
output = self.analysis._transform_umap(np.random.randn(100, 10))
assert (output == self.analysis.reducer.embedding_).all()

self.analysis.reducer = None
output = self.analysis._transform_umap(np.random.randn(100, 10))
assert (output == self.analysis.reducer.embedding_).all()
14 changes: 0 additions & 14 deletions cytoself/datamanager/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Callable

from numpy.typing import ArrayLike
from torch.utils.data import Dataset

Expand Down Expand Up @@ -37,19 +36,6 @@ def const_dataset(self, *args, **kwargs):
pass


# A potential class object to extract train, val and test data.
# class Subset(Dataset):
# def __init__(self, dataset, indices):
# self._indices = indices
# self._dataset = dataset
#
# def __len__(self):
# return len(self._indices)
#
# def __getitem__(self, index):
# return self._dataset[self._indices[index]]


class PreloadedDataset(Dataset):
"""
Dataset class for preloaded data
Expand Down
40 changes: 40 additions & 0 deletions cytoself/datamanager/datamanager_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,22 @@ def const_dataset(
self.test_dataset = PreloadedDataset(test_label, test_data, transform, self.unique_labels, label_format)

def const_dataloader(self, shuffle: bool = True, **kwargs):
"""
Constructs DataLoader
Parameters
----------
shuffle : bool
Shuffle batch if True
Returns
-------
None
"""
_assert_dtype(self.train_dataset.label, self.train_dataset.label_format)
_assert_dtype(self.val_dataset.label, self.val_dataset.label_format)
_assert_dtype(self.test_dataset.label, self.test_dataset.label_format)
self.train_loader = DataLoader(
self.train_dataset, self.batch_size, shuffle=shuffle, num_workers=self.num_workers, **kwargs
)
Expand All @@ -305,6 +321,30 @@ def const_dataloader(self, shuffle: bool = True, **kwargs):
)


def _assert_dtype(label, label_format):
"""
Asserts dtype for DataLoader
Parameters
----------
label : numpy array
Label data
label_format : str
label_format argument in const_dataset
Returns
-------
None
"""
if (
label_format is None
and isinstance(label, np.ndarray)
and not (np.issubdtype(label, np.number) or np.issubdtype(label, bool))
):
raise TypeError(f'label must be numerical to use dataloader, instead {label.dtype} is given.')


def get_file_df(basepath: str, suffix: Union[str, Sequence] = 'label', extension: str = 'npy'):
"""
Creates a DataFrame of data paths.
Expand Down
4 changes: 4 additions & 0 deletions cytoself/datamanager/test/test_datamanager_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def test_const_dataset_labelonly(self):
self._assert_dataset(datamgr.test_dataset, keys=('labels',))

def test_const_dataloader(self):
self.datamgr.const_dataset(label_format=None)
with self.assertRaises(TypeError):
self.datamgr.const_dataloader(shuffle=False)

self.datamgr.const_dataset()
self.datamgr.const_dataloader(shuffle=False)
for train_data0 in self.datamgr.train_dataset:
Expand Down
Loading

0 comments on commit 98e98e4

Please sign in to comment.