Skip to content

Commit

Permalink
Adding test and fixing deprecated unit parameters (#211)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Yahalomi <[email protected]>

Co-authored-by: Daniel Yahalomi <[email protected]>
  • Loading branch information
dfm and dyahalomi authored Jun 22, 2021
1 parent 6dab90c commit 3dd400d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/exoplanet/orbits/keplerian.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from exoplanet_core.pymc import ops

from ..citations import add_citations_to_model
from ..units import has_unit, to_unit
from ..utils import as_tensor_variable
from ..units import has_unit, to_unit, with_unit
from ..utils import as_tensor_variable, deprecation_warning
from .constants import G_grav, au_per_R_sun, c_light, gcc_per_sun


Expand Down Expand Up @@ -90,11 +90,21 @@ def __init__(
rho_star=None,
ror=None,
model=None,
contact_points_kwargs=None,
**kwargs
):
add_citations_to_model(self.__citations__, model=model)

if "m_planet_units" in kwargs:
deprecation_warning(
"'m_planet_units' is deprecated; Use `with_unit` instead"
)
m_planet = with_unit(m_planet, kwargs.pop("m_planet_units"))
if "rho_star_units" in kwargs:
deprecation_warning(
"'rho_star_units' is deprecated; Use `with_unit` instead"
)
rho_star = with_unit(rho_star, kwargs.pop("rho_star_units"))

self.jacobians = defaultdict(lambda: defaultdict(None))

daordtau = None
Expand Down
37 changes: 37 additions & 0 deletions tests/units_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

import astropy.units as u
import numpy as np

from exoplanet import units
from exoplanet.orbits import KeplerianOrbit


def test_mass_units():
P_earth = 365.256
Tper_earth = 2454115.5208333
inclination_earth = np.radians(45.0)

orbit1 = KeplerianOrbit(
period=P_earth,
t_periastron=Tper_earth,
incl=inclination_earth,
m_planet=units.with_unit(1.0, u.M_earth),
)
orbit2 = KeplerianOrbit(
period=P_earth,
t_periastron=Tper_earth,
incl=inclination_earth,
m_planet=1.0,
m_planet_units=u.M_earth,
)

t = np.linspace(Tper_earth, Tper_earth + 1000, 1000)
rv1 = orbit1.get_radial_velocity(t).eval()
rv_diff = np.max(rv1) - np.min(rv1)
assert rv_diff < 1.0, "with_unit"

rv2 = orbit2.get_radial_velocity(t).eval()
rv_diff = np.max(rv2) - np.min(rv2)
assert rv_diff < 1.0, "m_planet_units"
np.testing.assert_allclose(rv2, rv1)

0 comments on commit 3dd400d

Please sign in to comment.