Skip to content

Commit

Permalink
No download by default (#707)
Browse files Browse the repository at this point in the history
* Do not download libtorch by default.

* CI fix.

* More CI fixes.

* Update the changelog.
  • Loading branch information
LaurentMazare authored May 15, 2023
1 parent 6b0956d commit 6653692
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: check
args: --features download-libtorch

test:
name: Test Suite
Expand All @@ -38,6 +39,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: test
args: --features download-libtorch

fmt:
name: Rustfmt
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## v0.13.0 - unreleased yet
### Added
- Make the libtorch download opt-in rather than a default behavior. The libtorch
library download can still be triggered by enabling the `download-libtorch`
feature, [707](https://github.com/LaurentMazare/tch-rs/pull/707).
- Rename the `of_...` conversion functions to `from_...` so as to be closer to
the Rust best practices,
[706](https://github.com/LaurentMazare/tch-rs/pull/706). This is a breaking
Expand Down
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ anyhow = "1"
members = ["torch-sys", "examples/python-extension"]

[features]
default = ["torch-sys/download-libtorch"]
download-libtorch = ["torch-sys/download-libtorch"]
python-extension = ["torch-sys/python-extension"]
rl_python = ["cpython"]
rl-python = ["cpython"]
doc-only = ["torch-sys/doc-only"]
cuda-tests = []

Expand All @@ -49,7 +49,7 @@ features = [ "doc-only" ]

[[example]]
name = "reinforcement-learning"
required-features = ["rl_python"]
required-features = ["rl-python"]

[[example]]
name = "stable-diffusion"
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ your system. You can either:
- Install libtorch manually and let the build script know about it via the `LIBTORCH` environment variable.
- Use a Python PyTorch install, to do this set `LIBTORCH_USE_PYTORCH=1`.
- When a system-wide libtorch can't be found and `LIBTORCH` is not set, the
build script will download a pre-built binary version of libtorch. By default
a CPU version is used. The `TORCH_CUDA_VERSION` environment variable can be
set to `cu117` in order to get a pre-built binary using CUDA 11.7.
build script can download a pre-built binary version of libtorch by using
the `download-libtorch` feature. By default a CPU version is used. The
`TORCH_CUDA_VERSION` environment variable can be set to `cu117` in order to
get a pre-built binary using CUDA 11.7.

### System-wide Libtorch

Expand Down
14 changes: 14 additions & 0 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ import sysconfig
print('PYTHON_INCLUDE:', sysconfig.get_path('include'))
";

const NO_DOWNLOAD_ERROR_MESSAGE: &str = r"
Cannot find a libtorch install, you can either:
- Install libtorch manually and set the LIBTORCH environment variable to appropriate path.
- Use a system wide install in /usr/lib/libtorch.so.
- Use a Python environment with PyTorch installed by setting LIBTORCH_USE_PYTORCH=1
See the readme for more details:
https://github.com/LaurentMazare/tch-rs/blob/main/README.md
";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Os {
Linux,
Expand Down Expand Up @@ -233,6 +243,10 @@ impl SystemInfo {
} else if let Some(pathbuf) = Self::check_system_location(os) {
Ok(pathbuf)
} else {
if !cfg!(feature = "download-libtorch") {
anyhow::bail!(NO_DOWNLOAD_ERROR_MESSAGE)
}

let device = match env_var_rerun("TORCH_CUDA_VERSION") {
Ok(cuda_env) => match os {
Os::Linux | Os::Windows => cuda_env
Expand Down

0 comments on commit 6653692

Please sign in to comment.