Skip to content

Commit

Permalink
Add converter for values defined on TORAX cell grid -> face grid.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696552573
  • Loading branch information
Nush395 authored and Torax team committed Nov 14, 2024
1 parent 29893a4 commit fd7538b
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 1 deletion.
84 changes: 83 additions & 1 deletion torax/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,97 @@
Math operations that are needed for Torax, but are not specific to plasma
physics or differential equation solvers.
"""
from __future__ import annotations
import enum
import functools
import jax
from jax import numpy as jnp
import jaxtyping as jt
from torax import array_typing
from torax import geometry
from torax import jax_utils


@enum.unique
class IntegralPreservationQuantity(enum.Enum):
"""The quantity to preserve the integral of when converting to face values."""
# Indicate that the volume integral should be preserved.
VOLUME = 'volume'
# Indicate that the surface integral should be preserved.
SURFACE = 'surface'
# Indicate that the value integral should be preserved.
VALUE = 'value'


@array_typing.typed
@functools.partial(
jax_utils.jit,
static_argnames=['preserved_quantity',],
)
def cell_to_face(
cell_values: jt.Float[jt.Array, 'rhon'],
geo: geometry.Geometry,
preserved_quantity: IntegralPreservationQuantity = IntegralPreservationQuantity.VALUE,
) -> jt.Float[jt.Array, 'rhon+1']:
"""Convert cell values to face values.
We make four assumptions:
1) Inner face values are the average of neighbouring cells.
2) The left most face value is linearly extrapolated from the left most cell
values.
3) The transformation from cell to face is integration preserving.
4) The cell spacing is constant.
Args:
cell_values: Values defined on the TORAX cell grid.
geo: A geometry object.
preserved_quantity: The quantity to preserve the integral of when converting
to face values.
Returns:
Values defined on the TORAX face grid.
"""
if len(cell_values) < 2:
raise ValueError(
'Cell values must have at least two values to convert to face values.'
)
inner_face_values = (cell_values[:-1] + cell_values[1:]) / 2.0
# Linearly extrapolate to get left value.
left = cell_values[0] - (inner_face_values[0] - cell_values[0])
face_values_without_right = jnp.concatenate([left[None], inner_face_values])
# Preserve integral.
match preserved_quantity:
case IntegralPreservationQuantity.VOLUME:
diff = jnp.sum(
cell_values * geo.vpr
) * geo.drho_norm - jax.scipy.integrate.trapezoid(
face_values_without_right * geo.vpr_face[:-1], geo.rho_face_norm[:-1]
)
right = (
2 * diff / geo.drho_norm
- face_values_without_right[-1] * geo.vpr_face[-2]
) / geo.vpr_face[-1]
case IntegralPreservationQuantity.SURFACE:
diff = jnp.sum(
cell_values * geo.spr_cell
) * geo.drho_norm - jax.scipy.integrate.trapezoid(
face_values_without_right * geo.spr_face[:-1], geo.rho_face_norm[:-1]
)
right = (
2 * diff / geo.drho_norm
- face_values_without_right[-1] * geo.spr_face[-2]
) / geo.spr_face[-1]
case IntegralPreservationQuantity.VALUE:
diff = jnp.sum(
cell_values
) * geo.drho_norm - jax.scipy.integrate.trapezoid(
face_values_without_right, geo.rho_face_norm[:-1]
)
right = 2 * diff / geo.drho_norm - face_values_without_right[-1]

face_values = jnp.concatenate([face_values_without_right, right[None]])
return face_values


@array_typing.typed
@jax_utils.jit
def cell_integration(
Expand Down
101 changes: 101 additions & 0 deletions torax/tests/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unit tests for torax.math_utils."""
from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -79,6 +80,106 @@ def test_cell_integration(self, num_cell_grid_points: int):
rtol=1e-6, # 1e-7 rtol is too tight for this test to pass.
)

@parameterized.named_parameters(
dict(
testcase_name='with_equally_spaced_cell_values_value_preserved',
cell_values=[1.0, 2.0, 3.0, 4.0],
expected_face_values_except_right=np.array([0.5, 1.5, 2.5, 3.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VALUE,
),
dict(
testcase_name='with_sawtooth_cell_values_value_preserved',
cell_values=[-1.0, 2.0, -3.0, 4.0],
expected_face_values_except_right=np.array([-2.5, 0.5, -0.5, 0.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VALUE,
),
dict(
testcase_name='with_unevenly_spaced_cell_values_value_preserved',
cell_values=[10, 6, 0, 20],
expected_face_values_except_right=np.array([12, 8, 3, 10]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VALUE,
),
dict(
testcase_name='with_equally_spaced_cell_values_surface_preserved',
cell_values=[1.0, 2.0, 3.0, 4.0],
expected_face_values_except_right=np.array([0.5, 1.5, 2.5, 3.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
),
dict(
testcase_name='with_sawtooth_cell_values_surface_preserved',
cell_values=[-1.0, 2.0, -3.0, 4.0],
expected_face_values_except_right=np.array([-2.5, 0.5, -0.5, 0.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
),
dict(
testcase_name='with_unevenly_spaced_cell_values_surface_preserved',
cell_values=[10, 6, 0, 20],
expected_face_values_except_right=np.array([12, 8, 3, 10]),
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
),
dict(
testcase_name='with_equally_spaced_cell_values_volume_preserved',
cell_values=[1.0, 2.0, 3.0, 4.0],
expected_face_values_except_right=np.array([0.5, 1.5, 2.5, 3.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VOLUME,
),
dict(
testcase_name='with_sawtooth_cell_values_volume_preserved',
cell_values=[-1.0, 2.0, -3.0, 4.0],
expected_face_values_except_right=np.array([-2.5, 0.5, -0.5, 0.5]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VOLUME,
),
dict(
testcase_name='with_unevenly_spaced_cell_values_volume_preserved',
cell_values=[10, 6, 0, 20],
expected_face_values_except_right=np.array([12, 8, 3, 10]),
preserved_quantity=math_utils.IntegralPreservationQuantity.VOLUME,
),
)
def test_cell_to_face(
self,
cell_values: list[float],
expected_face_values_except_right: np.ndarray,
preserved_quantity: math_utils.IntegralPreservationQuantity,
):
"""Test that the cell_to_face method works as expected."""
geo = geometry.build_circular_geometry(n_rho=len(cell_values))
cell_values = jnp.array(cell_values, dtype=jnp.float32)

face_values = math_utils.cell_to_face(cell_values, geo, preserved_quantity)
chex.assert_shape(face_values, (len(cell_values) + 1,))

np.testing.assert_array_equal(
face_values[:-1], expected_face_values_except_right
)
# Check the integral is preserved.
match preserved_quantity:
case math_utils.IntegralPreservationQuantity.VALUE:
np.testing.assert_allclose(
math_utils.cell_integration(cell_values, geo),
jax.scipy.integrate.trapezoid(face_values, geo.rho_face_norm),
)
case math_utils.IntegralPreservationQuantity.SURFACE:
np.testing.assert_allclose(
math_utils.cell_integration(cell_values * geo.spr_cell, geo),
jax.scipy.integrate.trapezoid(
face_values * geo.spr_face, geo.rho_face_norm
),
)
case math_utils.IntegralPreservationQuantity.VOLUME:
np.testing.assert_allclose(
math_utils.cell_integration(cell_values * geo.vpr, geo),
jax.scipy.integrate.trapezoid(
face_values * geo.vpr_face, geo.rho_face_norm
),
)

def test_cell_to_face_raises_when_too_few_values(self,):
"""Test that the cell_to_face method raises when too few values are provided."""
geo = geometry.build_circular_geometry(n_rho=1)
with self.assertRaises(ValueError):
math_utils.cell_to_face(jnp.array([1.0], dtype=np.float32), geo)


if __name__ == '__main__':
absltest.main()

0 comments on commit fd7538b

Please sign in to comment.