diff --git a/uadapy/dr/uamds.py b/uadapy/dr/uamds.py index 9749639..045d739 100644 --- a/uadapy/dr/uamds.py +++ b/uadapy/dr/uamds.py @@ -521,7 +521,7 @@ def apply_uamds(means: list[np.ndarray], covs: list[np.ndarray], target_dim=2) - } -def uamds(distributions: list, dims: int=2): +def uamds(distributions: list, dims: int=2, seed: int=0): """ Applies the UAMDS algorithm to the provided distributions and returns the projected distributions in lower-dimensional space. It assumes multivariate normal distributions. @@ -534,6 +534,8 @@ def uamds(distributions: list, dims: int=2): list of input distributions (distribution objects offering mean() and cov() methods) dims : int target dimensionality, 2 by default. + seed : int + Set the random seed for the initialization, 0 by default Returns ------- @@ -541,6 +543,7 @@ def uamds(distributions: list, dims: int=2): List of distributions living in projection space (i.e. of provided dimensionality) """ try: + np.random.seed(seed) means = [d.mean() for d in distributions] covs = [d.cov() for d in distributions] result = apply_uamds(means, covs, dims)