Skip to content

Commit

Permalink
Merge pull request #4 from Flamefire/jax
Browse files Browse the repository at this point in the history
Use Bazel --distdir
  • Loading branch information
ThomasHoffmann77 authored May 22, 2024
2 parents cf6f043 + 2e9329d commit 3b51fbe
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
11 changes: 7 additions & 4 deletions easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ dependencies = [
# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl
local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'
local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit
local_repo_opt += '--bazel_options="--override_repository=tf_runtime=%%(builddir)s/runtime-%s" ' % local_tfrt_commit

local_extract_cmd = 'cp %s %(builddir)s/archives'
local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" '
local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '

Expand Down Expand Up @@ -94,18 +95,20 @@ components = [
{
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit,
'extract_cmd': local_extract_cmd,
},
{
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit,
'extract_cmd': local_extract_cmd,
},
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/runtime/archive',
'https://github.com/openxla/xla/archive'
],
'patches': [('jax-0.4.25_fix-pybind11-systemlib.patch', '../xla-' + local_xla_commit)],
'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
'checksums': [
{'jaxlib-v0.4.25.tar.gz':
'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
Expand All @@ -114,7 +117,7 @@ components = [
{'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.25_fix-pybind11-systemlib.patch':
'ec93de5628e4d40d3378b92784f7d1e5b0b43bd207a86badeffd44a42e0b1d47'},
'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
],
'start_dir': 'jax-jaxlib-v%(version)s',
# Avoid warning (treated as error) in upb/table.c
Expand Down
46 changes: 33 additions & 13 deletions easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,37 @@ Add missing value for System Pybind11 Bazel config

Author: Alexander Grund (TU Dresden)

--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD
+++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD
@@ -6,3 +6,10 @@
"@tsl//third_party/python_runtime:headers",
],
)
+
+# Needed by pybind11_bazel.
+config_setting(
+ name = "osx",
+ constraint_values = ["@platforms//os:osx"],
+)
+
diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch
new file mode 100644
index 000000000..68bd2063d
--- /dev/null
+++ b/third_party/xla/fix-pybind11-systemlib.patch
@@ -0,0 +1,13 @@
+--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD
++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD
+@@ -6,3 +6,10 @@
+ "@tsl//third_party/python_runtime:headers",
+ ],
+ )
++
++# Needed by pybind11_bazel.
++config_setting(
++ name = "osx",
++ constraint_values = ["@platforms//os:osx"],
++)
++
diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
index ebc8d9838..125e1c173 100644
--- a/third_party/xla/workspace.bzl
+++ b/third_party/xla/workspace.bzl
@@ -29,6 +29,9 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
+ patch_file = [
+ "//third_party/xla:fix-pybind11-systemlib.patch",
+ ],
)

# For development, one often wants to make changes to the TF repository as well

0 comments on commit 3b51fbe

Please sign in to comment.