diff --git a/enterprise/signals/parameter.py b/enterprise/signals/parameter.py index 6e79280c..a52caed1 100644 --- a/enterprise/signals/parameter.py +++ b/enterprise/signals/parameter.py @@ -228,7 +228,7 @@ def NormalSampler(mu, sigma, size=None): vector value/mu and compatible covariance matrix sigma.""" if np.ndim(sigma) == 2: - return np.random.multivariate_normal(mu, sigma, size=size) + return np.random.multivariate_normal(mu, sigma) else: return np.random.normal(mu, sigma, size=size) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 7606a68d..1882030b 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -98,6 +98,3 @@ def test_normal(self): x1, x2 = NormalSampler(mu, sigma), scipy.stats.multivariate_normal.rvs(mean=mu, cov=sigma) assert x1.shape == x2.shape, msg2 - - x1, x2 = NormalSampler(mu, sigma, size=10), scipy.stats.multivariate_normal.rvs(mean=mu, cov=sigma, size=10) - assert x1.shape == x2.shape, msg2 diff --git a/tests/test_pulsar.py b/tests/test_pulsar.py index b33ec56a..dec156f0 100644 --- a/tests/test_pulsar.py +++ b/tests/test_pulsar.py @@ -20,6 +20,8 @@ from enterprise.pulsar import Pulsar from tests.enterprise_test_data import datadir + +import pint.models.timing_model from pint.models import get_model_and_toas @@ -196,18 +198,18 @@ def setUpClass(cls): def test_deflate_inflate(self): pass - def test_load_radec_psr(cls): + def test_load_radec_psr(self): """Setup the Pulsar object.""" - # initialize Pulsar class with RA DEC - psr = Pulsar( - datadir + "/J0030+0451_RADEC_wrong.par", - datadir + "/J0030+0451_NANOGrav_9yv1.tim", - ephem="DE430", - drop_pintpsr=False, - timing_package="pint", - ) - assert "AstrometryEquatorial" in psr.model.components + with self.assertRaises(pint.models.timing_model.TimingModelError): + # initialize Pulsar class with RAJ DECJ and PMLAMBDA, PMBETA + Pulsar( + datadir + "/J0030+0451_RADEC_wrong.par", + datadir + "/J0030+0451_NANOGrav_9yv1.tim", + ephem="DE430", + drop_pintpsr=False, + timing_package="pint", + ) def test_no_planet(self): """Test exception when incorrect par(tim) file given."""