From e804bd59da0972b350b12d31735aec0e3db9e4e3 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 16 Sep 2021 18:43:04 +0800 Subject: [PATCH] fix dynload for cufft on windows (#51) 1. fix dynload for cufft on windows; 2. fix unittests. --- paddle/fluid/platform/dynload/dynamic_loader.cc | 7 ++++++- python/paddle/fluid/tests/unittests/fft/test_fft.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index d43e901371582..629a50561d9c6 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll"; #else static constexpr char* win_curand_lib = "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR @@ -122,7 +125,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cufft_lib = "cufft64_" CUDA_MAJOR_VERSION ".dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll"; #endif // CUDA_VERSION #endif diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index d563b3bed5204..1b4bcb709331c 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -534,8 +534,8 @@ def test_hfft2(self): [('test_n_nagative', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), (-2, -1), 'backward', ValueError), \ - ('test_n_equal_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), + ('test_zero_point', + np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None, (-2, -1), "backward", ValueError), \ ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0), (-2, -1), 'backward', ValueError), \