Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI to install dependencies #104

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/doc-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on:

jobs:
build_documentation:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
env:
COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
PR_NUMBER: ${{ github.event.number }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc-pr-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ concurrency:

jobs:
build_documentation:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 24.04 not supported yet?

Copy link
Collaborator Author

@tengomucho tengomucho Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we had the same problem last week on optimum, see here. Ubuntu latest comes with python 12, not supported by torch xla yet.

env:
COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
PR_NUMBER: ${{ github.event.number }}
Expand Down
6 changes: 1 addition & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ tgi_server:
VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server

jetstream_requirements:
bash install-jetstream-pt.sh
python -m pip install .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
python optimum/tpu/cli.py install-jetstream-pt --force

tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Optimum-TPU

[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://huggingface.co/docs/optimum/index)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

[![Optimum TPU / Test TGI on TPU](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml/badge.svg)](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml)
</div>

[Tensor Processing Units (TPU)](https://cloud.google.com/tpu) are AI accelerator made by Google to optimize
Expand Down Expand Up @@ -49,10 +49,10 @@ Please see the [TGI specific documentation](text-generation-inference) on how to

### JetStream Pytorch Engine

`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command:
`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated CLI command:

```shell
make jetstream_requirements
optimum-tpu install-jetstream-pytorch
```

To enable the support, export the environment variable `JETSTREAM_PT=1`.
Expand Down
18 changes: 0 additions & 18 deletions install-jetstream-pt.sh

This file was deleted.

113 changes: 113 additions & 0 deletions optimum/tpu/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import importlib.util
import os
import shutil
import subprocess
import sys
from pathlib import Path

import click
import typer


TORCH_VER = "2.4.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you source the version from torch directly? Seems very error prone / easy forgettable in the long run isn't it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually thinking in the long run I will try to remove the dependency altogether, so users could install optimum-tpu and then install torch cpu using this CLI: this will allow to avoid installing the cuda dependencies of the default torch package.

JETSTREAM_PT_VER = "ec4ac8f6b180ade059a2284b8b7d843b3cab0921"
DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps")

app = typer.Typer()


def _check_module(module_name: str):
spec = importlib.util.find_spec(module_name)
return spec is not None


def _run(cmd: str):
split_cmd = cmd.split()
subprocess.check_call(split_cmd)


def _install_torch_cpu():
# install torch CPU version to avoid installing CUDA dependencies
_run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu")


@app.command()
def install_pytorch_xla(
force: bool = False,
):
"""
Installs PyTorch XLA with TPU support.

Args:
force (bool): When set, force reinstalling even if Pytorch XLA is already installed.
"""
if not force and _check_module("torch") and _check_module("torch_xla"):
typer.confirm(
"PyTorch XLA is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)
_install_torch_cpu()
_run(
sys.executable
+ f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html"
)
click.echo()
click.echo(click.style("PyTorch XLA has been installed.", bold=True))


@app.command()
def install_jetstream_pytorch(
deps_path: str = DEFAULT_DEPS_PATH,
yes: bool = False,
):
"""
Installs Jetstream Pytorch with TPU support.

Args:
deps_path (str): Path where Jetstream Pytorch dependencies will be installed.
yes (bool): When set, proceed installing without asking questions.
"""
if not _check_module("torch"):
_install_torch_cpu()
if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"):
typer.confirm(
"Jetstream Pytorch is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)

jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch")
if not yes and os.path.exists(jetstream_repo_dir):
typer.confirm(
f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?",
default=False,
abort=True,
)
shutil.rmtree(jetstream_repo_dir, ignore_errors=True)
# Create the directory if it does not exist
os.makedirs(deps_path, exist_ok=True)
# Clone and install Jetstream Pytorch
os.chdir(deps_path)
_run("git clone https://github.com/google/jetstream-pytorch.git")
os.chdir("jetstream-pytorch")
_run(f"git checkout {JETSTREAM_PT_VER}")
_run("git submodule update --init --recursive")
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
_run(sys.executable + " -m pip install -e .")

_run(
sys.executable
+ f" -m pip install torch_xla[pallas]=={TORCH_VER} "
+ " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html"
+ " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html"
+ " -f https://storage.googleapis.com/libtpu-releases/index.html"
)
# Install PyTorch XLA pallas
click.echo()
click.echo(click.style("Jetstream Pytorch has been installed.", bold=True))


if __name__ == "__main__":
sys.exit(app())
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"transformers == 4.41.1",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
'typer == 0.6.1',
"loguru == 0.6.0",
"sentencepiece == 0.2.0",
]
Expand Down Expand Up @@ -103,4 +104,7 @@ filterwarnings = [
"ignore:`do_sample` is set",
"ignore:Device capability of jax",
"ignore:`tensorflow` can conflict",
]
]

[project.scripts]
optimum-tpu = "optimum.tpu.cli:app"
Loading