Skip to content

Commit

Permalink
adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jul 24, 2024
1 parent e654d19 commit 22985b0
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import prod

import jax.numpy as jnp
import numpy as np
import pytest
from conftest import initialize_distributed
from jax.experimental import mesh_utils, multihost_utils
Expand All @@ -14,6 +15,7 @@
from numpy.testing import assert_allclose

import jaxdecomp
from jaxdecomp._src import PENCILS, SLAB_XY, SLAB_YZ

# Initialize cuDecomp
initialize_distributed()
Expand All @@ -37,13 +39,21 @@ def create_spmd_array(global_shape, pdims):
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('y', 'z'))
mesh = Mesh(devices.T, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))

return global_array, mesh


def print_array(array):
print(f"shape {array.shape} rank {rank}")
for z in range(array.shape[0]):
for y in range(array.shape[1]):
for x in range(array.shape[2]):
print(f"[{z},{y},{x}] {array[z,y,x]}")


pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100

Expand All @@ -61,6 +71,13 @@ def test_fft(pdims, global_shape):

print("*" * 80)
print(f"Testing with pdims {pdims} and global shape {global_shape}")
if pdims[0] == 1:
penciltype = SLAB_XY
elif pdims[1] == 1:
penciltype = SLAB_YZ
else:
penciltype = PENCILS
print(f"Decomposition type {penciltype}")

global_array, mesh = create_spmd_array(global_shape, pdims)

Expand All @@ -82,14 +99,21 @@ def test_fft(pdims, global_shape):
assert_allclose(
gathered_array.imag, gathered_rec_array.imag, rtol=1e-7, atol=1e-7)

print(f"Reconstruction check OK!")

# Check the forward FFT
transpose_back = [1, 2, 0]
if penciltype == SLAB_YZ:
transpose_back = [2, 0, 1]
else:
transpose_back = [1, 2, 0]
jax_karray_transposed = jax_karray.transpose(transpose_back)
assert_allclose(
gathered_karray.real, jax_karray_transposed.real, rtol=1e-7, atol=1e-7)
assert_allclose(
gathered_karray.imag, jax_karray_transposed.imag, rtol=1e-7, atol=1e-7)

print(f"FFT with transpose check OK!")


# Cartesian product tests
@pytest.mark.parametrize("pdims",
Expand Down

0 comments on commit 22985b0

Please sign in to comment.