-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding analysis module to facilitate all analysis tasks. (#8)
* WIP: added analysis major module * added _infer_one_epoch in basetrainer.py * added test_base.py * added an instance checker in datamanager.base and its test * added test_analysis_opencell.py; fixed bugs in analysis_opencell.py; updated datamanager & BaseTrainer._infer_one_epoch; * updated requirements.txt * added --ignore=E203 for flake8 in pytest-codecov-conda.yml
- Loading branch information
1 parent
d611466
commit 98e98e4
Showing
12 changed files
with
321 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import inspect | ||
from typing import Optional | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from cytoself.analysis.base import BaseAnalysis | ||
|
||
|
||
class AnalysisOpenCell(BaseAnalysis): | ||
""" | ||
Analysis class for OpenCell data | ||
""" | ||
|
||
def __init__(self, datamanager, trainer, homepath: Optional[str] = None, **kwargs): | ||
super().__init__(datamanager, trainer, homepath, **kwargs) | ||
|
||
def plot_umap_of_embedding_vector( | ||
self, | ||
data_loader: Optional = None, | ||
label_data: Optional = None, | ||
umap_data: Optional = None, | ||
embedding_data: Optional = None, | ||
image_data: Optional = None, | ||
group_col: int = 0, | ||
unique_groups: Optional = None, | ||
**kwargs, | ||
): | ||
# Get compute umap data from embedding_data | ||
if umap_data is None: | ||
if embedding_data is None: | ||
print('Computing embeddings from image...') | ||
if data_loader is None: | ||
if image_data is None: | ||
raise ValueError('image_data is required.') | ||
if label_data is None: | ||
raise ValueError('label_data is required.') | ||
embedding_data = self.trainer.infer_embeddings( | ||
image_data if data_loader is None else data_loader, | ||
**{a: kwargs[a] for a in inspect.getfullargspec(self.trainer.infer_embeddings).args if a in kwargs}, | ||
) | ||
if data_loader is not None: | ||
embedding_data, label_data = embedding_data | ||
print('Computing UMAP coordinates from embeddings...') | ||
umap_data = self._transform_umap( | ||
embedding_data, | ||
**{a: kwargs[a] for a in inspect.getfullargspec(self._transform_umap).args if a in kwargs}, | ||
) | ||
|
||
# Configure arguments | ||
local_figsize = kwargs['figsize'] if 'figsize' in kwargs else (6.4, 4.8) | ||
scatter_kwargs = {a: kwargs[a] for a in inspect.getfullargspec(plt.scatter).args if a in kwargs} | ||
if unique_groups is None: | ||
unique_groups = np.unique(label_data[:, group_col]) | ||
|
||
# Making the plot | ||
fig, ax = plt.subplots(figsize=local_figsize) | ||
for gp in tqdm(unique_groups): | ||
ind = label_data[:, group_col] == gp | ||
ax.scatter(umap_data[ind, 0], umap_data[ind, 1], label=gp, **scatter_kwargs) | ||
if 'title' in kwargs: | ||
ax.set_title(kwargs['title']) | ||
if 'xlabel' in kwargs: | ||
ax.set_xlabel(kwargs['xlabel']) | ||
if 'ylabel' in kwargs: | ||
ax.set_ylabel(kwargs['ylabel']) | ||
fig.tight_layout() | ||
|
||
if 'savepath' in kwargs: | ||
fig.savefig(kwargs['savepath'], dpi=kwargs['dpi'] if 'dpi' in kwargs else 'figure') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os.path | ||
from os.path import join | ||
import umap | ||
from typing import Optional | ||
|
||
|
||
class BaseAnalysis: | ||
""" | ||
Base class for Analysis | ||
""" | ||
|
||
def __init__( | ||
self, | ||
datamanager, | ||
trainer, | ||
homepath: Optional[str] = None, | ||
**kwargs, | ||
): | ||
self.datamanager = datamanager | ||
self.trainer = trainer | ||
self.reducer = None | ||
self.savepath_dict = { | ||
'homepath': join(trainer.savepath_dict['homepath'], 'analysis') if homepath is None else homepath | ||
} | ||
self._init_savepath() | ||
|
||
def _init_savepath(self): | ||
folders = ['umap_figures', 'umap_data', 'feature_spectra_figures', 'feature_spectra_data'] | ||
for f in folders: | ||
p = join(self.savepath_dict['homepath'], f) | ||
if not os.path.exists(p): | ||
os.makedirs(p) | ||
self.savepath_dict[f] = p | ||
|
||
def _fit_umap(self, data, n_neighbors=15, min_dist=0.1, metric='euclidean', verbose=True, **kwargs): | ||
self.reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, verbose=verbose, **kwargs) | ||
self.reducer.fit(data.reshape(data.shape[0], -1)) | ||
|
||
def _transform_umap(self, data, n_neighbors=15, min_dist=0.1, metric='euclidean', verbose=True, **kwargs): | ||
if self.reducer is None: | ||
self._fit_umap(data, n_neighbors, min_dist, metric, verbose, **kwargs) | ||
return self.reducer.transform(data.reshape(data.shape[0], -1)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
from os.path import exists, join | ||
from contextlib import contextmanager | ||
import pytest | ||
|
||
import numpy as np | ||
from ..analysis_opencell import AnalysisOpenCell | ||
from cytoself.trainer.test.test_vanilla_trainer import test_VanillaAETrainer | ||
|
||
|
||
@contextmanager | ||
def assert_not_raises(): | ||
try: | ||
yield | ||
except Exception as e: | ||
raise pytest.fail(f'Did raise {e}') | ||
|
||
|
||
class test_BaseAnalysis(test_VanillaAETrainer): | ||
def setUp(self): | ||
super().setUp() | ||
self.analysis = AnalysisOpenCell(self.datamgr, self.trainer) | ||
self.label_data = np.repeat(np.arange(10), 10).reshape(-1, 1) | ||
np.random.shuffle(self.label_data) | ||
self.file_name = join(self.analysis.savepath_dict['umap_figures'], 'test.png') | ||
|
||
def test_plot_umap_of_embedding_vector_nullinput(self): | ||
with self.assertRaises(ValueError): | ||
self.analysis.plot_umap_of_embedding_vector(label_data=self.label_data) | ||
with self.assertRaises(ValueError): | ||
self.analysis.plot_umap_of_embedding_vector(image_data=np.random.randn(100, 1, 32, 32)) | ||
|
||
def test_plot_umap_of_embedding_vector_umapdata(self): | ||
with assert_not_raises(): | ||
self.analysis.plot_umap_of_embedding_vector( | ||
label_data=self.label_data, | ||
umap_data=np.random.randn(100, 2), | ||
savepath=self.file_name, | ||
title='fig title', | ||
xlabel='x axis', | ||
ylabel='y axis', | ||
) | ||
assert exists(self.file_name) | ||
os.remove(self.file_name) | ||
|
||
def test_plot_umap_of_embedding_vector_embedding(self): | ||
with assert_not_raises(): | ||
self.analysis.plot_umap_of_embedding_vector( | ||
label_data=self.label_data, | ||
embedding_data=np.random.randn(100, 10), | ||
savepath=self.file_name, | ||
title='fig title', | ||
xlabel='x axis', | ||
ylabel='y axis', | ||
) | ||
assert exists(self.file_name) | ||
os.remove(self.file_name) | ||
|
||
def test_plot_umap_of_embedding_vector_image(self): | ||
with assert_not_raises(): | ||
self.analysis.plot_umap_of_embedding_vector( | ||
label_data=self.datamgr.test_loader.dataset.label, | ||
image_data=self.datamgr.test_loader.dataset.data, | ||
savepath=self.file_name, | ||
title='fig title', | ||
xlabel='x axis', | ||
ylabel='y axis', | ||
) | ||
assert exists(self.file_name) | ||
os.remove(self.file_name) | ||
|
||
def test_plot_umap_of_embedding_vector_dataloader(self): | ||
self.datamgr.const_dataset(label_format='index') | ||
self.datamgr.const_dataloader() | ||
with assert_not_raises(): | ||
self.analysis.plot_umap_of_embedding_vector( | ||
data_loader=self.datamgr.test_loader, | ||
savepath=self.file_name, | ||
title='fig title', | ||
xlabel='x axis', | ||
ylabel='y axis', | ||
) | ||
assert exists(self.file_name) | ||
os.remove(self.file_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from os.path import exists, join | ||
import numpy as np | ||
from ..base import BaseAnalysis | ||
from cytoself.trainer.test.test_vanilla_trainer import test_VanillaAETrainer | ||
|
||
|
||
class test_BaseAnalysis(test_VanillaAETrainer): | ||
def setUp(self): | ||
super().setUp() | ||
self.analysis = BaseAnalysis(self.datamgr, self.trainer) | ||
|
||
def test__init_savepath(self): | ||
self.analysis._init_savepath() | ||
for f in ['umap_figures', 'umap_data', 'feature_spectra_figures', 'feature_spectra_data']: | ||
assert f in self.analysis.savepath_dict | ||
assert exists(join(self.analysis.savepath_dict['homepath'], f)) | ||
|
||
def test__fit_umap(self): | ||
self.analysis._fit_umap(np.random.randn(100, 10), verbose=False) | ||
assert self.analysis.reducer.embedding_.shape == (100, 2) | ||
|
||
def test__transform_umap(self): | ||
output = self.analysis._transform_umap(np.random.randn(100, 10)) | ||
assert (output == self.analysis.reducer.embedding_).all() | ||
|
||
self.analysis.reducer = None | ||
output = self.analysis._transform_umap(np.random.randn(100, 10)) | ||
assert (output == self.analysis.reducer.embedding_).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.