Skip to content

Commit

Permalink
Merge pull request #317 from vallis/master - Fix multivariate-normal …
Browse files Browse the repository at this point in the history
…parameter bug

Looks good, merge
  • Loading branch information
vallis authored Mar 8, 2022
2 parents dffb805 + 6eb2600 commit 2df96c2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion enterprise/signals/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def NormalSampler(mu, sigma, size=None):
vector value/mu and compatible covariance matrix sigma."""

if np.ndim(sigma) == 2:
return np.random.multivariate_normal(mu, sigma, size=size)
return np.random.multivariate_normal(mu, sigma)
else:
return np.random.normal(mu, sigma, size=size)

Expand Down
3 changes: 0 additions & 3 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,3 @@ def test_normal(self):

x1, x2 = NormalSampler(mu, sigma), scipy.stats.multivariate_normal.rvs(mean=mu, cov=sigma)
assert x1.shape == x2.shape, msg2

x1, x2 = NormalSampler(mu, sigma, size=10), scipy.stats.multivariate_normal.rvs(mean=mu, cov=sigma, size=10)
assert x1.shape == x2.shape, msg2
22 changes: 12 additions & 10 deletions tests/test_pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from enterprise.pulsar import Pulsar
from tests.enterprise_test_data import datadir

import pint.models.timing_model
from pint.models import get_model_and_toas


Expand Down Expand Up @@ -196,18 +198,18 @@ def setUpClass(cls):
def test_deflate_inflate(self):
pass

def test_load_radec_psr(cls):
def test_load_radec_psr(self):
"""Setup the Pulsar object."""

# initialize Pulsar class with RA DEC
psr = Pulsar(
datadir + "/J0030+0451_RADEC_wrong.par",
datadir + "/J0030+0451_NANOGrav_9yv1.tim",
ephem="DE430",
drop_pintpsr=False,
timing_package="pint",
)
assert "AstrometryEquatorial" in psr.model.components
with self.assertRaises(pint.models.timing_model.TimingModelError):
# initialize Pulsar class with RAJ DECJ and PMLAMBDA, PMBETA
Pulsar(
datadir + "/J0030+0451_RADEC_wrong.par",
datadir + "/J0030+0451_NANOGrav_9yv1.tim",
ephem="DE430",
drop_pintpsr=False,
timing_package="pint",
)

def test_no_planet(self):
"""Test exception when incorrect par(tim) file given."""
Expand Down

0 comments on commit 2df96c2

Please sign in to comment.