Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement polar decomposition #1697

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
850bb62
added first drafs regarding PD etc.
Oct 28, 2024
a253ac7
added file for tests
Oct 28, 2024
32d8f32
first draft of condition estimation, split=1 still with bug in solve_…
Oct 29, 2024
13607a8
Merge branch 'main' into features/1696-Implement_Polar_Decomposition
mrfh92 Oct 29, 2024
cd3c007
added condest to linalg.basics, including tests
Oct 31, 2024
84bed5e
Merge branch 'features/1696-Implement_Polar_Decomposition' of github.…
Oct 31, 2024
fd7dbd4
Merge branch 'main' into features/1696-Implement_Polar_Decomposition
mrfh92 Oct 31, 2024
884d179
removed bug
Oct 31, 2024
c051391
test coverage for uncovered line
Oct 31, 2024
ac37f76
Batched QR, update of unit tests is missing so far
Oct 31, 2024
b879471
Merge branch 'main' into features/1707-Batched_QR
Nov 14, 2024
4f67a93
removed old tests that threw errors for batched inputs
Nov 14, 2024
cec049c
final debugging of tests
Nov 14, 2024
4fffb39
added changes to docs
Nov 14, 2024
7d37a40
dummy change for benchmarking run
Nov 15, 2024
6b2a18f
Merge branch 'main' into features/1696-Implement_Polar_Decomposition
Nov 15, 2024
5b4b801
Merge branch 'features/1707-Batched_QR' into features/1696-Implement_…
Nov 15, 2024
d918d2c
started with ZoloPD
Nov 15, 2024
286104e
implementation of ZoloPD + tests
Nov 15, 2024
31db459
added seeds for random in the tests to ensure reproducibility
Nov 18, 2024
efa65b7
removed file "ausprobieren.py"
Nov 18, 2024
a674e2c
final clean up
Nov 18, 2024
8107cca
...
Dec 3, 2024
704d8f0
created branch for QR in case split=0 and non-tall-skinny matrices
Dec 3, 2024
7580552
...
Dec 4, 2024
b087c28
Merge branch 'main' into features/1696-Implement_Polar_Decomposition
mrfh92 Dec 9, 2024
7dbfcbe
Merge branch 'features/1696-Implement_Polar_Decomposition' of github.…
Dec 9, 2024
6430afd
QR for split=0 and non tall-skinny data
Dec 9, 2024
7b26be1
Merge branch 'main' into features/1736-QR_for_non-tall-skinny_matrice…
mrfh92 Dec 9, 2024
3a77070
debugging
Dec 10, 2024
b6bd730
Merge branch 'features/1736-QR_for_non-tall-skinny_matrices_and_split…
Dec 10, 2024
3be5913
Merge branch 'features/1736-QR_for_non-tall-skinny_matrices_and_split…
Dec 10, 2024
159cb18
took new qr for split=0 into account
Dec 10, 2024
505fd4b
Merge branch 'main' into features/1696-Implement_Polar_Decomposition
mrfh92 Dec 10, 2024
d308bdf
added random seed
Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions heat/core/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .qr import *
from .svdtools import *
from .svd import *
from .pd import *
114 changes: 114 additions & 0 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
from .. import statistics
from .. import stride_tricks
from .. import types
from ..random import randn
from .qr import qr
from .solver import solve_triangular

__all__ = [
"condest",
"cross",
"det",
"dot",
Expand All @@ -45,6 +49,116 @@
]


def _estimate_largest_singularvalue(A: DNDarray, algorithm: str = "fro") -> DNDarray:
"""
Computes an upper estimate for the largest singular value of the input 2D DNDarray.

Parameters
----------
A : DNDarray
The matrix, i.e., a 2D DNDarray, for which the largest singular value should be estimated.
algorithm : str
The algorithm to use for the estimation. Currently, only "fro" (default) is implemented.
If "fro" is chosen, the Frobenius norm of the matrix is used as an upper estimate.
"""
if not isinstance(algorithm, str):
raise TypeError(
f"Parameter 'algorithm' needs to be a string, but is {algorithm} with data type {type(algorithm)}."
)
if algorithm == "fro":
return matrix_norm(A, ord="fro").squeeze()
else:
raise NotImplementedError("So far only algorithm='fro' implemented.")


def condest(
A: DNDarray, p: Union[int, str] = None, algorithm: str = "randomized", params: list = None
) -> DNDarray:
"""
Computes a (possibly randomized) upper estimate the l2-condition number of the input 2D DNDarray.

Parameters
----------
A : DNDarray
The matrix, i.e., a 2D DNDarray, for which the condition number shall be estimated.
p : int or str (optional)
The norm to use for the condition number computation. If None, the l2-norm (default, p=2) is used.
So far, only p=2 is implemented.
algorithm : str
The algorithm to use for the estimation. Currently, only "randomized" (default) is implemented.
params : dict (optional)
A list of parameters required for the chosen algorithm; if not provided, default values for the respective algorithm are chosen.
If `algorithm="randomized"` the number of random samples to use can be specified under the key "nsamples"; default is 10.

Notes
----------
The "randomized" algorithm follows the approach described in [1]; note that in the paper actually the condition number w.r.t. the Frobenius norm is estimated.
However, this yields an upper bound for the condition number w.r.t. the l2-norm as well.

References
----------
[1] T. Gudmundsson, C. S. Kenney, and A. J. Laub. Small-Sample Statistical Estimates for Matrix Norms. SIAM Journal on Matrix Analysis and Applications 1995 16:3, 776-792.
"""
if p is None:
p = 2
if p != 2:
raise ValueError(
f"Only the case p=2 (condition number w.r.t. the euclidean norm) is implemented so far, but input was p={p} (type: {type(p)})."
)
if not isinstance(algorithm, str):
raise TypeError(
f"Parameter 'algorithm' needs to be a string, but is {algorithm} with data type {type(algorithm)}."
)
if algorithm == "randomized":
if params is None:
nsamples = 10 # set default value
else:
if not isinstance(params, dict) or "nsamples" not in params:
raise TypeError(
"If not None, 'params' needs to be a dictionary containing the number of samples under the key 'nsamples'."
)
if not isinstance(params["nsamples"], int) or params["nsamples"] <= 0:
raise ValueError(
f"The number of samples needs to be a positive integer, but is {params['nsamples']} with data type {type(params['nsamples'])}."
)
nsamples = params["nsamples"]

m = A.shape[0]
n = A.shape[1]

if n > m:
# the algorithm only works for m >= n, but fortunately, the condition number (w.r.t. l2-norm) is invariant under transposition
return condest(A.T, p=p, algorithm=algorithm, params=params)

_, R = qr(A, mode="r") # only R factor is computed in QR

# random samples from unit sphere
# regarding the split: if A.split == 1, then n is probably large and we should split along an axis of size n; otherwise, both n and nsamples should be small
Q, R_not_used = qr(
randn(
n,
nsamples,
dtype=A.dtype,
split=0 if A.split == 1 else None,
device=A.device,
comm=A.comm,
)
)
del R_not_used

est = (
matrix_norm(R @ Q)
* A.dtype((m / nsamples) ** 0.5, comm=A.comm)
* matrix_norm(solve_triangular(R, Q))
)

return est.squeeze()
else:
raise NotImplementedError(
"So far only algorithm='randomized' is implemented. Please open an issue on GitHub if you would like to suggest implementing another algorithm."
)


def cross(
a: DNDarray, b: DNDarray, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int = -1
) -> DNDarray:
Expand Down
237 changes: 237 additions & 0 deletions heat/core/linalg/pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
Implements polar decomposition (PD)
"""

import numpy as np
import collections
import torch
from typing import Type, Callable, Dict, Any, TypeVar, Union, Tuple

from ..communication import MPICommunication
from ..dndarray import DNDarray
from .. import factories
from .. import types
from ..linalg import matrix_norm, vector_norm, matmul, qr, solve_triangular
from .basics import _estimate_largest_singularvalue, condest
from ..indexing import where
from ..random import randn
from ..devices import Device
from ..manipulations import vstack, hstack, concatenate, diag, balance
from ..exponential import sqrt
from .. import statistics
from mpi4py import MPI

from scipy.special import ellipj
from scipy.special import ellipkm1

__all__ = ["pd"]


def _zolopd_n_iterations(r: int, kappa: float) -> int:
"""
Returns the number of iterations required in the Zolotarev-PD algorithm.
See the Table 3.1 in: Nakatsukasa, Y., & Freund, R. W. (2016). Computing the polar decomposition with applications. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334

Inputs are `r` and `kappa` (named as in the paper), output is the number of iterations.
"""
if kappa <= 1e2:
its = [4, 3, 2, 2, 2, 2, 2, 2]
elif kappa <= 1e3:
its = [3, 3, 2, 2, 2, 2, 2, 2]
elif kappa <= 1e5:
its = [5, 3, 3, 3, 2, 2, 2, 2]
elif kappa <= 1e7:
its = [5, 4, 3, 3, 3, 2, 2, 2]
else:
its = [6, 4, 3, 3, 3, 3, 3, 2]
return its[r - 1]


def _compute_zolotarev_coefficients(
r: int, ell: float, device: str, dtype: types.datatype = types.float64
) -> Tuple[DNDarray, DNDarray, types.datatype]:
"""
Computes c=(c_i)_i defined in equation (3.4), as well as a=(a_j)_j and Mhat defined in formulas (4.2)/(4.3) of the paper Nakatsukasa, Y., & Freund, R. W. (2016). Computing the polar decomposition with applications. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334.
Evaluations of the respective complete elliptic integral of the first kind and the Jacobi elliptic functions are imported from SciPy.

Inputs are `r` and `ell` (named as in the paper), as well as the Heat data type `dtype` of the output (required for reasons of consistency).
Output is a tupe containing the vectors `a` and `c` as DNDarrays and `Mhat`.
"""
uu = np.arange(1, 2 * r + 1) * ellipkm1(ell**2) / (2 * r + 1)
ellipfcts = np.asarray(ellipj(uu, 1 - ell**2)[:2])
cc = ell**2 * ellipfcts[0, :] ** 2 / ellipfcts[1, :] ** 2
aa = np.zeros(r)
Mhat = 1
for j in range(1, r + 1):
p1 = 1
p2 = 1
for k in range(1, r + 1):
p1 *= cc[2 * j - 2] - cc[2 * k - 1]
if k != j:
p2 *= cc[2 * j - 2] - cc[2 * k - 2]
aa[j - 1] = -p1 / p2
Mhat *= (1 + cc[2 * j - 2]) / (1 + cc[2 * j - 1])
return (
factories.array(cc, dtype=dtype, split=None, device=device),
factories.array(aa, dtype=dtype, split=None, device=device),
factories.array(Mhat, dtype=dtype, split=None, device=device),
)


def pd(
A: DNDarray,
r: int = 8,
calcH: bool = True,
condition_estimate: float = 0.0,
silent: bool = True,
) -> Tuple[DNDarray, DNDarray]:
"""
Computes the so-called polar decomposition of the input 2D DNDarray ``A``, i.e., it returns the orthogonal matrix ``U`` and the symmetric, positive definite
matrix ``H`` such that ``A = U @ H``.

Input
-----
A : ht.DNDarray,
The input matrix for which the polar decomposition is computed;
must be two-dimensional, of data type float32 or float64, and must have at least as many rows as columns.
r : int, optional, default: 8
The parameter r used in the Zolotarev-PD algorithm; must be an integer between 1 and 8.
Higher values of r lead to faster convergence, but memory consumption is proportional to r.
calcH : bool, optional, default: True
If True, the function returns the symmetric, positive definite matrix H. If False, only the orthogonal matrix U is returned.
condition_estimate : float, optional, default: 0.
This argument allows to provide an estimate for the condition number of the input matrix ``A``, if such estimate is already known.
If a positive number greater than 1., this value is used as an estimate for the condition number of A.
If smaller or equal than 1., the condition number is estimated internally (default).
silent : bool, optional, default: True
If True, the function does not print any output. If False, some information is printed during the computation.

Notes
-----
The implementation follows Algorithm 5.1 in Reference [1]; however, instead of switching from QR to Cholesky decomposition depending on the condition number,
we stick to QR decomposition in all iterations.

References
----------
[1] Nakatsukasa, Y., & Freund, R. W. (2016). Computing the polar decomposition with applications. SIAM Review, 58(3), DOI: https://doi.org/10.1137/140990334.
"""
# check whether input is DNDarray of correct shape
if not isinstance(A, DNDarray):
raise TypeError(f"Input ``A`` needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError(f"Input ``A`` needs to be a 2D DNDarray, but its dimension is {A.ndim}.")
if A.shape[0] < A.shape[1]:
raise ValueError(
f"Input ``A`` must have at least as many rows as columns, but has shape {A.shape}."
)
# check if A is a real floating point matrix and choose tolerances tol accordingly
if A.dtype == types.float32:
tol = 1.19e-7
elif A.dtype == types.float64:
tol = 2.22e-16
else:
raise TypeError(
f"Input ``A`` must be of data type float32 or float64 but has data type {A.dtype}"
)

# check if input for r is reasonable
if not isinstance(r, int) or r < 1 or r > 8:
raise ValueError(
f"If specified, input ``r`` must be an integer between 1 and 8, but is {r} of data type {type(r)}."
)

# check if input for condition_estimate is reasonable
if not isinstance(condition_estimate, float):
raise TypeError(
f"If specified, input ``condition_estimate`` must be a float but is {type(condition_estimate)}."
)

alpha = _estimate_largest_singularvalue(A).item()

if condition_estimate <= 1.0:
kappa = condest(A).item()
else:
kappa = condition_estimate

if A.comm.rank == 0 and not silent:
print(
f"Condition number estimate: {kappa:2.2e} / Estimate for largest singular value: {alpha:2.2e}."
)

# initialize X for the iteration: input ``A``, normalized by largest singular value
X = A / alpha

# iteration counter and maximum number of iterations
it = 0
itmax = _zolopd_n_iterations(r, kappa)

# parameters and coefficients, see Ref. [1] for their meaning
ell = 1.0 / kappa
c, a, Mhat = _compute_zolotarev_coefficients(r, ell, A.device, dtype=A.dtype)

while it < itmax:
it += 1
if not silent:
if A.comm.rank == 0:
print(f"Starting Zolotarev-PD iteration no. {it}...")
# remember current X for later convergence check
X_old = X.copy()

# repeat X r-times and create (repeated) identity matrix
# this allows to compute the r-many QR decomposition and matrix multiplications in batch-parallel manor
X = factories.array(
X.larray.repeat(r, 1, 1),
is_split=X.split + 1 if X.split is not None else None,
comm=A.comm,
)
cId = factories.eye(A.shape[1], dtype=A.dtype, comm=A.comm, split=A.split, device=A.device)
cId = factories.array(
cId.larray.repeat(r, 1, 1),
is_split=cId.split + 1 if cId.split is not None else None,
comm=A.comm,
)
cId *= c[0::2].reshape(-1, 1, 1) ** 0.5
X = concatenate([X, cId], axis=1)
Q, _ = qr(X)
Q1 = Q[:, : A.shape[0], : A.shape[1]].balance()
Q2 = Q[:, A.shape[0] :, : A.shape[1]].transpose([0, 2, 1]).balance()
del Q
X = Mhat * (
X[:, : A.shape[0], :].balance() / r
+ a.reshape(-1, 1, 1)
/ c[0::2].reshape(-1, 1, 1) ** 0.5
* matmul(Q1, Q2).resplit_(X.split)
)
del (Q1, Q2)
# finally, sum over the batch-dimension to get back the result of the iteration
X = X.sum(axis=0)

# check for convergence and break if tolerance is reached
if it > 1 and matrix_norm(X - X_old, ord="fro") / matrix_norm(X, ord="fro") <= tol ** (
1 / (2 * r + 1)
):
if not silent:
if A.comm.rank == 0:
print(f"Zolotarev-PD iteration converged after {it} iterations.")
break
elif it < itmax:
# if another iteration is necessary, update coefficients and parameters for next iteration
ellold = ell
ell = 1
for j in range(r):
ell *= (ellold**2 + c[2 * j + 1]) / (ellold**2 + c[2 * j])
ell *= Mhat * ellold
c, a, Mhat = _compute_zolotarev_coefficients(r, ell, A.device, dtype=A.dtype)
else:
if not silent:
if A.comm.rank == 0:
print(
f"Zolotarev-PD iteration did not reach the convergence criterion after {itmax} iterations, which is most likely due to limited numerical accuracy and/or poor estimation of the condition number. The result may still be useful, but should be handeled with care!"
)
# postprocessing: compute H if requested
if calcH:
H = matmul(X.T, A)
H = 0.5 * (H + H.T.resplit(H.split))
return X, H.resplit(A.split)
else:
return X
Loading
Loading