diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 88a904b3c3c..14853cd991e 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -51,6 +51,7 @@ dpnp_det, dpnp_eigh, dpnp_inv, + dpnp_pinv, dpnp_qr, dpnp_slogdet, dpnp_solve, @@ -69,6 +70,7 @@ "matrix_rank", "multi_dot", "norm", + "pinv", "qr", "solve", "svd", @@ -474,6 +476,56 @@ def multi_dot(arrays, out=None): return result +def pinv(a, rcond=1e-15, hermitian=False): + """ + Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + Calculate the generalized inverse of a matrix using its + singular-value decomposition (SVD) and including all large singular values. + + For full documentation refer to :obj:`numpy.linalg.inv`. + + Parameters + ---------- + a : (..., M, N) {dpnp.ndarray, usm_ndarray} + Matrix or stack of matrices to be pseudo-inverted. + rcond : {float, dpnp.ndarray, usm_ndarray}, optional + Cutoff for small singular values. + Singular values less than or equal to ``rcond * largest_singular_value`` + are set to zero. Broadcasts against the stack of matrices. + Default: ``1e-15``. + hermitian : bool, optional + If ``True``, a is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Default: ``False``. + + Returns + ------- + out : (..., N, M) dpnp.ndarray + The pseudo-inverse of a. + + Examples + -------- + The following example checks that ``a * a+ * a == a`` and + ``a+ * a * a+ == a+``: + + >>> import dpnp as np + >>> a = np.random.randn(9, 6) + >>> B = np.linalg.pinv(a) + >>> np.allclose(a, np.dot(a, np.dot(B, a))) + array([ True]) + >>> np.allclose(B, np.dot(B, np.dot(a, B))) + array([ True]) + + """ + + dpnp.check_supported_arrays_type(a) + dpnp.check_supported_arrays_type(rcond, scalar_type=True) + check_stacked_2d(a) + + return dpnp_pinv(a, rcond=rcond, hermitian=hermitian) + + def norm(x1, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index a6dcfbf0c2b..b92dcae0f47 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -39,6 +39,7 @@ "dpnp_det", "dpnp_eigh", "dpnp_inv", + "dpnp_pinv", "dpnp_qr", "dpnp_slogdet", "dpnp_solve", @@ -998,6 +999,46 @@ def dpnp_inv(a): return b_f +def dpnp_pinv(a, rcond=1e-15, hermitian=False): + """ + dpnp_pinv(a, rcond=1e-15, hermitian=False): + + Compute the Moore-Penrose pseudoinverse of `a` matrix. + + It computes a pseudoinverse of a matrix `a`, which is a generalization + of the inverse matrix with Singular Value Decomposition (SVD). + + """ + + if a.size == 0: + m, n = a.shape[-2:] + if m == 0 or n == 0: + res_type = a.dtype + else: + res_type = _common_type(a) + return dpnp.empty_like(a, shape=(a.shape[:-2] + (n, m)), dtype=res_type) + + if dpnp.is_supported_array_type(rcond): + # Check that `a` and `rcond` are allocated on the same device + # and have the same queue. Otherwise, `ValueError`` will be raised. + get_usm_allocations([a, rcond]) + else: + # Allocate dpnp.ndarray if rcond is a scalar + rcond = dpnp.array(rcond, usm_type=a.usm_type, sycl_queue=a.sycl_queue) + + u, s, vt = dpnp_svd(a.conj(), full_matrices=False, hermitian=hermitian) + + # discard small singular values + cutoff = rcond * dpnp.max(s, axis=-1) + leq = s <= cutoff[..., None] + dpnp.reciprocal(s, out=s) + s[leq] = 0 + + u = u.swapaxes(-2, -1) + dpnp.multiply(s[..., None], u, out=u) + return dpnp.matmul(vt.swapaxes(-2, -1), u) + + def dpnp_qr_batch(a, mode="reduced"): """ dpnp_qr_batch(a, mode="reduced") diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 8e32b867b85..5cf226762af 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -14,7 +14,7 @@ from .helper import ( assert_dtype_allclose, get_all_dtypes, - get_complex_dtypes, + get_float_complex_dtypes, has_support_aspect64, is_cpu_device, ) @@ -678,6 +678,11 @@ def test_norm3(array, ord, axis): class TestQr: + # Set numpy.random.seed for test methods to prevent + # random generation of the input singular matrix + def setup_method(self): + numpy.random.seed(81) + # TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI. # Skip the tests on cpu until these packages are available for the external CI. # Specifically dpcpp_linux-64>=2024.1.0 @@ -702,7 +707,9 @@ class TestQr: ids=["r", "raw", "complete", "reduced"], ) def test_qr(self, dtype, shape, mode): - a = numpy.random.rand(*shape).astype(dtype) + a = numpy.random.randn(*shape).astype(dtype) + if numpy.issubdtype(dtype, numpy.complexfloating): + a += 1j * numpy.random.randn(*shape) ia = inp.array(a) if mode == "r": @@ -772,7 +779,7 @@ def test_qr_empty(self, dtype, shape, mode): ids=["r", "raw", "complete", "reduced"], ) def test_qr_strides(self, mode): - a = numpy.random.rand(5, 5) + a = numpy.random.randn(5, 5) ia = inp.array(a) # positive strides @@ -1032,6 +1039,11 @@ def test_slogdet_errors(self): class TestSvd: + # Set numpy.random.seed for test methods to prevent + # random generation of the input singular matrix + def setup_method(self): + numpy.random.seed(81) + def get_tol(self, dtype): tol = 1e-06 if dtype in (inp.float32, inp.complex64): @@ -1121,18 +1133,19 @@ def test_svd(self, dtype, shape): dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True ) - @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) @pytest.mark.parametrize("compute_vt", [True, False], ids=["True", "False"]) @pytest.mark.parametrize( "shape", [(2, 2), (16, 16)], - ids=["(2,2)", "(16, 16)"], + ids=["(2, 2)", "(16, 16)"], ) def test_svd_hermitian(self, dtype, compute_vt, shape): - a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape) - a = numpy.conj(a.T) @ a + a = numpy.random.randn(*shape).astype(dtype) + if numpy.issubdtype(dtype, numpy.complexfloating): + a += 1j * numpy.random.randn(*shape) + a = (a + a.conj().T) / 2 - a = a.astype(dtype) dp_a = inp.array(a) if compute_vt: @@ -1167,3 +1180,155 @@ def test_svd_errors(self): # a.ndim < 2 a_dp_ndim_1 = a_dp.flatten() assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1) + + +class TestPinv: + # Set numpy.random.seed for test methods to prevent + # random generation of the input singular matrix + def setup_method(self): + numpy.random.seed(81) + + def get_tol(self, dtype): + tol = 1e-06 + if dtype in (inp.float32, inp.complex64): + tol = 1e-04 + elif not has_support_aspect64() and dtype in ( + inp.int32, + inp.int64, + None, + ): + tol = 1e-04 + self._tol = tol + + def check_types_shapes(self, dp_B, np_B): + if has_support_aspect64(): + assert dp_B.dtype == np_B.dtype + else: + assert dp_B.dtype.kind == np_B.dtype.kind + + assert dp_B.shape == np_B.shape + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape", + [(2, 2), (3, 4), (5, 3), (16, 16), (2, 2, 2), (2, 4, 2), (2, 2, 4)], + ids=[ + "(2, 2)", + "(3, 4)", + "(5, 3)", + "(16, 16)", + "(2, 2, 2)", + "(2, 4, 2)", + "(2, 2, 4)", + ], + ) + def test_pinv(self, dtype, shape): + a = numpy.random.randn(*shape).astype(dtype) + if numpy.issubdtype(dtype, numpy.complexfloating): + a += 1j * numpy.random.randn(*shape) + a_dp = inp.array(a) + + B = numpy.linalg.pinv(a) + B_dp = inp.linalg.pinv(a_dp) + + self.check_types_shapes(B_dp, B) + self.get_tol(dtype) + tol = self._tol + assert_allclose(B_dp, B, rtol=tol, atol=tol) + + if a.ndim == 2: + reconstructed = inp.dot(a_dp, inp.dot(B_dp, a_dp)) + else: # a.ndim > 2 + reconstructed = inp.matmul(a_dp, inp.matmul(B_dp, a_dp)) + + assert_allclose(reconstructed, a_dp, rtol=tol, atol=tol) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize( + "shape", + [(2, 2), (16, 16)], + ids=["(2, 2)", "(16, 16)"], + ) + def test_pinv_hermitian(self, dtype, shape): + a = numpy.random.randn(*shape).astype(dtype) + if numpy.issubdtype(dtype, numpy.complexfloating): + a += 1j * numpy.random.randn(*shape) + a = (a + a.conj().T) / 2 + + a_dp = inp.array(a) + + B = numpy.linalg.pinv(a, hermitian=True) + B_dp = inp.linalg.pinv(a_dp, hermitian=True) + + self.check_types_shapes(B_dp, B) + self.get_tol(dtype) + tol = self._tol + + reconstructed = inp.dot(inp.dot(a_dp, B_dp), a_dp) + assert_allclose(reconstructed, a_dp, rtol=tol, atol=tol) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape", + [(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)], + ids=[ + "(0, 0)", + "(0, 2)", + "(2 ,0)", + "(2, 0, 3)", + "(2, 3, 0)", + "(0, 2, 3)", + ], + ) + def test_pinv_empty(self, dtype, shape): + a = numpy.empty(shape, dtype=dtype) + a_dp = inp.array(a) + + B = numpy.linalg.pinv(a) + B_dp = inp.linalg.pinv(a_dp) + + assert_dtype_allclose(B_dp, B) + + def test_pinv_strides(self): + a = numpy.random.randn(5, 5) + a_dp = inp.array(a) + + self.get_tol(a_dp.dtype) + tol = self._tol + + # positive strides + B = numpy.linalg.pinv(a[::2, ::2]) + B_dp = inp.linalg.pinv(a_dp[::2, ::2]) + assert_allclose(B_dp, B, rtol=tol, atol=tol) + + # negative strides + B = numpy.linalg.pinv(a[::-2, ::-2]) + B_dp = inp.linalg.pinv(a_dp[::-2, ::-2]) + assert_allclose(B_dp, B, rtol=tol, atol=tol) + + def test_pinv_errors(self): + a_dp = inp.array([[1, 2], [3, 4]], dtype="float32") + + # unsupported type `a` + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.pinv, a_np) + + # unsupported type `rcond` + rcond = numpy.array(0.5, dtype="float32") + assert_raises(TypeError, inp.linalg.pinv, a_dp, rcond) + assert_raises(TypeError, inp.linalg.pinv, a_dp, [0.5]) + + # non-broadcastable `rcond` + rcond_dp = inp.array([0.5], dtype="float32") + assert_raises(ValueError, inp.linalg.pinv, a_dp, rcond_dp) + + # a.ndim < 2 + a_dp_ndim_1 = a_dp.flatten() + assert_raises(inp.linalg.LinAlgError, inp.linalg.pinv, a_dp_ndim_1) + + # diffetent queue + a_queue = dpctl.SyclQueue() + rcond_queue = dpctl.SyclQueue() + a_dp_q = inp.array(a_dp, sycl_queue=a_queue) + rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue) + assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 479e96e0229..888891d80f6 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1665,3 +1665,64 @@ def test_slogdet(shape, is_empty, device): assert_sycl_queue_equal(sign_queue, dpnp_x.sycl_queue) assert_sycl_queue_equal(logdet_queue, dpnp_x.sycl_queue) + + +@pytest.mark.parametrize( + "shape, hermitian, rcond_as_array", + [ + ((4, 4), False, False), + ((4, 4), False, True), + ((2, 0), False, False), + ((4, 4), True, False), + ((4, 4), True, True), + ((2, 2, 3), False, False), + ((2, 2, 3), False, True), + ((0, 2, 3), False, False), + ((1, 0, 3), False, False), + ], + ids=[ + "(4, 4)", + "(4, 4), rcond_as_array", + "(2, 0)", + "(2, 2), hermitian)", + "(2, 2), hermitian, rcond_as_array)", + "(2, 2, 3)", + "(2, 2, 3), rcond_as_array", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_pinv(shape, hermitian, rcond_as_array, device): + numpy.random.seed(81) + if hermitian: + a_np = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape) + a_np = numpy.conj(a_np.T) @ a_np + else: + a_np = numpy.random.randn(*shape) + + a_dp = dpnp.array(a_np, device=device) + + if rcond_as_array: + rcond_np = numpy.array(1e-15) + rcond_dp = dpnp.array(1e-15, device=device) + + B_result = dpnp.linalg.pinv(a_dp, rcond=rcond_dp, hermitian=hermitian) + B_expected = numpy.linalg.pinv( + a_np, rcond=rcond_np, hermitian=hermitian + ) + + else: + # rcond == 1e-15 by default + B_result = dpnp.linalg.pinv(a_dp, hermitian=hermitian) + B_expected = numpy.linalg.pinv(a_np, hermitian=hermitian) + + assert_allclose(B_expected, B_result, rtol=1e-3, atol=1e-4) + + B_queue = B_result.sycl_queue + + assert_sycl_queue_equal(B_queue, a_dp.sycl_queue) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 21dfb3cde67..43f526ebcb4 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -829,6 +829,40 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param): assert x.usm_type == s.usm_type +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize( + "shape, hermitian", + [ + ((4, 4), False), + ((2, 0), False), + ((4, 4), True), + ((2, 2, 3), False), + ((0, 2, 3), False), + ((1, 0, 3), False), + ], + ids=[ + "(4, 4)", + "(2, 0)", + "(2, 2), hermitian)", + "(2, 2, 3)", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +def test_pinv(shape, hermitian, usm_type): + numpy.random.seed(81) + if hermitian: + a = dp.random.randn(*shape) + 1j * dp.random.randn(*shape) + a = dp.conj(a.T) @ a + else: + a = dp.random.randn(*shape) + + a = dp.array(a, usm_type=usm_type) + B = dp.linalg.pinv(a, hermitian=hermitian) + + assert a.usm_type == B.usm_type + + @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) @pytest.mark.parametrize( "shape", diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index cd397f6c9e1..d0d7cc295a7 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -166,3 +166,45 @@ def test_batched_inv(self, dtype): assert a.ndim >= 3 # CuPy internally uses a batched function. with pytest.raises(xp.linalg.LinAlgError): xp.linalg.inv(a) + + +class TestPinv(unittest.TestCase): + @testing.for_dtypes("ifdFD") + @_condition.retry(10) + def check_x(self, a_shape, rcond, dtype): + a_gpu = testing.shaped_random(a_shape, dtype=dtype) + a_cpu = cupy.asnumpy(a_gpu) + a_gpu_copy = a_gpu.copy() + if not isinstance(rcond, float): + rcond = numpy.asarray(rcond) + result_cpu = numpy.linalg.pinv(a_cpu, rcond=rcond) + if not isinstance(rcond, float): + rcond = cupy.asarray(rcond) + result_gpu = cupy.linalg.pinv(a_gpu, rcond=rcond) + + assert_dtype_allclose(result_gpu, result_cpu) + testing.assert_array_equal(a_gpu_copy, a_gpu) + + def test_pinv(self): + self.check_x((3, 3), rcond=1e-15) + self.check_x((2, 4), rcond=1e-15) + self.check_x((3, 2), rcond=1e-15) + + self.check_x((4, 4), rcond=0.3) + self.check_x((2, 5), rcond=0.5) + self.check_x((5, 3), rcond=0.6) + + def test_pinv_batched(self): + self.check_x((2, 3, 4), rcond=1e-15) + self.check_x((2, 3, 4, 5), rcond=1e-15) + + def test_pinv_batched_vector_rcond(self): + self.check_x((2, 3, 4), rcond=[0.2, 0.8]) + self.check_x((2, 3, 4, 5), rcond=[[0.2, 0.9, 0.1], [0.7, 0.2, 0.5]]) + + def test_pinv_size_0(self): + self.check_x((3, 0), rcond=1e-15) + self.check_x((0, 3), rcond=1e-15) + self.check_x((0, 0), rcond=1e-15) + self.check_x((0, 2, 3), rcond=1e-15) + self.check_x((2, 0, 3), rcond=1e-15)