From b23d1080b0b999d3cd8ce3f9f62cf9e7adb0410b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Fri, 25 Nov 2022 14:41:35 +0200 Subject: [PATCH] fix test_get --- tests/test_slisemap.py | 12 ++++++----- tests/utils.py | 46 ++++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/tests/test_slisemap.py b/tests/test_slisemap.py index a72a370..e84380f 100644 --- a/tests/test_slisemap.py +++ b/tests/test_slisemap.py @@ -156,7 +156,8 @@ def test_predict(): def test_get(): - sm = get_slisemap(40, 5, intercept=True, lasso=0, ridge=0, z_norm=0) + sm = get_slisemap(40, 5, intercept=True, lasso=0, ridge=0, z_norm=0, seed=459872) + sm.lbfgs(10) assert torch.allclose(sm._Z, sm.get_Z(False, False, False)) assert torch.allclose(torch.sqrt(torch.sum(sm._Z**2) / sm.n), torch.ones(1)) Z = sm.get_Z(numpy=False) @@ -164,13 +165,14 @@ def test_get(): torch.sqrt(torch.sum(Z**2) / sm.n) / sm.radius, torch.ones(1) ) assert torch.allclose(sm.get_D(numpy=False), torch.cdist(Z, Z), 1e-4, 1e-6) - assert torch.allclose(sm.get_W(numpy=False), sm.kernel(torch.cdist(Z, Z))) - assert sm.get_X(intercept=False).shape[1] == sm.m - 1 - L = sm.get_L(numpy=False) W = sm.get_W(numpy=False) - assert np.allclose(sm.value(True), tonp(torch.sum(L * W, 1)), 1e-4, 1e-6) + assert torch.allclose(W, sm.kernel(torch.cdist(Z, Z))) + L = sm.get_L(numpy=False) + assert_allclose(sm.value(), torch.sum(L * W).cpu().item(), "loss", 1e-4, 1e-6) + assert_allclose(sm.value(True), tonp(torch.sum(L * W, 1)), "ind_loss", 1e-4, 1e-6) assert sm.get_Y(False, True).shape == (40,) assert sm.get_Y(False, False).shape == (40, 1) + assert sm.get_X(intercept=False).shape[1] == sm.m - 1 assert sm.get_X(False, False).shape == (40, 5) assert sm.get_X(False, True).shape == (40, 6) assert sm.get_B(False).shape == (40, 6) diff --git a/tests/utils.py b/tests/utils.py index db1b90f..5b9f389 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,13 +34,13 @@ def get_slisemap( **kwargs, ) else: - y = npr.normal(size=n) - sm = Slisemap(X, y, lasso=lasso, **kwargs) - if randomB: - sm._B = torch.normal(0, 1, sm._B.shape, **sm.tensorargs) - return sm - - + y = npr.normal(size=n) + sm = Slisemap(X, y, lasso=lasso, **kwargs) + if randomB: + sm._B = torch.normal(0, 1, sm._B.shape, **sm.tensorargs) + return sm + + def get_slisemap2( n=20, m=5, @@ -65,13 +65,13 @@ def get_slisemap2( **kwargs, ) else: - sm = Slisemap(X, y, lasso=lasso, random_state=seed, **kwargs) - if randomB: - sm._B = torch.normal( - 0, 1, sm._B.shape, **sm.tensorargs, generator=sm._random_state - ) - if cheat: - angles = 2 * np.pi * cl / 3 # Assume k=3, d=2 + sm = Slisemap(X, y, lasso=lasso, random_state=seed, **kwargs) + if randomB: + sm._B = torch.normal( + 0, 1, sm._B.shape, **sm.tensorargs, generator=sm._random_state + ) + if cheat: + angles = 2 * np.pi * cl / 3 # Assume k=3, d=2 Z = np.stack((np.sin(angles), np.cos(angles)), 1) end = [sm.radius * 0.99, sm.radius * 1.01] Z = Z * np.linspace(end, end[::-1], len(cl)) @@ -82,20 +82,22 @@ def get_slisemap2( def assert_allclose(x, y, label="", *args, **kwargs): - assert np.allclose(x, y, *args, **kwargs), f"{label}: {x} != {y}" + assert np.allclose( + x, y, *args, **kwargs + ), f"{label}: {x} != {y}\nmax abs diff: {np.max(np.abs(x-y))}" def all_finite(x: Union[float, np.ndarray], *args: Union[float, np.ndarray]) -> bool: if len(args) > 0: return np.all(np.all(np.isfinite(y)) for y in [x, *args]) return np.all(np.isfinite(x)) - - -def assert_approx_ge(x, y, label=None, tolerance=0.05): - tolerance *= (np.abs(x) + np.abs(y)) * 0.5 - if label: - assert np.all(x > y - tolerance), f"{label}: {x} !>= {y}" - else: + + +def assert_approx_ge(x, y, label=None, tolerance=0.05): + tolerance *= (np.abs(x) + np.abs(y)) * 0.5 + if label: + assert np.all(x > y - tolerance), f"{label}: {x} !>= {y}" + else: assert np.all(x > y - tolerance), f"{x} !>= {y}"