-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CLI to install dependencies #104
Changes from all commits
6417dc7
062bcc5
4a45443
1d4cb8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import importlib.util | ||
import os | ||
import shutil | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
|
||
import click | ||
import typer | ||
|
||
|
||
TORCH_VER = "2.4.0" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should you source the version from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was actually thinking in the long run I will try to remove the dependency altogether, so users could install |
||
JETSTREAM_PT_VER = "ec4ac8f6b180ade059a2284b8b7d843b3cab0921" | ||
DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") | ||
|
||
app = typer.Typer() | ||
|
||
|
||
def _check_module(module_name: str): | ||
spec = importlib.util.find_spec(module_name) | ||
return spec is not None | ||
|
||
|
||
def _run(cmd: str): | ||
split_cmd = cmd.split() | ||
subprocess.check_call(split_cmd) | ||
|
||
|
||
def _install_torch_cpu(): | ||
# install torch CPU version to avoid installing CUDA dependencies | ||
_run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu") | ||
|
||
|
||
@app.command() | ||
def install_pytorch_xla( | ||
force: bool = False, | ||
): | ||
""" | ||
Installs PyTorch XLA with TPU support. | ||
|
||
Args: | ||
force (bool): When set, force reinstalling even if Pytorch XLA is already installed. | ||
""" | ||
if not force and _check_module("torch") and _check_module("torch_xla"): | ||
typer.confirm( | ||
"PyTorch XLA is already installed. Do you want to reinstall it?", | ||
default=False, | ||
abort=True, | ||
) | ||
_install_torch_cpu() | ||
_run( | ||
sys.executable | ||
+ f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html" | ||
) | ||
click.echo() | ||
click.echo(click.style("PyTorch XLA has been installed.", bold=True)) | ||
|
||
|
||
@app.command() | ||
def install_jetstream_pytorch( | ||
deps_path: str = DEFAULT_DEPS_PATH, | ||
yes: bool = False, | ||
): | ||
""" | ||
Installs Jetstream Pytorch with TPU support. | ||
|
||
Args: | ||
deps_path (str): Path where Jetstream Pytorch dependencies will be installed. | ||
yes (bool): When set, proceed installing without asking questions. | ||
""" | ||
if not _check_module("torch"): | ||
_install_torch_cpu() | ||
if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"): | ||
typer.confirm( | ||
"Jetstream Pytorch is already installed. Do you want to reinstall it?", | ||
default=False, | ||
abort=True, | ||
) | ||
|
||
jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch") | ||
if not yes and os.path.exists(jetstream_repo_dir): | ||
typer.confirm( | ||
f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?", | ||
default=False, | ||
abort=True, | ||
) | ||
shutil.rmtree(jetstream_repo_dir, ignore_errors=True) | ||
# Create the directory if it does not exist | ||
os.makedirs(deps_path, exist_ok=True) | ||
# Clone and install Jetstream Pytorch | ||
os.chdir(deps_path) | ||
_run("git clone https://github.com/google/jetstream-pytorch.git") | ||
os.chdir("jetstream-pytorch") | ||
_run(f"git checkout {JETSTREAM_PT_VER}") | ||
_run("git submodule update --init --recursive") | ||
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, | ||
# because it will install its dependendencies from that directory. | ||
_run(sys.executable + " -m pip install -e .") | ||
|
||
_run( | ||
sys.executable | ||
+ f" -m pip install torch_xla[pallas]=={TORCH_VER} " | ||
+ " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html" | ||
+ " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" | ||
+ " -f https://storage.googleapis.com/libtpu-releases/index.html" | ||
) | ||
# Install PyTorch XLA pallas | ||
click.echo() | ||
click.echo(click.style("Jetstream Pytorch has been installed.", bold=True)) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(app()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
24.04
not supported yet?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we had the same problem last week on optimum, see here. Ubuntu latest comes with python 12, not supported by torch xla yet.