Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{tools}[gfbf/2023a] jax v0.4.25 w/ CUDA 12.1.1 #20119

Merged
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8e4e184
{tools}[foss/2023a] jax v0.4.25, ml_dtypes v0.3.2 w/ CUDA 12.1.1
ThomasHoffmann77 Mar 14, 2024
94e9567
fix style
ThomasHoffmann77 Mar 14, 2024
1827967
fix test_no_log_spam
ThomasHoffmann77 Mar 15, 2024
021b248
update patch
ThomasHoffmann77 Mar 15, 2024
a90bc88
isolate some tests
ThomasHoffmann77 Mar 19, 2024
8c5d1f3
Update easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Mar 23, 2024
1839f6d
Update easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Mar 27, 2024
0fb9863
Merge branch 'easybuilders:develop' into 20240314161733_new_pr_jax0425
ThomasHoffmann77 Mar 27, 2024
d80e8fe
Update and rename jax-0.4.25-foss-2023a-CUDA-12.1.1.eb to jax-0.4.25-…
ThomasHoffmann77 Mar 27, 2024
a025bfb
Update and rename ml_dtypes-0.3.2-foss-2023a.eb to ml_dtypes-0.3.2-gf…
ThomasHoffmann77 Mar 27, 2024
9291f70
Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Apr 8, 2024
07bb12d
Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Apr 11, 2024
7809bd3
Fix usage of system Pybind11
Flamefire Apr 11, 2024
dc112fc
Merge pull request #2 from Flamefire/jax_improvements
ThomasHoffmann77 Apr 11, 2024
65e87a6
Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Apr 12, 2024
c5f0711
Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Apr 12, 2024
cf6f043
Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Apr 29, 2024
2e9329d
Use Bazel --distdir
Flamefire May 8, 2024
3b51fbe
Merge pull request #4 from Flamefire/jax
ThomasHoffmann77 May 22, 2024
a643eeb
Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 May 22, 2024
a045edf
revert workaround for framework bug in extract command in easyconfig …
boegel May 30, 2024
2a1b972
Delete easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.…
ThomasHoffmann77 Jun 3, 2024
53afba3
Delete easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb
ThomasHoffmann77 Jun 3, 2024
ee19f89
Merge branch 'develop' of https://github.com/easybuilders/easybuild-e…
Jun 3, 2024
c6518ee
Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
ThomasHoffmann77 Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.25'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://pypi.python.org/pypi/jax'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2023a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
('Bazel', '6.3.1'),
('pytest-xdist', '3.3.1'),
# git 2.x required to fetch repository 'io_bazel_rules_docker'
('git', '2.41.0', '-nodocs'),
('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py
('poetry', '1.5.1'),
('pybind11', '2.11.1'),
]

dependencies = [
('CUDA', '12.1.1', '', SYSTEM),
('cuDNN', '8.9.2.26', versionsuffix, SYSTEM),
('NCCL', '2.18.3', versionsuffix),
('zlib', '1.2.13'),
('Python', '3.11.3'),
('SciPy-bundle', '2023.07'),
('flatbuffers-python', '23.5.26'),
('ml_dtypes', '0.3.2'),
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl
local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'

local_extract_cmd = 'cp %s %(builddir)s/archives'
ThomasHoffmann77 marked this conversation as resolved.
Show resolved Hide resolved
ThomasHoffmann77 marked this conversation as resolved.
Show resolved Hide resolved
local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" '
local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '

ThomasHoffmann77 marked this conversation as resolved.
Show resolved Hide resolved

# Some tests require an isolated run:
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

default_easyblock = 'PythonPackage'
default_component_specs = {
'sources': [SOURCE_TAR_GZ],
'source_urls': [PYPI_SOURCE],
'start_dir': '%(name)s-%(version)s',
'use_pip': True,
'sanity_pip_check': True,
'download_dep_fail': True,
}

components = [
('absl-py', '2.1.0', {
'options': {'modulename': 'absl'},
'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'],
}),
('jaxlib', version, {
'sources': [
'%(name)s-v%(version)s.tar.gz',
{
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit,
'extract_cmd': local_extract_cmd,
},
{
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit,
'extract_cmd': local_extract_cmd,
},
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/runtime/archive',
'https://github.com/openxla/xla/archive'
],
'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
'checksums': [
{'jaxlib-v0.4.25.tar.gz':
'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
{'xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4.tar.gz':
'8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'},
{'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.25_fix-pybind11-systemlib.patch':
'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
],
'start_dir': 'jax-jaxlib-v%(version)s',
# Avoid warning (treated as error) in upb/table.c
'buildopts': local_repo_opt + ' --bazel_options="--copt=-Wno-maybe-uninitialized"'
}),
]

exts_list = [
(name, version, {
'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'],
'runtest': local_test,
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'},
{'jax-0.4.25_fix_env_test_no_log_spam.patch':
'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'},
],
}),
]

sanity_pip_check = True

moduleclass = 'tools'
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Add missing value for System Pybind11 Bazel config

Author: Alexander Grund (TU Dresden)

diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch
new file mode 100644
index 000000000..68bd2063d
--- /dev/null
+++ b/third_party/xla/fix-pybind11-systemlib.patch
@@ -0,0 +1,13 @@
+--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD
++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD
+@@ -6,3 +6,10 @@
+ "@tsl//third_party/python_runtime:headers",
+ ],
+ )
++
++# Needed by pybind11_bazel.
++config_setting(
++ name = "osx",
++ constraint_values = ["@platforms//os:osx"],
++)
++
diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
index ebc8d9838..125e1c173 100644
--- a/third_party/xla/workspace.bzl
+++ b/third_party/xla/workspace.bzl
@@ -29,6 +29,9 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
+ patch_file = [
+ "//third_party/xla:fix-pybind11-systemlib.patch",
+ ],
)

# For development, one often wants to make changes to the TF repository as well

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/03
# avoid overriding LD_LIBRARY_PATH, which would lead to test error: error while loading shared libraries: libpython3.11.so.1.0: cannot open shared object file: No such file or directory'
diff -ru jax-jax-v0.4.25/tests/logging_test.py jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py
--- jax-jax-v0.4.25/tests/logging_test.py 2024-02-24 19:25:17.000000000 +0100
+++ jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py 2024-03-15 12:00:34.133022613 +0100
@@ -72,8 +72,11 @@
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
+ import os
+ tmp_env=os.environ.copy()
+ tmp_env['TF_CPP_MIN_LOG_LEVEL']='1'
proc = subprocess.run([python, "-c", program], capture_output=True,
- env={"TF_CPP_MIN_LOG_LEVEL": "1"})
+ env=tmp_env)

lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
Loading