Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Skip V100 encoder tests #1262

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache

from transformer_engine.transformer_engine_jax import get_device_compute_capability


@lru_cache
def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80
4 changes: 4 additions & 0 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
Expand Down Expand Up @@ -434,6 +436,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand All @@ -446,6 +449,7 @@ def test_te_fp8(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
Expand Down Expand Up @@ -402,6 +404,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
12 changes: 8 additions & 4 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
Expand Down Expand Up @@ -552,8 +554,9 @@ def encoder_parser(args):
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, reason])
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])


def unittest_query_gpu():
Expand All @@ -566,15 +569,15 @@ def unittest_query_gpu():
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, reason = q.get()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, reason
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason


class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

num_gpu, gpu_has_fp8, reason = unittest_query_gpu()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()

def exec(self, use_fp8):
"""Run 3 epochs for testing"""
Expand All @@ -598,6 +601,7 @@ def exec(self, use_fp8):

return results

@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Expand Down Expand Up @@ -321,6 +323,7 @@ def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
2 changes: 0 additions & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
Loading