Skip to content

Commit

Permalink
Merge pull request #17 from saezlab/devel
Browse files Browse the repository at this point in the history
Merge benchmark devel branch
  • Loading branch information
PauBadiaM authored Sep 1, 2022
2 parents 3f0fcfb + 9c2851c commit 22d6095
Show file tree
Hide file tree
Showing 49 changed files with 7,346 additions and 2,422 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: devel

on:
push:
branches: [ development ]
branches: [ devel ]
pull_request:
branches: [ main ]

Expand All @@ -21,7 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install wheel
pip install pytest flake8 sklearn skranger omnipath scanpy .
pip install pytest flake8 sklearn skranger omnipath scanpy adjustText .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -46,7 +46,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install wheel
pip install pytest flake8 sklearn skranger omnipath scanpy .
pip install pytest flake8 sklearn skranger omnipath scanpy adjustText .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install wheel
pip install pytest-cov flake8 sklearn skranger omnipath scanpy .
pip install pytest-cov flake8 sklearn skranger omnipath scanpy adjustText .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down Expand Up @@ -49,7 +49,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install wheel
pip install pytest flake8 sklearn skranger omnipath scanpy .
pip install pytest flake8 sklearn skranger omnipath scanpy adjustText .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

.DS_Store
benchmark-dev.ipynb
23 changes: 9 additions & 14 deletions decoupler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
__version__ = '1.1.13' # noqa: F401
__version__ = '1.2.0' # noqa: F401
__version_info__ = tuple([int(num) for num in __version__.split('.')]) # noqa: F401

from .pre import extract, match, rename_net, get_net_mat, filt_min_n, mask_features # noqa: F401
from .utils import melt, show_methods, check_corr, get_toy_data, summarize_acts, assign_groups # noqa: F401
from .utils import dense_run, p_adjust_fdr # noqa: F401
from .utils import dense_run, p_adjust_fdr, shuffle_net # noqa: F401
from .utils_anndata import get_acts, get_pseudobulk, get_contrast, get_top_targets, format_contrast_results # noqa: F401
from .method_wmean import run_wmean # noqa: F401
from .method_wsum import run_wsum # noqa: F401
from .method_ulm import run_ulm # noqa: F401
from .method_mdt import run_mdt # noqa: F401
from .method_mlm import run_mlm # noqa: F401
from .method_ora import run_ora, test1r # noqa: F401
from .method_udt import run_udt # noqa: F401
from .method_ora import run_ora, test1r, get_ora_df # noqa: F401
from .method_gsva import run_gsva # noqa: F401
from .method_gsea import run_gsea # noqa: F401
from .method_viper import run_viper # noqa: F401
Expand All @@ -18,14 +20,7 @@
from .consensus import cons # noqa: F401
from .omnip import show_resources, get_resource, get_progeny, get_dorothea # noqa: F401
from .plotting import plot_volcano, plot_violins, plot_barplot # noqa: F401

# External libraries go out of main setup
try:
from .method_mdt import run_mdt # noqa: F401
except Exception:
pass

try:
from .method_udt import run_udt # noqa: F401
except Exception:
pass
from .plotting import plot_metrics_scatter, plot_metrics_boxplot, plot_metrics_scatter_cols # noqa: F401
from .benchmark import benchmark, format_benchmark_inputs, get_performances # noqa: F401
from .utils_benchmark import get_toy_benchmark_data, show_metrics # noqa: F401
from .metrics import metric_auroc, metric_auprc, metric_mcauroc, metric_mcauprc # noqa: F401
236 changes: 236 additions & 0 deletions decoupler/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""
Functions to benchmark methods and nets.
Functions to benchmark methods and nets using perturbation experiments.
"""

import numpy as np
import pandas as pd

from .decouple import decouple
from .utils_anndata import extract_psbulk_inputs
from .pre import rename_net, filt_min_n
from .utils_benchmark import format_acts_grts, append_metrics_scores
from .utils_benchmark import validate_metrics, check_groupby, rename_obs


def get_performances(res, obs, groupby, by, metrics, min_exp=5, pi0=0.5, n_iter=1000,
seed=42, verbose=False):

# Return acts, grts and msks tensors
acts, grts, msks, srcs, mthds, grpbys, grps = format_acts_grts(res, obs, groupby)

# Init empty df
df = []
if msks is not None:
n_grpbys = len(msks)
for i in range(n_grpbys):
msk_i = msks[i]
grpby_i = grpbys[i]
grps_i = grps[i]
n_grps = len(grps_i)
if verbose:
print('Computing metrics for groupby {0}...'.format(grpby_i))
for j in range(n_grps):
msk = msk_i[j]
grp = grps_i[j]
n = np.sum(msk)

# If enough exps, subset by group
if n >= min_exp:
act, grt = acts[msk, :, :], grts[msk, :]

# Special case when groupby == perturb, remove extra grts
if grp in srcs:
m = grp == srcs
grt[:, ~m] = 0.

# Compute and append scores to df
append_metrics_scores(df, grpby_i, grp, act, grt, srcs, mthds, metrics, by, min_exp=min_exp,
pi0=pi0, n_iter=n_iter, seed=seed)
else:
n_exp = acts.shape[0]
if n_exp >= min_exp:

# Compute and append scores to df
if verbose:
print('Computing metrics...')
append_metrics_scores(df, None, None, acts, grts, srcs, mthds, metrics, by, min_exp=min_exp,
pi0=pi0, n_iter=n_iter, seed=seed)

# Format df
df = pd.DataFrame(df, columns=['groupby', 'group', 'source', 'method', 'metric', 'score', 'ci'])

return df


def format_benchmark_inputs(mat, obs, perturb, sign, net, groupby, by, f_expr=True, f_srcs=False,
source='source', target='target', weight='weight', min_n=5,
verbose=False, use_raw=True, decouple_kws={}):

# Extract inputs
if verbose:
print("Extracting inputs...")
mat, obs, var = extract_psbulk_inputs(mat, obs, layer=None, use_raw=use_raw)

# Format groupby
groupby = check_groupby(obs, groupby, perturb, by)

# Rename obs
obs = rename_obs(obs, perturb, sign)

# Rename net
if verbose:
print("Formating net...")
net = rename_net(net, source=decouple_kws['source'], target=decouple_kws['target'], weight=decouple_kws['weight'])
net = filt_min_n(var.index.values.astype('U'), net, min_n=decouple_kws['min_n'])

# Remove experiments without sources in net
if f_expr:
msk = np.full((obs['perturb'].size, ), False)
srcs = net['source'].values.astype('U')
for i, src in enumerate(obs['perturb']):
msk[i] = np.any(np.isin(src, srcs))
if verbose:
n = np.sum(~msk)
print("{0} experiments without sources in net, they will be removed.".format(n))
mat, obs = mat[msk], obs.loc[msk]

# Remove sources without experiments in obs
if f_srcs:
msk = np.isin(net['source'].values, obs['perturb'].values.ravel())
if verbose:
n = np.sum(~msk)
print("{0} sources without experiments in obs, they will be removed.".format(n))
net = net.loc[msk]

return mat, obs, var, net, groupby


def _benchmark(mat, obs, net, perturb, sign, metrics=['auroc', 'auprc'], groupby=None, by='experiment', f_expr=True,
f_srcs=False, min_exp=5, pi0=0.5, n_iter=1000, seed=42, verbose=True, use_raw=True, decouple_kws={}):

# Format inputs
mat, obs, var, net, groupby = format_benchmark_inputs(mat, obs, perturb, sign, net, groupby, by, f_expr=f_expr,
f_srcs=f_srcs, verbose=verbose, use_raw=use_raw,
decouple_kws=decouple_kws)

# Reset net names args
decouple_kws['source'] = 'source'
decouple_kws['target'] = 'target'
decouple_kws['weight'] = 'weight'

# Run prediction
if verbose:
print('Running methods...')
res = decouple([mat, obs.index, var.index], net, verbose=verbose, **decouple_kws)

# Compute metrics
if verbose:
print('Calculating metrics...')
df = get_performances(res, obs, groupby, by, metrics, min_exp=min_exp, pi0=pi0,
n_iter=n_iter, seed=seed, verbose=verbose)
if verbose:
print('Done.')

return df


def benchmark(mat, obs, net, perturb, sign, metrics=['auroc', 'auprc', 'mcauroc', 'mcauprc'], groupby=None,
by='experiment', f_expr=True, f_srcs=False, min_exp=5, pi0=0.5, n_iter=1000, seed=42,
verbose=True, use_raw=True, decouple_kws={}):
"""
Benchmark methods or networks on a given set of perturbation experiments using activity inference with decoupler.
Parameters
----------
mat : list, DataFrame or AnnData
List of [features, matrix], dataframe (samples x features) or an AnnData instance.
obs : DataFrame or None
Metadata containing the perturbed targets and the sign of the perturbation. If mat is AnnData, use mat.obs
attribute instead.
net : DataFrame, dict
Network in long format. Can be dictionary of nets, where key is the name and value is the long format DataFrame.
perturb : str
Column name in obs with perturbed sources.
sign : str, int
Column name in obs with sign of the perturbation. Can be set to 1 or -1 if all experiments are overexpression or
knockouts, respectively.
metrics : list, str
Performance metric(s) to compute. See the description of get_performance for more details. Defaults
to ['roc', 'calprc'].
groupby : list, str, None
Performance metrics(s) can be computed per groups if enough experiments are available.
by : str
Whether to evaluate performances at the "experiment" or at the "source" level.
f_expr : bool
Whether to filter out experiments whose perturbed sources are not in the given net. Defaults to True.
f_srcs : bool
Whether to fitler out sources in net for which there are not perturbation data. Defaults to False.
min_exp : int
Minimum of perturbation experiments per group.
pi0 : float
Reference ratio for calibrated metrics. Corresponds to the baseline/reference class inbalance to which
to set the metric.
n_iter : int
Number of downsampling iterations used for the 'mcroc' and 'mcprc' metrics.
seed : int
Random seed to use.
verbose : bool
Whether to show progress.
use_raw : bool
Use raw attribute of mat if present.
decouple_kws : dict
Parameters for the decoupler.decouple function. If more than one net, use a nested dictionary where the main
key is the network name and the value is a dictionary with the requiered arguments.
Returns
-------
df : DataFrame
DataFrame containing the metrics' scores.
"""

# Init default args
default_kws = {'source': 'source', 'target': 'target', 'weight': 'weight', 'min_n': 5}

# Validate by
if by not in ['experiment', 'source']:
raise ValueError('Argument `by` has to be either "experiment" or "source".')

# Validate metrics
validate_metrics(metrics)

# Validate pi0
if pi0 is not None:
if pi0 < 0 or pi0 > 1:
raise ValueError('Argument `pi0` needs to be between 0 and 1.')

# Run benchmark per net
if type(net) is not dict:

# Update decouple args
decouple_kws = {**default_kws, **decouple_kws}

# Run benchmark
df = _benchmark(mat, obs, net, perturb, sign, metrics, groupby, by, f_expr, f_srcs, min_exp, pi0,
n_iter, seed, verbose, use_raw, decouple_kws)
else:
df = []
for net_name in net:

if verbose:
print('Using {0} network...'.format(net_name))

# Update decouple args
decouple_kws.setdefault(net_name, {})
decouple_kws[net_name] = {**default_kws, **decouple_kws[net_name]}

# Run benchmark
tmp = _benchmark(mat, obs, net[net_name], perturb, sign, metrics, groupby, by, f_expr, f_srcs,
min_exp, pi0, n_iter, seed, verbose, use_raw, decouple_kws[net_name])
tmp['net'] = net_name
df.append(tmp)

# Merge all results
df = pd.concat(df)

return df
2 changes: 1 addition & 1 deletion decoupler/decouple.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def decouple(mat, net, source='source', target='target', weight='weight', method
Column name in net with target nodes.
weight : str
Column name in net with weights.
methods : list, str
methods : list, str, None
List of methods to run. If none are provided use weighted top performers (mlm, ulm and wsum). To run all methods set to
"all".
args : dict
Expand Down
6 changes: 3 additions & 3 deletions decoupler/method_gsva.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def density(mat, kcdf=False):
return mat


@nb.njit(nb.types.Tuple((nb.f4[:, :], nb.i4[:, :]))(nb.f4[:, :]), parallel=True)
@nb.njit(nb.types.Tuple((nb.f4[:, :], nb.i4[:, :]))(nb.f4[:, :]), parallel=True, cache=True)
def nb_get_D_I(mat):
n = mat.shape[1]
rev_idx = np.abs(np.arange(start=n, stop=0, step=-1, dtype=nb.f4) - n / 2)
Expand All @@ -92,7 +92,7 @@ def nb_get_D_I(mat):
return mat, Idx


@nb.njit(nb.f4(nb.f4[:], nb.i4[:], nb.i4, nb.i4[:], nb.i4[:], nb.i4, nb.f4))
@nb.njit(nb.f4(nb.f4[:], nb.i4[:], nb.i4, nb.i4[:], nb.i4[:], nb.i4, nb.f4), cache=True)
def ks_sample(D, Idx, n_genes, geneset_mask, fset, n_geneset, dec):

sum_gset = 0.0
Expand Down Expand Up @@ -121,7 +121,7 @@ def ks_sample(D, Idx, n_genes, geneset_mask, fset, n_geneset, dec):
return mx_value_sign


@nb.njit(nb.f4[:](nb.f4[:, :], nb.i4[:, :], nb.i4[:]), parallel=True)
@nb.njit(nb.f4[:](nb.f4[:, :], nb.i4[:, :], nb.i4[:]), parallel=True, cache=True)
def ks_matrix(D, Idx, fset):
n_samples, n_genes = D.shape
n_geneset = fset.shape[0]
Expand Down
Loading

0 comments on commit 22d6095

Please sign in to comment.