Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return named tuple for linalg functions per python array API #2276

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@ array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
array_api_tests/test_signatures.py::test_info_func_signature[devices]
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]

# do not return a namedtuple
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_linalg.py::test_svd

# hypothesis found failures
array_api_tests/test_linalg.py::test_qr
array_api_tests/test_operators_and_elementwise_functions.py::test_clip

# unexpected result is returned
Expand Down
107 changes: 67 additions & 40 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@
# pylint: disable=invalid-name
# pylint: disable=no-member

from typing import NamedTuple

import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple

import dpnp

from .dpnp_utils_linalg import (
EighResult,
QRResult,
SlogdetResult,
SVDResult,
assert_2d,
assert_stacked_2d,
assert_stacked_square,
Expand All @@ -66,6 +72,11 @@
)

__all__ = [
"EigResult",
"EighResult",
"QRResult",
"SlogdetResult",
"SVDResult",
"cholesky",
"cond",
"cross",
Expand Down Expand Up @@ -100,6 +111,12 @@
]


# pylint:disable=missing-class-docstring
class EigResult(NamedTuple):
eigenvalues: dpnp.ndarray
eigenvectors: dpnp.ndarray


def cholesky(a, /, *, upper=False):
"""
Cholesky decomposition.
Expand Down Expand Up @@ -451,17 +468,18 @@ def eig(a):

Returns
-------
A namedtuple with the following attributes:

eigenvalues : (..., M) dpnp.ndarray
The eigenvalues, each repeated according to its multiplicity.
The eigenvalues are not necessarily ordered. The resulting
array will be of complex type, unless the imaginary part is
zero in which case it will be cast to a real type. When `a`
is real the resulting eigenvalues will be real (0 imaginary
part) or occur in conjugate pairs
The eigenvalues are not necessarily ordered. The resulting array will
be of complex type, unless the imaginary part is zero in which case it
will be cast to a real type. When `a` is real the resulting eigenvalues
will be real (zero imaginary part) or occur in conjugate pairs.
eigenvectors : (..., M, M) dpnp.ndarray
The normalized (unit "length") eigenvectors, such that the
column ``v[:,i]`` is the eigenvector corresponding to the
eigenvalue ``w[i]``.
The normalized (unit "length") eigenvectors, such that the column
``eigenvectors[:,i]`` is the eigenvector corresponding to the
eigenvalue ``eigenvalues[i]``.

Note
----
Expand Down Expand Up @@ -532,7 +550,7 @@ def eig(a):
# Since geev function from OneMKL LAPACK is not implemented yet,
# use NumPy for this calculation.
w_np, v_np = numpy.linalg.eig(dpnp.asnumpy(a))
return (
return EigResult(
dpnp.array(w_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
dpnp.array(v_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
)
Expand Down Expand Up @@ -565,12 +583,14 @@ def eigh(a, UPLO="L"):

Returns
-------
w : (..., M) dpnp.ndarray
The eigenvalues in ascending order, each repeated according to
its multiplicity.
v : (..., M, M) dpnp.ndarray
The column ``v[:, i]`` is the normalized eigenvector corresponding
to the eigenvalue ``w[i]``.
A namedtuple with the following attributes:

eigenvalues : (..., M) dpnp.ndarray
The eigenvalues in ascending order, each repeated according to its
multiplicity.
eigenvectors : (..., M, M) dpnp.ndarray
The column ``eigenvectors[:, i]`` is the normalized eigenvector
corresponding to the eigenvalue ``eigenvalues[i]``.

See Also
--------
Expand Down Expand Up @@ -644,7 +664,7 @@ def eigvals(a):
Illustration, using the fact that the eigenvalues of a diagonal matrix
are its diagonal elements, that multiplying a matrix on the left
by an orthogonal matrix, `Q`, and on the right by `Q.T` (the transpose
of `Q`), preserves the eigenvalues of the "middle" matrix. In other words,
of `Q`), preserves the eigenvalues of the "middle" matrix. In other words,
if `Q` is orthogonal, then ``Q * A * Q.T`` has the same eigenvalues as
``A``:

Expand Down Expand Up @@ -839,7 +859,7 @@ def lstsq(a, b, rcond=None):
gradient of roughly 1 and cut the y-axis at, more or less, -1.

We can rewrite the line equation as ``y = Ap``, where ``A = [[x 1]]``
and ``p = [[m], [c]]``. Now use `lstsq` to solve for `p`:
and ``p = [[m], [c]]``. Now use `lstsq` to solve for `p`:

>>> A = np.vstack([x, np.ones(len(x))]).T
>>> A
Expand Down Expand Up @@ -1252,7 +1272,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
Input array. If `axis` is ``None``, `x` must be 1-D or 2-D, unless
Input array. If `axis` is ``None``, `x` must be 1-D or 2-D, unless
`ord` is ``None``. If both `axis` and `ord` are ``None``, the 2-norm
of ``x.ravel`` will be returned.
ord : {int, float, inf, -inf, "fro", "nuc"}, optional
Expand Down Expand Up @@ -1557,20 +1577,22 @@ def qr(a, mode="reduced"):
Returns
-------
When mode is "reduced" or "complete", the result will be a namedtuple with
the attributes Q and R.
Q : dpnp.ndarray
the attributes `Q` and `R`.

Q : dpnp.ndarray of float or complex, optional
A matrix with orthonormal columns.
When mode = "complete" the result is an orthogonal/unitary matrix
depending on whether or not a is real/complex.
The determinant may be either +/- 1 in that case.
In case the number of dimensions in the input array is greater
than 2 then a stack of the matrices with above properties is returned.
R : dpnp.ndarray
The upper-triangular matrix or a stack of upper-triangular matrices
if the number of dimensions in the input array is greater than 2.
(h, tau) : tuple of dpnp.ndarray
The `h` array contains the Householder reflectors that generate Q along
with R. The `tau` array contains scaling factors for the reflectors.
When mode is ``"complete"`` the result is an orthogonal/unitary matrix
depending on whether or not `a` is real/complex. The determinant may be
either ``+/- 1`` in that case. In case the number of dimensions in the
input array is greater than 2 then a stack of the matrices with above
properties is returned.
R : dpnp.ndarray of float or complex, optional
The upper-triangular matrix or a stack of upper-triangular matrices if
the number of dimensions in the input array is greater than 2.
(h, tau) : tuple of dpnp.ndarray of float or complex, optional
The array `h` contains the Householder reflectors that generate `Q`
along with `R`. The `tau` array contains scaling factors for the
reflectors.

Examples
--------
Expand Down Expand Up @@ -1709,22 +1731,25 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):

Returns
-------
u : { (…, M, M), (…, M, K) } dpnp.ndarray
When `compute_uv` is ``True``, the result is a namedtuple with the
following attribute names:

U : { (…, M, M), (…, M, K) } dpnp.ndarray
Unitary matrix, where M is the number of rows of the input array `a`.
The shape of the matrix `u` depends on the value of `full_matrices`.
If `full_matrices` is ``True``, `u` has the shape (…, M, M).
If `full_matrices` is ``False``, `u` has the shape (…, M, K), where
K = min(M, N), and N is the number of columns of the input array `a`.
If `compute_uv` is ``False``, neither `u` or `Vh` are computed.
s : (…, K) dpnp.ndarray
The shape of the matrix `U` depends on the value of `full_matrices`.
If `full_matrices` is ``True``, `U` has the shape (…, M, M). If
`full_matrices` is ``False``, `U` has the shape (…, M, K), where
``K = min(M, N)``, and N is the number of columns of the input array
`a`. If `compute_uv` is ``False``, neither `U` or `Vh` are computed.
S : (…, K) dpnp.ndarray
Vector containing the singular values of `a`, sorted in descending
order. The length of `s` is min(M, N).
order. The length of `S` is min(M, N).
Vh : { (…, N, N), (…, K, N) } dpnp.ndarray
Unitary matrix, where N is the number of columns of the input array `a`.
The shape of the matrix `Vh` depends on the value of `full_matrices`.
If `full_matrices` is ``True``, `Vh` has the shape (…, N, N).
If `full_matrices` is ``False``, `Vh` has the shape (…, K, N).
If `compute_uv` is ``False``, neither `u` or `Vh` are computed.
If `compute_uv` is ``False``, neither `U` or `Vh` are computed.

Examples
--------
Expand Down Expand Up @@ -1852,6 +1877,8 @@ def slogdet(a):

Returns
-------
A namedtuple with the following attributes:

sign : (...) dpnp.ndarray
A number representing the sign of the determinant. For a real matrix,
this is 1, 0, or -1. For a complex matrix, this is a complex number
Expand Down
Loading
Loading