Skip to content

Commit

Permalink
fix test_get
Browse files Browse the repository at this point in the history
  • Loading branch information
Aggrathon committed Nov 25, 2022
1 parent 41f286e commit b23d108
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
12 changes: 7 additions & 5 deletions tests/test_slisemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,23 @@ def test_predict():


def test_get():
sm = get_slisemap(40, 5, intercept=True, lasso=0, ridge=0, z_norm=0)
sm = get_slisemap(40, 5, intercept=True, lasso=0, ridge=0, z_norm=0, seed=459872)
sm.lbfgs(10)
assert torch.allclose(sm._Z, sm.get_Z(False, False, False))
assert torch.allclose(torch.sqrt(torch.sum(sm._Z**2) / sm.n), torch.ones(1))
Z = sm.get_Z(numpy=False)
assert torch.allclose(
torch.sqrt(torch.sum(Z**2) / sm.n) / sm.radius, torch.ones(1)
)
assert torch.allclose(sm.get_D(numpy=False), torch.cdist(Z, Z), 1e-4, 1e-6)
assert torch.allclose(sm.get_W(numpy=False), sm.kernel(torch.cdist(Z, Z)))
assert sm.get_X(intercept=False).shape[1] == sm.m - 1
L = sm.get_L(numpy=False)
W = sm.get_W(numpy=False)
assert np.allclose(sm.value(True), tonp(torch.sum(L * W, 1)), 1e-4, 1e-6)
assert torch.allclose(W, sm.kernel(torch.cdist(Z, Z)))
L = sm.get_L(numpy=False)
assert_allclose(sm.value(), torch.sum(L * W).cpu().item(), "loss", 1e-4, 1e-6)
assert_allclose(sm.value(True), tonp(torch.sum(L * W, 1)), "ind_loss", 1e-4, 1e-6)
assert sm.get_Y(False, True).shape == (40,)
assert sm.get_Y(False, False).shape == (40, 1)
assert sm.get_X(intercept=False).shape[1] == sm.m - 1
assert sm.get_X(False, False).shape == (40, 5)
assert sm.get_X(False, True).shape == (40, 6)
assert sm.get_B(False).shape == (40, 6)
Expand Down
46 changes: 24 additions & 22 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def get_slisemap(
**kwargs,
)
else:
y = npr.normal(size=n)
sm = Slisemap(X, y, lasso=lasso, **kwargs)
if randomB:
sm._B = torch.normal(0, 1, sm._B.shape, **sm.tensorargs)
return sm


y = npr.normal(size=n)
sm = Slisemap(X, y, lasso=lasso, **kwargs)
if randomB:
sm._B = torch.normal(0, 1, sm._B.shape, **sm.tensorargs)
return sm


def get_slisemap2(
n=20,
m=5,
Expand All @@ -65,13 +65,13 @@ def get_slisemap2(
**kwargs,
)
else:
sm = Slisemap(X, y, lasso=lasso, random_state=seed, **kwargs)
if randomB:
sm._B = torch.normal(
0, 1, sm._B.shape, **sm.tensorargs, generator=sm._random_state
)
if cheat:
angles = 2 * np.pi * cl / 3 # Assume k=3, d=2
sm = Slisemap(X, y, lasso=lasso, random_state=seed, **kwargs)
if randomB:
sm._B = torch.normal(
0, 1, sm._B.shape, **sm.tensorargs, generator=sm._random_state
)
if cheat:
angles = 2 * np.pi * cl / 3 # Assume k=3, d=2
Z = np.stack((np.sin(angles), np.cos(angles)), 1)
end = [sm.radius * 0.99, sm.radius * 1.01]
Z = Z * np.linspace(end, end[::-1], len(cl))
Expand All @@ -82,20 +82,22 @@ def get_slisemap2(


def assert_allclose(x, y, label="", *args, **kwargs):
assert np.allclose(x, y, *args, **kwargs), f"{label}: {x} != {y}"
assert np.allclose(
x, y, *args, **kwargs
), f"{label}: {x} != {y}\nmax abs diff: {np.max(np.abs(x-y))}"


def all_finite(x: Union[float, np.ndarray], *args: Union[float, np.ndarray]) -> bool:
if len(args) > 0:
return np.all(np.all(np.isfinite(y)) for y in [x, *args])
return np.all(np.isfinite(x))


def assert_approx_ge(x, y, label=None, tolerance=0.05):
tolerance *= (np.abs(x) + np.abs(y)) * 0.5
if label:
assert np.all(x > y - tolerance), f"{label}: {x} !>= {y}"
else:


def assert_approx_ge(x, y, label=None, tolerance=0.05):
tolerance *= (np.abs(x) + np.abs(y)) * 0.5
if label:
assert np.all(x > y - tolerance), f"{label}: {x} !>= {y}"
else:
assert np.all(x > y - tolerance), f"{x} !>= {y}"


Expand Down

0 comments on commit b23d108

Please sign in to comment.