Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests #69

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,57 @@ def compare_values(expected, got, message_header=None):
), f"{_form_message_header(message_header)}: expected {expected}, got {got}"


def create_testing_data():
N = 20
dim = 256
data = torch.randn((N, dim))
return N, dim, data
def create_testing_data(architecture='fcn'):
architecture = architecture.lower()
if architecture == 'fcn':
return torch.randn((20, 256))
elif architecture == 'cnn':
return torch.randn((20, 3, 32, 32))
else:
raise Exception(f'Unsupported architecture type: {architecture}')


def create_testing_model(num_classes=10):
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Linear(256, 128)),
("second_layer", nn.Linear(128, 64)),
("third_layer", nn.Linear(64, num_classes)),
],
),
)
def create_testing_model(architecture='fcn', num_classes=10):
architecture = architecture.lower()
if architecture == 'fcn':
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Linear(256, 128)),
("second_layer", nn.Linear(128, 64)),
("third_layer", nn.Linear(64, num_classes)),
],
),
)
elif architecture == 'cnn':
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Conv2d(in_channels=3, out_channels=10, kernel_size=7)),
("second_layer", nn.Conv2d(in_channels=10, out_channels=20, kernel_size=7)),
("avgpool", nn.AdaptiveAvgPool2d(1)),
("flatten", nn.Flatten()),
("fc", nn.Linear(20, num_classes)),
],
),
)
elif architecture == 'rnn':
return nn.Sequential(
OrderedDict(
[
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
('extract', ExtractTensor()),
('second_layer', nn.Linear(128, 64)),
('third_layer', nn.Linear(64, num_classes)),
],
),
)
else:
raise Exception(f'Unsupported architecture type: {architecture}')


class ExtractTensor(nn.Module):
def forward(self, x):
tensor, _ = x
x = x.to(torch.float32)
return tensor[:, :]


def create_testing_model_lstm(num_classes=10):
return nn.Sequential(
OrderedDict(
[
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
('extract', ExtractTensor()),
('second_layer', nn.Linear(128, 64)),
('third_layer', nn.Linear(64, num_classes)),
],
),
)
68 changes: 55 additions & 13 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def test_check_random_input():


def _check_reduce_dim(mode):
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
reduced_data = viz_api.reduce_dim(data, mode)
utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type")
utils.compare_values((N, 2), reduced_data.shape, "Wrong result shape")
utils.compare_values((len(data), 2), reduced_data.shape, "Wrong result shape")


def test_reduce_dim_umap():
Expand All @@ -31,8 +31,8 @@ def test_reduce_dim_pca():
_check_reduce_dim("pca")


def test_visualization():
N, dim, data = utils.create_testing_data()
def test_visualization_fcn():
data = utils.create_testing_data()
model = utils.create_testing_model()
layers = ["second_layer", "third_layer"]
res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
Expand All @@ -52,10 +52,31 @@ def test_visualization():
)


def test_visualization_cnn():
data = utils.create_testing_data(architecture='cnn')
model = utils.create_testing_model(architecture='cnn')
layers = ["first_layer", "second_layer", "avgpool", "flatten", "fc"]
res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)

utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(6, len(res), "Wrong dictionary length")
utils.compare_values(
set(["input"] + layers),
set(res.keys()),
"Wrong dictionary keys",
)
for key, plot in res.items():
utils.compare_values(
matplotlib.figure.Figure,
type(plot),
f"Wrong value type for key {key}",
)


def test_embed_visualization():
data = torch.randn((20, 1, 256))
labels = torch.randn((20))
model = utils.create_testing_model_lstm()
model = utils.create_testing_model('rnn')
layers = ["second_layer", "third_layer"]
res = viz_api.visualize_recurrent_layer_manifolds(model, "umap",
data, layers=layers, labels=labels)
Expand All @@ -74,16 +95,16 @@ def test_embed_visualization():
)


def _test_bayes_prediction(mode: str):
def _test_bayes_prediction(mode: str, architecture='fcn'):
params = {
"basic": dict(mode="basic", p=0.5),
"beta": dict(mode="beta", a=0.9, b=0.2),
"gauss": dict(sigma=1e-2),
}

N, dim, data = utils.create_testing_data()
data = utils.create_testing_data(architecture=architecture)
num_classes = 17
model = utils.create_testing_model(num_classes=num_classes)
model = utils.create_testing_model(architecture=architecture, num_classes=num_classes)
n_iter = 7
if mode != 'gauss':
res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
Expand All @@ -94,6 +115,7 @@ def _test_bayes_prediction(mode: str):
utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys")
N = len(data)
utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape")
utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std")

Expand All @@ -110,14 +132,18 @@ def test_gauss_bayes_wrapper():
_test_bayes_prediction("gauss")


def test_bayes_wrapper_cnn():
_test_bayes_prediction("basic", architecture='cnn')


def test_data_barcode():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
res = topology_api.get_data_barcode(data, "standard", "3")
utils.compare_values(dict, type(res), "Wrong result type")


def test_nn_barcodes():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
model = utils.create_testing_model()
layers = ["second_layer", "third_layer"]
res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
Expand All @@ -132,15 +158,31 @@ def test_nn_barcodes():
)


def test_nn_barcodes_cnn():
data = utils.create_testing_data(architecture='cnn')
model = utils.create_testing_model(architecture='cnn')
layers = ["second_layer", "flatten"]
res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys")
for layer, barcode in res.items():
utils.compare_values(
dict,
type(barcode),
f"Wrong result type for key {layer}",
)


def test_barcode_plot():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
plot = topology_api.plot_barcode(barcode)
utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type")


def test_barcode_evaluate_all_metrics():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
result = topology_api.evaluate_barcode(barcode)
utils.compare_values(dict, type(result), "Wrong result type")
Expand All @@ -166,7 +208,7 @@ def test_barcode_evaluate_all_metrics():


def test_barcode_evaluate_one_metric():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
result = topology_api.evaluate_barcode(barcode, metric_name="mean_length")
utils.compare_values(float, type(result), "Wrong result type")
Loading