Skip to content

Commit

Permalink
update test_model and test_phenotype
Browse files Browse the repository at this point in the history
  • Loading branch information
Ross committed Nov 25, 2023
1 parent 5d05bbd commit 2040b77
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
19 changes: 15 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import nn
from torch.distributions import Gamma
import pandas as pd

from lantern.model import Model
from lantern.model.basis import Basis, VariationalBasis
Expand All @@ -12,6 +13,17 @@
)


from lantern.dataset import Dataset

df = pd.DataFrame({'substitutions':['A', 'B', 'A:B'], 'phen_0':[1,2,3], 'phen_0_var':[1,1,1]})
ds_single = Dataset(df, phenotypes = ['phen_0'], errors = ['phen_0_var'])

df_m = df.cop()
df_m['phen_1'] = [2,4,6]
df_m['phen_1_var'] = [2,2,2]
ds_multi = Dataset(df_m, phenotypes = ['phen_0', 'phen_1'], errors = ['phen_0_var', 'phen_1_var'])


def test_model_validator():
class DummyBasis(Basis):
@property
Expand All @@ -23,7 +35,7 @@ def K(self):
return 3

with pytest.raises(ValueError):
Model(DummyBasis(), Phenotype.build(4, 5,), MultitaskGaussianLikelihood(5))
Model(DummyBasis(), Phenotype.fromDataset(ds_multi, 5,), MultitaskGaussianLikelihood(5))


def test_forward():
Expand All @@ -38,7 +50,7 @@ def test_forward():
Gamma(0.001, 0.001),
)

m = Model(vb, Phenotype.build(4, K, Ni=100), GaussianLikelihood())
m = Model(vb, Phenotype.fromDataset(ds_multi, K, Ni=100), GaussianLikelihood())
m.eval()

X = torch.randn(30, 10)
Expand All @@ -52,7 +64,6 @@ def test_loss():

p = 10
K = 3
D = 4
vb = VariationalBasis(
nn.Parameter(torch.randn(p, K)),
nn.Parameter(torch.randn(p, K) - 3),
Expand All @@ -61,7 +72,7 @@ def test_loss():
Gamma(0.001, 0.001),
)

m = Model(vb, Phenotype.build(D, 3, Ni=100), MultitaskGaussianLikelihood(D))
m = Model(vb, Phenotype.fromDataset(ds_multi, 3, Ni=100), MultitaskGaussianLikelihood(ds_multi.D))
loss = m.loss(N=1000)

X = torch.randn(30, 10)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_phenotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
from lantern.loss import ELBO_GP
from lantern.dataset import Dataset

df = pd.DataFrame({'substitutions':['A', 'B', 'A:B'], 'phen_0':[1,2,3], 'phen_0_var':[1,1,1]})
ds_single = Dataset(df, phenotypes = ['phen_0'], errors = ['phen_0_var'])

df_m = df.cop()
df_m['phen_1'] = [2,4,6]
df_m['phen_1_var'] = [2,2,2]
ds_multi = Dataset(df_m, phenotypes = ['phen_0', 'phen_1'], errors = ['phen_0_var', 'phen_1_var'])


def test_1d():

phen = Phenotype.build(1, 10, Ni=100)
phen = Phenotype.fromDataset(ds_single, 10, Ni=100)

assert type(phen.variational_strategy) == VariationalStrategy

Expand All @@ -28,7 +36,7 @@ def test_1d():

def test_multid():

phen = Phenotype.build(4, 10, Ni=100)
phen = Phenotype.fromDataset(ds_multi, 10, Ni=100)

assert type(phen.variational_strategy) == IndependentMultitaskVariationalStrategy

Expand Down

0 comments on commit 2040b77

Please sign in to comment.