-
Notifications
You must be signed in to change notification settings - Fork 744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Update to 2.1 #1426
[PyTorch] Update to 2.1 #1426
Conversation
@sbrunk Could you review this pull request? |
I still plan to add another improvement: support for stateless datasets, dataloaders, etc... |
If the upstream libtorch binaries don't cause any issues, I'm in favor of switching. I think it has a number of advantages:
I will do some tests like running the test suite of Storch using this branch and the upstream binaries to see if I run into any issues. One thing that might be interesting to look into is MPS acceleration on MacOS, as soon most macs will run on ARM with GPU support. |
The libtorch binary for mac that can be downloaded from PyTorch main page is still x86_64 only. |
There's no need to switch to anything, the presets already support LibTorch: |
RIght, but what exactly does the user gain in using our binary compared to libtorch, now that libtorch links with CUDA 12 ? If nothing significant, it seems logical to embed libtorch which supports more hardware and for the reason listed by @sbrunk. One problem I can think of is if the cuda presets is updated to a version and libtorch is not available with this cuda version, or the other way around. |
Like I mentioned previously many times, LibTorch links statically with MKL, so it can't be used together with other libraries that use BLAS and LAPACK like GSL, OpenCV, NumPy, SciPy, Smile, etc |
Ok. Maybe can we use libtorch_cuda only from provided binaries and compile libtorch_cpu ? I'm done with what I planned to add to this PR. |
The remaining error during build on Windows is:
This is related to issue bytedeco/javacpp#720. Else we can skip the single function needing this function pointer: |
Sounds like a bug in MSVC 2019:
If it doesn't look like an important function, we can do that, sure. |
Where can we get libtorch_cuda for CUDA 12.3? |
According to https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/ we could run with 12.3 a binary compiled for 12.1 ? |
CUDA doesn't maintain backward compatibility of the ABI, even for minor version upgrades.
|
Have you seen the "compatibility" paragraph on the link above ? It seems to be new since CUDA 11. |
Well, try it and tell me if it works! I don't think it does. NVIDIA says many things, but reality is often different. |
I'm trying to test this locally, but I'm running into issues.
Fresh pytorch clone from the cppbuild. There's a |
Have you run the cppbuid ? |
Yes sorry should have mentioned that. I'm running Update I re-checked the native build and I think I did not put CUDA in the right place, sorry for the noise. |
I managed to build it locally now, and the Storch tests are passing (with minimal changes) with 2.1. :) Testing with CUDA is causing issues though. My current assumption is that we might need to extend the CUDA arch list to support Ampere GPUs. See sbrunk/storch#62 |
I did some more testing. All tests were done on an Ampere GPU with compute capability 8.6 Ubuntu 22.04 with the latest Nvidia driver currently available in the official package sources which is 535.104.12 which means CUDA driver version 12.2
To make sure that the correct native libraries are chosen I always checked the symlinks in So it looks like building against 12.1 makes it work again, and is compatible with 12.3 at runtime. I still need to verify this for 2.0.1 |
Thanks for reporting all these tests.
That explains the whole thing. Hopefully upstream libtorch_cuda and libtorch_cuda_linalg will work with JavaCPP build of other libs and we could get rid of PTX nightmares. |
It looks like either the GitHub runners or NVCC from CUDA 12.3 are a bit faster. In any case, we have an extra build hour that we could use for 8.0 arch. |
I tried to replace libtorch_cuda in JavaCPP pytorch with "official" binary libtorch_cuda, but this binary is linked with a a libcudart provided in the archive and this would need to patch the elf and the dll files to link with JavaCPP cuda. Maybe there is another option: copy cubins from a library file to another, but I don't know how to do it. @saudet You might as well replace |
@sbrunk Apart from that, anything else that needs to be fixed? |
A lot of those "changes" is simply because you're changing around the parsing order. If there are no good reasons to change the order, please revert the order to how it was before. It makes it hard to see the differences. |
|
Where is that script? |
Just added to the repo |
Ok, so can you please run that against master and make sure nothing changes? If something changes, please fix it so that nothing changes. |
Included in this PR:
functions
Tensor.item_bool
andTensor.item_byte
Tensor.data_ptr_byte
andTensor.data_ptr_bool
CUDACachingAllocator
Example<Tensor, Tensor>
andExample<Tensor,NoTarget>
ArrayRef
arguments of primitive types accept primitive Java array or variadic.register_module
generic and its return type the class of the registered module.