Skip to content

Commit

Permalink
move save and load outside of scHPF object & add run_trials method
Browse files Browse the repository at this point in the history
  • Loading branch information
hannaml committed Dec 7, 2018
1 parent 5439789 commit 55d3e49
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 48 deletions.
49 changes: 18 additions & 31 deletions bin/scHPF
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import numpy as np
import pandas as pd
from scipy.io import mmread, mmwrite
from scipy.sparse import coo_matrix
from sklearn.externals import joblib

from schpf import scHPF
from schpf import load_model, save_model, run_trials
import schpf.preprocessing as pp

def _parser():
Expand Down Expand Up @@ -160,6 +160,8 @@ if __name__=='__main__':
print('......found {} cells and {} genes'.format(ncells, ngenes))

print('Generating masks for filtering......')
if args.min_cells < 0:
raise ValueError('min_cells must be >= 0')
mask = pp.min_cells_expressing_mask(umis, args.min_cells)
if args.whitelist is not None and len(args.whitelist):
whitelist = pd.read_csv(args.whitelist, delim_whitespace=True,
Expand Down Expand Up @@ -191,50 +193,35 @@ if __name__=='__main__':
print( 'Loading data......' )
load_fnc = mmread if args.input.endswith('.mtx') else pp.load_coo
train = load_fnc(args.input)

ncells, ngenes = train.shape
msg = '......found {} cells and {} genes'.format(ncells, ngenes)
print(msg)
if ngenes >= 20000:
msg = 'WARNING: you are running scHPF with {} genes,'.format(ngenes)
msg += ' which is more than the ~20k protein coding genes in the'
msg += ' human genome. We suggest running scHPF on protein-coding'
msg += ' genes only.'
print(msg)

if args.validation is not None:
vdata = load_fnc(args.validation)
else:
vdata = None
ncells, ngenes = train.shape
msg = '......found {} cells and {} genes'.format(ncells, ngenes)

# create model
print( 'Running trials......' )
best_loss, best_model = 1e100, None
print('Running trials......' )
dtype = np.float32 if args.float32 else np.float64
for t in range(args.ntrials):
model = scHPF(nfactors=args.nfactors,
min_iter=args.min_iter, max_iter=args.max_iter,
check_freq=args.check_freq, epsilon=args.epsilon,
better_than_n_ago=args.better_than_n_ago,
dtype=dtype
)
model.fit(train, validation_data=vdata, verbose=args.verbose)
model = run_trials(train,
nfactors=args.nfactors, ntrials=args.ntrials,
min_iter=args.min_iter, max_iter=args.max_iter,
check_freq=args.check_freq, epsilon=args.epsilon,
better_than_n_ago=args.better_than_n_ago,
dtype=dtype, verbose=args.verbose,
validation_data=vdata)

loss = model.loss[-1]
if loss < best_loss:
best_model = model
best_loss = loss
print('New best!')

print('Best loss: {0:.6f}'.format(best_loss))
# save the model
print('Saving best model......')
prefix = args.prefix.rstrip('.') + '.' if len(args.prefix) > 0 else ''
outprefix = '{}/{}'.format(args.outdir, prefix)
outname = '{}scHPF_K{}_epsilon{}_{}trials.joblib'.format(
outprefix, args.nfactors, args.epsilon, args.ntrials)
joblib.dump(best_model, outname)
save_model(model, outname)

elif args.cmd == 'score':
print('Loading model......')
model = joblib.load(args.model)
model = load_model(args.model)

print('Calculating scores......')
cell_score = model.cell_score()
Expand Down
2 changes: 1 addition & 1 deletion schpf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .scHPF_ import scHPF
from .scHPF_ import *
145 changes: 129 additions & 16 deletions schpf/scHPF_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python

from copy import deepcopy
from warnings import warn

import numpy as np
from scipy.sparse import coo_matrix
Expand Down Expand Up @@ -146,6 +147,8 @@ class scHPF(BaseEstimator):
Variational distributions for eta
beta: HPF_Gamma (optional, default None)
Variational distributions for beta
verbose: bool (optional, default True)
Print messages at each check_freq
"""
def __init__(
self,
Expand All @@ -166,7 +169,8 @@ def __init__(
theta=None,
eta=None,
beta=None,
loss=None
loss=None,
verbose=True,
):
"""Initialize HPF instance"""
self.nfactors = nfactors
Expand All @@ -182,6 +186,7 @@ def __init__(
self.epsilon = epsilon
self.better_than_n_ago = better_than_n_ago
self.dtype = dtype
self.verbose = verbose

self.xi = None
self.eta = None
Expand Down Expand Up @@ -283,7 +288,6 @@ def fit(self, X, validation_data=None, **params):
return self


# TODO copy self but with new values for new data
def project(self, X, validation_data=None, min_iter=10):
"""Get bp,xi and theta for new data while fixing gene scores"""
(bp, _, xi, _, theta, _) = self._fit(X,
Expand All @@ -306,7 +310,10 @@ def save(self, file_name):
file_name : str
Name of file to save model to
"""
joblib.dumps(model, file_name)
msg = 'scHPF.save() is deprecated. Instead, use save_model in the '
msg+= 'schpf module.'
warn(msg, DeprecationWarning)
joblib.dump(model, file_name)


@staticmethod
Expand All @@ -318,6 +325,9 @@ def load(file_name):
file_name : str
Joblib file containing a saved scHPF model
"""
msg = 'scHPF.load(file) is deprecated. Instead, use load_model in the '
msg+= 'schpf module.'
warn(msg, DeprecationWarning)
return joblib.load(file_name)


Expand All @@ -326,7 +336,7 @@ def _score(self, capacity, loading):


def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
verbose=True, min_iter=None, message_function=None):
min_iter=None, checkstep_function=None, verbose=None):
"""Combined internal fit/transform function
Parameters
Expand All @@ -340,15 +350,16 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
reinit: bool, (optional, default True)
Randomly initialize variational distributions even if they
already exist. Superseded by freeze_genes.
verbose: bool (optional, default True)
Print messages at each check_freq
min_iter: int (optional, default None)
replaces self.min_iter if given
message_function : function (optional, default None)
Replaces self.min_iter if given. Useful when projecting
new data onto an existing scHPF model.
checkstep_function : function (optional, default None)
A function that takes arguments theta, beta, and t and, if
given, is called at check_interval. Intended use is
to check additional stats during training, potentially with
hardcoded data, but is unrestricted. Use at own risk.
verbose: bool (optional, default None)
If not None, overrides self.verbose
Returns
-------
Expand All @@ -369,10 +380,10 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
"""
# local (convenience) vars for model
nfactors, (ncells, ngenes) = self.nfactors, X.shape
(a, ap, c, cp) = (self.a, self.ap, self.c, self.cp)
a, ap, c, cp = self.a, self.ap, self.c, self.cp

# get empirically set hyperparameters and variational distributions
(bp, dp, xi, eta, theta, beta) = self._setup(X, freeze_genes, reinit)
bp, dp, xi, eta, theta, beta = self._setup(X, freeze_genes, reinit)

# Make first updates for hierarchical prior
# (vi_shape is constant, but want to update full distribution)
Expand All @@ -382,6 +393,7 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,

pct_change = []
min_iter = self.min_iter if min_iter is None else min_iter
verbose = self.verbose if verbose is None else verbose
for t in range(self.max_iter):
if t==0 and reinit: #randomize phi for first iteration
random_phi = np.random.dirichlet( np.ones(nfactors),
Expand Down Expand Up @@ -426,9 +438,8 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
msg = '[Iter. {0: >4}] loss:{1:.6f} pct:{2:.9f}'.format(
t, curr, pct_change[-1])
print(msg)
if message_function is not None:
message_function(theta, beta, t)

if checkstep_function is not None:
checkstep_function(theta, beta, t)

# check convergence
if len(self.loss) > 3 and t >= min_iter:
Expand All @@ -440,7 +451,8 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
and (np.abs(prev) > np.abs(curr)))
converged = current_small and prev_small and not_inflection
if converged:
print('converged')
if verbose:
print('converged')
break

# getting worse, and has been for better_than_n_ago checks
Expand All @@ -451,7 +463,8 @@ def _fit(self, X, validation_data=None, freeze_genes=False, reinit=True,
worse_than_n_ago = np.abs(nprev) < np.abs(curr)
getting_worse = np.abs(prev) < np.abs(curr)
if worse_than_n_ago and getting_worse:
print('getting worse break')
if verbose:
print('getting worse break')
break


Expand Down Expand Up @@ -516,11 +529,111 @@ def mean_var_ratio(X, axis):

def _initialize(self, X, freeze_genes=False):
"""Shortcut to setup random distributions without fitting"""
(bp, dp, xi, eta, theta, beta) = self._setup(X, freeze_genes,
bp, dp, xi, eta, theta, beta = self._setup(X, freeze_genes,
reinit=True)
self.bp = bp
self.dp = dp
self.xi = xi
self.eta = eta
self.theta = theta
self.beta = beta


def load_model(file_name):
"""Load a model from a joblib file
Parameters
----------
file_name : str
Joblib file containing a saved scHPF model
"""
return joblib.load(file_name)


def save_model(model, file_name):
"""Save model to (joblib) file
Serialize scHPF model as a joblib file. Joblib is simillar to pickle,
but preferable for objects with many numpy arrays
Parameters
----------
model : scHPF
The scHPF model object to save
file_name : str
Name of file to save model to
"""
joblib.dump(model, file_name)


def run_trials(X, nfactors,
ntrials=5,
min_iter=30,
max_iter=1000,
check_freq=10,
epsilon=0.001,
better_than_n_ago=5,
dtype=np.float64,
verbose=True,
validation_data=None
):
"""
Train with multiple random initializations, selecting model with best loss
As scHPF uses non-convex optimization, it benefits from training with
multiple random initializations to avoid local minima.
Parameters
----------
X: coo_matrix
Data to fit
nfactors: int
Number of factors (K)
ntrials : int (optional, default 5)
Number of random initializations for training
min_iter: int (optional, default 30):
Minimum number of interations for training.
max_iter: int (optional, default 1000):
Maximum number of interations for training.
check_freq: int (optional, default 10)
Number of training iterations between calculating loss.
epsilon: float (optional, default 0.001)
Percent change of loss for convergence.
better_than_n_ago: int (optional, default 5)
Stop condition if loss is getting worse. Stops training if loss
is worse than `better_than_n_ago`*`check_freq` training steps
ago and getting worse.
validation_data: coo_matrix, (optional, default None)
validation data, train data used to assess convergence if not given
"""
ngenes = X.shape[1]
if ngenes >= 20000:
msg = 'WARNING: you are running scHPF with {} genes,'.format(ngenes)
msg += ' which is more than the ~20k protein coding genes in the'
msg += ' human genome. We suggest running scHPF on protein-coding'
msg += ' genes only.'
print(msg)

best_loss, best_model, best_t = np.finfo(np.float64).max, None, None
for t in range(ntrials):
model = scHPF(nfactors=nfactors,
min_iter=min_iter, max_iter=max_iter,
check_freq=check_freq, epsilon=epsilon,
better_than_n_ago=better_than_n_ago,
verbose=verbose, dtype=dtype,
)
model.fit(X, validation_data=validation_data)

loss = model.loss[-1]
if loss < best_loss:
best_model = model
best_loss = loss
best_t = t
if verbose:
print('New best! (trial {})'.format(t))

if verbose:
print('Best loss: {0:.6f} ({1})'.format(best_loss, best_t))
return best_model


0 comments on commit 55d3e49

Please sign in to comment.