Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dpnp.linalg.pinv() implementation #1704

Merged
merged 24 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1950828
Impl dpnp.linalg.pinv()
vlad-perevezentsev Feb 8, 2024
095da9a
Add cupy tests for dpnp.linalg.pinv()
vlad-perevezentsev Feb 8, 2024
cb22462
Merge master into impl_pinv
vlad-perevezentsev Feb 8, 2024
b716704
Add tests to test_sycl_queue and test_usm_type
vlad-perevezentsev Feb 9, 2024
d0bcbbd
Add TestPinv to test_linalg.py
vlad-perevezentsev Feb 9, 2024
0293a3f
Merge master into impl_pinv
vlad-perevezentsev Feb 9, 2024
b7ee111
Address remarks
vlad-perevezentsev Feb 12, 2024
7f03075
Add additional checks for rcond parameter
vlad-perevezentsev Feb 12, 2024
34aa37c
Add a more efficient implementation
vlad-perevezentsev Feb 12, 2024
0766347
Update test_sycl_queue
vlad-perevezentsev Feb 12, 2024
bad4ad1
Update TestPinv in test_linalg.py
vlad-perevezentsev Feb 12, 2024
826610b
Merge master into impl_pinv
vlad-perevezentsev Feb 12, 2024
358216a
Merge branch 'master' into impl_pinv
antonwolfy Feb 12, 2024
35dede1
Update tests/test_usm_type.py
antonwolfy Feb 12, 2024
e7b899a
Update test_pinv_hermitian
vlad-perevezentsev Feb 13, 2024
a7bb8e0
Update test_svd_hermitian
vlad-perevezentsev Feb 13, 2024
f9f8010
Merge origin/impl_pinv into impl_pinv
vlad-perevezentsev Feb 13, 2024
4f72d33
Merge master into impl_pinv
vlad-perevezentsev Feb 13, 2024
a198de0
Merge master into impl_pinv
vlad-perevezentsev Feb 14, 2024
16c292c
Use numpy.random.seed in test_linalg
vlad-perevezentsev Feb 14, 2024
c08ea5e
Use setup_method to rudece code duplication in test_linalg
vlad-perevezentsev Feb 14, 2024
a975fc4
Change the seed number
vlad-perevezentsev Feb 15, 2024
c9b8fff
Merge branch 'master' into impl_pinv
antonwolfy Feb 15, 2024
32f8f68
Use numpy.random.seed in test_usm_type and test_sycl_queue
vlad-perevezentsev Feb 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
dpnp_det,
dpnp_eigh,
dpnp_inv,
dpnp_pinv,
dpnp_qr,
dpnp_slogdet,
dpnp_solve,
Expand All @@ -69,6 +70,7 @@
"matrix_rank",
"multi_dot",
"norm",
"pinv",
"qr",
"solve",
"svd",
Expand Down Expand Up @@ -474,6 +476,56 @@ def multi_dot(arrays, out=None):
return result


def pinv(a, rcond=1e-15, hermitian=False):
"""
Compute the (Moore-Penrose) pseudo-inverse of a matrix.

Calculate the generalized inverse of a matrix using its
singular-value decomposition (SVD) and including all large singular values.

For full documentation refer to :obj:`numpy.linalg.inv`.

Parameters
----------
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Matrix or stack of matrices to be pseudo-inverted.
rcond : {float, dpnp.ndarray, usm_ndarray}, optional
Cutoff for small singular values.
Singular values less than or equal to ``rcond * largest_singular_value``
are set to zero. Broadcasts against the stack of matrices.
Default: ``1e-15``.
hermitian : bool, optional
If ``True``, a is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Default: ``False``.

Returns
-------
out : (..., N, M) dpnp.ndarray
The pseudo-inverse of a.

Examples
--------
The following example checks that ``a * a+ * a == a`` and
``a+ * a * a+ == a+``:

>>> import dpnp as np
>>> a = np.random.randn(9, 6)
>>> B = np.linalg.pinv(a)
>>> np.allclose(a, np.dot(a, np.dot(B, a)))
array([ True])
>>> np.allclose(B, np.dot(B, np.dot(a, B)))
array([ True])

"""

dpnp.check_supported_arrays_type(a)
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
check_stacked_2d(a)

return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)


def norm(x1, ord=None, axis=None, keepdims=False):
"""
Matrix or vector norm.
Expand Down
41 changes: 41 additions & 0 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"dpnp_det",
"dpnp_eigh",
"dpnp_inv",
"dpnp_pinv",
"dpnp_qr",
"dpnp_slogdet",
"dpnp_solve",
Expand Down Expand Up @@ -998,6 +999,46 @@ def dpnp_inv(a):
return b_f


def dpnp_pinv(a, rcond=1e-15, hermitian=False):
"""
dpnp_pinv(a, rcond=1e-15, hermitian=False):

Compute the Moore-Penrose pseudoinverse of `a` matrix.

It computes a pseudoinverse of a matrix `a`, which is a generalization
of the inverse matrix with Singular Value Decomposition (SVD).

"""

if a.size == 0:
m, n = a.shape[-2:]
if m == 0 or n == 0:
res_type = a.dtype
else:
res_type = _common_type(a)
return dpnp.empty_like(a, shape=(a.shape[:-2] + (n, m)), dtype=res_type)

if dpnp.is_supported_array_type(rcond):
# Check that `a` and `rcond` are allocated on the same device
# and have the same queue. Otherwise, `ValueError`` will be raised.
get_usm_allocations([a, rcond])
else:
# Allocate dpnp.ndarray if rcond is a scalar
rcond = dpnp.array(rcond, usm_type=a.usm_type, sycl_queue=a.sycl_queue)

u, s, vt = dpnp_svd(a.conj(), full_matrices=False, hermitian=hermitian)

# discard small singular values
cutoff = rcond * dpnp.max(s, axis=-1)
leq = s <= cutoff[..., None]
dpnp.reciprocal(s, out=s)
s[leq] = 0

u = u.swapaxes(-2, -1)
dpnp.multiply(s[..., None], u, out=u)
return dpnp.matmul(vt.swapaxes(-2, -1), u)


def dpnp_qr_batch(a, mode="reduced"):
"""
dpnp_qr_batch(a, mode="reduced")
Expand Down
181 changes: 173 additions & 8 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .helper import (
assert_dtype_allclose,
get_all_dtypes,
get_complex_dtypes,
get_float_complex_dtypes,
has_support_aspect64,
is_cpu_device,
)
Expand Down Expand Up @@ -678,6 +678,11 @@ def test_norm3(array, ord, axis):


class TestQr:
# Set numpy.random.seed for test methods to prevent
# random generation of the input singular matrix
def setup_method(self):
numpy.random.seed(81)

# TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI.
# Skip the tests on cpu until these packages are available for the external CI.
# Specifically dpcpp_linux-64>=2024.1.0
Expand All @@ -702,7 +707,9 @@ class TestQr:
ids=["r", "raw", "complete", "reduced"],
)
def test_qr(self, dtype, shape, mode):
a = numpy.random.rand(*shape).astype(dtype)
a = numpy.random.randn(*shape).astype(dtype)
if numpy.issubdtype(dtype, numpy.complexfloating):
a += 1j * numpy.random.randn(*shape)
ia = inp.array(a)

if mode == "r":
Expand Down Expand Up @@ -772,7 +779,7 @@ def test_qr_empty(self, dtype, shape, mode):
ids=["r", "raw", "complete", "reduced"],
)
def test_qr_strides(self, mode):
a = numpy.random.rand(5, 5)
a = numpy.random.randn(5, 5)
ia = inp.array(a)

# positive strides
Expand Down Expand Up @@ -1032,6 +1039,11 @@ def test_slogdet_errors(self):


class TestSvd:
# Set numpy.random.seed for test methods to prevent
# random generation of the input singular matrix
def setup_method(self):
numpy.random.seed(81)

def get_tol(self, dtype):
tol = 1e-06
if dtype in (inp.float32, inp.complex64):
Expand Down Expand Up @@ -1121,18 +1133,19 @@ def test_svd(self, dtype, shape):
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True
)

@pytest.mark.parametrize("dtype", get_complex_dtypes())
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize("compute_vt", [True, False], ids=["True", "False"])
@pytest.mark.parametrize(
"shape",
[(2, 2), (16, 16)],
ids=["(2,2)", "(16, 16)"],
ids=["(2, 2)", "(16, 16)"],
)
def test_svd_hermitian(self, dtype, compute_vt, shape):
a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
a = numpy.conj(a.T) @ a
a = numpy.random.randn(*shape).astype(dtype)
if numpy.issubdtype(dtype, numpy.complexfloating):
a += 1j * numpy.random.randn(*shape)
a = (a + a.conj().T) / 2

a = a.astype(dtype)
dp_a = inp.array(a)

if compute_vt:
Expand Down Expand Up @@ -1167,3 +1180,155 @@ def test_svd_errors(self):
# a.ndim < 2
a_dp_ndim_1 = a_dp.flatten()
assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1)


class TestPinv:
# Set numpy.random.seed for test methods to prevent
# random generation of the input singular matrix
def setup_method(self):
numpy.random.seed(81)

def get_tol(self, dtype):
tol = 1e-06
if dtype in (inp.float32, inp.complex64):
tol = 1e-04
elif not has_support_aspect64() and dtype in (
inp.int32,
inp.int64,
None,
):
tol = 1e-04
self._tol = tol

def check_types_shapes(self, dp_B, np_B):
if has_support_aspect64():
assert dp_B.dtype == np_B.dtype
else:
assert dp_B.dtype.kind == np_B.dtype.kind

assert dp_B.shape == np_B.shape

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
[(2, 2), (3, 4), (5, 3), (16, 16), (2, 2, 2), (2, 4, 2), (2, 2, 4)],
ids=[
"(2, 2)",
"(3, 4)",
"(5, 3)",
"(16, 16)",
"(2, 2, 2)",
"(2, 4, 2)",
"(2, 2, 4)",
],
)
def test_pinv(self, dtype, shape):
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
a = numpy.random.randn(*shape).astype(dtype)
if numpy.issubdtype(dtype, numpy.complexfloating):
a += 1j * numpy.random.randn(*shape)
a_dp = inp.array(a)

B = numpy.linalg.pinv(a)
B_dp = inp.linalg.pinv(a_dp)

self.check_types_shapes(B_dp, B)
self.get_tol(dtype)
tol = self._tol
assert_allclose(B_dp, B, rtol=tol, atol=tol)

if a.ndim == 2:
reconstructed = inp.dot(a_dp, inp.dot(B_dp, a_dp))
else: # a.ndim > 2
reconstructed = inp.matmul(a_dp, inp.matmul(B_dp, a_dp))

assert_allclose(reconstructed, a_dp, rtol=tol, atol=tol)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize(
"shape",
[(2, 2), (16, 16)],
ids=["(2, 2)", "(16, 16)"],
)
def test_pinv_hermitian(self, dtype, shape):
a = numpy.random.randn(*shape).astype(dtype)
if numpy.issubdtype(dtype, numpy.complexfloating):
a += 1j * numpy.random.randn(*shape)
a = (a + a.conj().T) / 2

a_dp = inp.array(a)

B = numpy.linalg.pinv(a, hermitian=True)
B_dp = inp.linalg.pinv(a_dp, hermitian=True)

self.check_types_shapes(B_dp, B)
self.get_tol(dtype)
tol = self._tol

reconstructed = inp.dot(inp.dot(a_dp, B_dp), a_dp)
assert_allclose(reconstructed, a_dp, rtol=tol, atol=tol)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
[(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)],
ids=[
"(0, 0)",
"(0, 2)",
"(2 ,0)",
"(2, 0, 3)",
"(2, 3, 0)",
"(0, 2, 3)",
],
)
def test_pinv_empty(self, dtype, shape):
a = numpy.empty(shape, dtype=dtype)
a_dp = inp.array(a)

B = numpy.linalg.pinv(a)
B_dp = inp.linalg.pinv(a_dp)

assert_dtype_allclose(B_dp, B)

def test_pinv_strides(self):
a = numpy.random.randn(5, 5)
a_dp = inp.array(a)

self.get_tol(a_dp.dtype)
tol = self._tol

# positive strides
B = numpy.linalg.pinv(a[::2, ::2])
B_dp = inp.linalg.pinv(a_dp[::2, ::2])
assert_allclose(B_dp, B, rtol=tol, atol=tol)

# negative strides
B = numpy.linalg.pinv(a[::-2, ::-2])
B_dp = inp.linalg.pinv(a_dp[::-2, ::-2])
assert_allclose(B_dp, B, rtol=tol, atol=tol)

def test_pinv_errors(self):
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")

# unsupported type `a`
a_np = inp.asnumpy(a_dp)
assert_raises(TypeError, inp.linalg.pinv, a_np)

# unsupported type `rcond`
rcond = numpy.array(0.5, dtype="float32")
assert_raises(TypeError, inp.linalg.pinv, a_dp, rcond)
assert_raises(TypeError, inp.linalg.pinv, a_dp, [0.5])

# non-broadcastable `rcond`
rcond_dp = inp.array([0.5], dtype="float32")
assert_raises(ValueError, inp.linalg.pinv, a_dp, rcond_dp)

# a.ndim < 2
a_dp_ndim_1 = a_dp.flatten()
assert_raises(inp.linalg.LinAlgError, inp.linalg.pinv, a_dp_ndim_1)

# diffetent queue
a_queue = dpctl.SyclQueue()
rcond_queue = dpctl.SyclQueue()
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)
Loading
Loading