Skip to content

Commit

Permalink
Add seed for random number in UAMDS
Browse files Browse the repository at this point in the history
  • Loading branch information
marinaevers committed Jul 25, 2024
1 parent 35024d8 commit 406d138
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion uadapy/dr/uamds.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def apply_uamds(means: list[np.ndarray], covs: list[np.ndarray], target_dim=2) -
}


def uamds(distributions: list, dims: int=2):
def uamds(distributions: list, dims: int=2, seed: int=0):
"""
Applies the UAMDS algorithm to the provided distributions and returns the projected distributions
in lower-dimensional space. It assumes multivariate normal distributions.
Expand All @@ -534,13 +534,16 @@ def uamds(distributions: list, dims: int=2):
list of input distributions (distribution objects offering mean() and cov() methods)
dims : int
target dimensionality, 2 by default.
seed : int
Set the random seed for the initialization, 0 by default
Returns
-------
list
List of distributions living in projection space (i.e. of provided dimensionality)
"""
try:
np.random.seed(seed)
means = [d.mean() for d in distributions]
covs = [d.cov() for d in distributions]
result = apply_uamds(means, covs, dims)
Expand Down

0 comments on commit 406d138

Please sign in to comment.