Skip to content

Commit

Permalink
tests and docs build.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Dec 28, 2023
1 parent 94ab743 commit c2e5370
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions test/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,9 @@ def test_chi_squared_evaluation(
data = torch.tensor(data_re + 1.0j * data_im)
weight = torch.tensor(weight)

chi_squared = losses.chi_squared(loose_visibilities, data, weight)
chi_squared = losses._chi_squared(loose_visibilities, data, weight)
print("loose chi_squared", chi_squared)

# calculate the gridded chi^2
chi_squared_gridded = losses.chi_squared_gridded(gridded_visibilities, dataset)
print("gridded chi_squared", chi_squared_gridded)

# it's OK that the values are different


def test_log_likelihood_evaluation(
loose_visibilities, mock_visibility_data, gridded_visibilities, dataset
Expand All @@ -102,7 +96,7 @@ def test_nll_hermitian_pairs(loose_visibilities, mock_visibility_data):
data = torch.tensor(data_re + 1.0j * data_im)
weight = torch.tensor(weight)

log_like = losses.reduced_chi_squared(loose_visibilities, data, weight)
log_like = losses.r_chi_squared(loose_visibilities, data, weight)
print("loose nll", log_like)

# calculate it with Hermitian pairs
Expand All @@ -116,7 +110,7 @@ def test_nll_hermitian_pairs(loose_visibilities, mock_visibility_data):
data = torch.cat([data, torch.conj(data)], axis=1)
weight = torch.cat([weight, weight], axis=1)

log_like = losses.reduced_chi_squared(loose_visibilities, data, weight)
log_like = losses.r_chi_squared(loose_visibilities, data, weight)
print("loose nll w/ Hermitian", log_like)


Expand Down Expand Up @@ -156,7 +150,7 @@ def test_nll_1D_zero():
data_im = model_im
data_vis = torch.complex(data_re, data_im)

loss = losses.reduced_chi_squared(model_vis, data_vis, weights)
loss = losses.r_chi_squared(model_vis, data_vis, weights)
assert loss.item() == 0.0


Expand All @@ -175,7 +169,7 @@ def test_nll_1D_random():
data_im = torch.randn_like(weights)
data_vis = torch.complex(data_re, data_im)

losses.reduced_chi_squared(model_vis, data_vis, weights)
losses.r_chi_squared(model_vis, data_vis, weights)


def test_nll_2D_zero():
Expand All @@ -195,7 +189,7 @@ def test_nll_2D_zero():
data_im = model_im
data_vis = torch.complex(data_re, data_im)

loss = losses.reduced_chi_squared(model_vis, data_vis, weights)
loss = losses.r_chi_squared(model_vis, data_vis, weights)
assert loss.item() == 0.0


Expand All @@ -215,7 +209,7 @@ def test_nll_2D_random():
data_im = torch.randn_like(weights)
data_vis = torch.complex(data_re, data_im)

losses.reduced_chi_squared(model_vis, data_vis, weights)
losses.r_chi_squared(model_vis, data_vis, weights)


def test_entropy_raise_error_negative():
Expand Down

0 comments on commit c2e5370

Please sign in to comment.