From 4e69be430d93130815c5d565d49ede2c4f9a454f Mon Sep 17 00:00:00 2001 From: Pavel Sukachev <92699977+TheRealGremlin@users.noreply.github.com> Date: Sun, 7 Jan 2024 23:43:38 +0400 Subject: [PATCH 1/3] Update api.py --- eXNN/visualization/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/eXNN/visualization/api.py b/eXNN/visualization/api.py index 6cbe216..2e763ac 100644 --- a/eXNN/visualization/api.py +++ b/eXNN/visualization/api.py @@ -161,6 +161,7 @@ def visualize_recurrent_layer_manifolds( """ model2 = create_feature_extractor(model, return_nodes=layers) labels = labels.detach().numpy() + emb_viz = {} for layer in layers: if torch.is_tensor(model2(data)[layer]): layer_output = model2(data)[layer].cpu().detach().numpy() @@ -206,6 +207,7 @@ def visualize_recurrent_layer_manifolds( width=1000, height=1000) emb_out.show(renderer="colab") + emb_viz[layer] = emb_out def get_random_input(dims: List[int]) -> torch.Tensor: From ed40149617d9923152f598e2b10c7d789dc59c0a Mon Sep 17 00:00:00 2001 From: Pavel Sukachev <92699977+TheRealGremlin@users.noreply.github.com> Date: Mon, 8 Jan 2024 00:08:51 +0400 Subject: [PATCH 2/3] Update api.py --- eXNN/visualization/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/eXNN/visualization/api.py b/eXNN/visualization/api.py index 2e763ac..6904d4f 100644 --- a/eXNN/visualization/api.py +++ b/eXNN/visualization/api.py @@ -208,6 +208,7 @@ def visualize_recurrent_layer_manifolds( height=1000) emb_out.show(renderer="colab") emb_viz[layer] = emb_out + return emb_viz def get_random_input(dims: List[int]) -> torch.Tensor: From 195f542586141bce4bbff747db5f82107630fe3d Mon Sep 17 00:00:00 2001 From: Pavel Sukachev <92699977+TheRealGremlin@users.noreply.github.com> Date: Mon, 8 Jan 2024 02:05:04 +0400 Subject: [PATCH 3/3] Update api.py --- eXNN/visualization/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eXNN/visualization/api.py b/eXNN/visualization/api.py index 6904d4f..d396d7c 100644 --- a/eXNN/visualization/api.py +++ b/eXNN/visualization/api.py @@ -208,7 +208,7 @@ def visualize_recurrent_layer_manifolds( height=1000) emb_out.show(renderer="colab") emb_viz[layer] = emb_out - return emb_viz + return emb_viz def get_random_input(dims: List[int]) -> torch.Tensor: