diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb new file mode 100644 index 000000000000..c193b35c83b1 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -0,0 +1,144 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.4.25' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://pypi.python.org/pypi/jax' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'gfbf', 'version': '2023a'} +cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + ('Bazel', '6.3.1'), + ('pytest-xdist', '3.3.1'), + # git 2.x required to fetch repository 'io_bazel_rules_docker' + ('git', '2.41.0', '-nodocs'), + ('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py + ('poetry', '1.5.1'), + ('pybind11', '2.11.1'), +] + +dependencies = [ + ('CUDA', '12.1.1', '', SYSTEM), + ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), + ('NCCL', '2.18.3', versionsuffix), + ('zlib', '1.2.13'), + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), + ('flatbuffers-python', '23.5.26'), + ('ml_dtypes', '0.3.2'), +] + +# downloading xla and other tarballs to avoid that Bazel downloads it during the build +# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl +local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' +local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' + +local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' +local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' +local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' +local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' + + +# Some tests require an isolated run: +local_isolated_tests = [ + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', + 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + + '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', +] +# deliberately not testing in parallel, as that results in (additional) failing tests; +# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, +# see https://github.com/google/jax/issues/7323 and +# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; +# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; +# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; +local_test_exports = [ + "NVIDIA_TF32_OVERRIDE=0", + "CUDA_VISIBLE_DEVICES=0", + "XLA_PYTHON_CLIENT_ALLOCATOR=platform", + "JAX_ENABLE_X64=true", +] +local_test = ''.join(['export %s;' % x for x in local_test_exports]) +# run all tests at once except for local_isolated_tests: +local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) +# run remaining local_isolated_tests separately: +local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) + +use_pip = True + +default_easyblock = 'PythonPackage' +default_component_specs = { + 'sources': [SOURCE_TAR_GZ], + 'source_urls': [PYPI_SOURCE], + 'start_dir': '%(name)s-%(version)s', + 'use_pip': True, + 'sanity_pip_check': True, + 'download_dep_fail': True, +} + +components = [ + ('absl-py', '2.1.0', { + 'options': {'modulename': 'absl'}, + 'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'], + }), + ('jaxlib', version, { + 'sources': [ + '%(name)s-v%(version)s.tar.gz', + { + 'download_filename': '%s.tar.gz' % local_xla_commit, + 'filename': 'xla-%s.tar.gz' % local_xla_commit, + 'extract_cmd': local_extract_cmd, + }, + { + 'download_filename': '%s.tar.gz' % local_tfrt_commit, + 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, + 'extract_cmd': local_extract_cmd, + }, + ], + 'source_urls': [ + 'https://github.com/google/jax/archive/', + 'https://github.com/tensorflow/runtime/archive', + 'https://github.com/openxla/xla/archive' + ], + 'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'], + 'checksums': [ + {'jaxlib-v0.4.25.tar.gz': + 'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'}, + {'xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4.tar.gz': + '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, + {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': + 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, + {'jax-0.4.25_fix-pybind11-systemlib.patch': + 'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'}, + ], + 'start_dir': 'jax-jaxlib-v%(version)s', + # Avoid warning (treated as error) in upb/table.c + 'buildopts': local_repo_opt + ' --bazel_options="--copt=-Wno-maybe-uninitialized"' + }), +] + +exts_list = [ + (name, version, { + 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], + 'runtest': local_test, + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': [ + {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'}, + {'jax-0.4.25_fix_env_test_no_log_spam.patch': + 'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch new file mode 100644 index 000000000000..c404ee6917fb --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch @@ -0,0 +1,38 @@ +Add missing value for System Pybind11 Bazel config + +Author: Alexander Grund (TU Dresden) + +diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch +new file mode 100644 +index 000000000..68bd2063d +--- /dev/null ++++ b/third_party/xla/fix-pybind11-systemlib.patch +@@ -0,0 +1,13 @@ ++--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD +++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD ++@@ -6,3 +6,10 @@ ++ "@tsl//third_party/python_runtime:headers", ++ ], ++ ) +++ +++# Needed by pybind11_bazel. +++config_setting( +++ name = "osx", +++ constraint_values = ["@platforms//os:osx"], +++) +++ +diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl +index ebc8d9838..125e1c173 100644 +--- a/third_party/xla/workspace.bzl ++++ b/third_party/xla/workspace.bzl +@@ -29,6 +29,9 @@ def repo(): + sha256 = XLA_SHA256, + strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), + urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), ++ patch_file = [ ++ "//third_party/xla:fix-pybind11-systemlib.patch", ++ ], + ) + + # For development, one often wants to make changes to the TF repository as well + diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch new file mode 100644 index 000000000000..ad9196084373 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch @@ -0,0 +1,18 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/03 +# avoid overriding LD_LIBRARY_PATH, which would lead to test error: error while loading shared libraries: libpython3.11.so.1.0: cannot open shared object file: No such file or directory' +diff -ru jax-jax-v0.4.25/tests/logging_test.py jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py +--- jax-jax-v0.4.25/tests/logging_test.py 2024-02-24 19:25:17.000000000 +0100 ++++ jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py 2024-03-15 12:00:34.133022613 +0100 +@@ -72,8 +72,11 @@ + python = sys.executable + assert "python" in python + # Make sure C++ logging is at default level for the test process. ++ import os ++ tmp_env=os.environ.copy() ++ tmp_env['TF_CPP_MIN_LOG_LEVEL']='1' + proc = subprocess.run([python, "-c", program], capture_output=True, +- env={"TF_CPP_MIN_LOG_LEVEL": "1"}) ++ env=tmp_env) + + lines = proc.stdout.split(b"\n") + lines.extend(proc.stderr.split(b"\n"))