Skip to content

Commit

Permalink
Merge pull request #53 from TheRealGremlin/patch-4
Browse files Browse the repository at this point in the history
Update api.py
  • Loading branch information
tanyapole authored Jan 7, 2024
2 parents 1e5ef98 + 2efc320 commit 6f7223d
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions eXNN/visualization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _plot(embedding, labels):
def reduce_dim(
data: torch.Tensor,
mode: str,
out_dim=2,
neighbors=20,
) -> np.ndarray:
"""This function reduces data dimensionality to 2 dimensions.
Expand All @@ -35,6 +37,8 @@ def reduce_dim(
where N is the number of data points,
C1,...,Ck are dimensions of each data point
mode (str): dimensionality reduction mode (`umap` or `pca`)
out_dim (int): dimension of output, 3 by default
neighbors (int): n_neighbors of umap method
Raises:
ValueError: returned if unsupported mode is provided
Expand All @@ -45,9 +49,9 @@ def reduce_dim(

data = data.detach().cpu().numpy().reshape((len(data), -1))
if mode == "pca":
return PCA(n_components=2).fit_transform(data)
return PCA(n_components=out_dim).fit_transform(data)
elif mode == "umap":
return umap.UMAP().fit_transform(data)
return umap.UMAP(n_components=out_dim, n_neighbors=neighbors).fit_transform(data)
else:
raise ValueError(f"Unsupported mode: `{mode}`")

Expand Down Expand Up @@ -113,9 +117,11 @@ def visualize_recurrent_layer_manifolds(
neighbors=20,
time_delay=1,
embedding_dim=10,
stride_mode="dimensional",
stride_mode='dimensional',
out_dim=3,
renderer="browser",
arr_reducer=1,
renderer='browser',
heatmap: Optional[str] = True,
layers: Optional[List[str]] = None,
labels: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
Expand All @@ -134,6 +140,7 @@ def visualize_recurrent_layer_manifolds(
stride_mode ('dimensional' or str): stride duration between two consecutive embedded points,
'dimensional' makes 'stride' equal to layer dimension
out_dim (int): dimension of output, 3 by default
arr_reducer (int): strips the output array of some data, leaving only each n_th
renderer (str): plotly renderer for image,
Available renderers:
['plotly_mimetype', 'jupyterlab', 'nteract', 'vscode',
Expand Down Expand Up @@ -164,25 +171,33 @@ def visualize_recurrent_layer_manifolds(
else:
stride = stride_mode
if layer_output.ndim > 2:
embedder = TakensEmbedding(time_delay=time_delay, dimension=10, stride=stride)
embedder = TakensEmbedding(time_delay=time_delay, dimension=embedding_dim,
stride=stride)
emb_res = embedder.fit_transform(layer_output)
else:
embedder = TakensEmbedding(time_delay=time_delay, dimension=10, stride=stride)
embedder = TakensEmbedding(time_delay=time_delay, dimension=embedding_dim,
stride=stride)
emb_res = embedder.fit_transform(layer_output.reshape(layer_output.shape[0],
1, layer_output.shape[1]))
emb_res = emb_res.reshape(emb_res.shape[0], 1, -1)
if mode.lower() == 'umap':
umapred = umap.UMAP(n_components=3, n_neighbors=neighbors)
reducing_output = umapred.fit_transform(emb_res[:, 0, :])
if mode.lower() == 'pca':
PCA_out = PCA(n_components=3)
reducing_output = PCA_out.fit_transform(emb_res[:, 0, :])
reducing_output = reduce_dim(data=emb_res[:, 0, :], mode=mode,
out_dim=out_dim, neighbors=neighbors)
df = pd.DataFrame(reducing_output)
if labels.shape[1] == 1:
if labels.ndim == 1:
df["category"] = labels.astype(str)
else:
df["category"] = np.where(labels == 1)[1].astype(str)
# df = df.iloc[::4, :]
df = df.iloc[::arr_reducer, :]
if heatmap is True:
labels_noncat = labels
center = np.zeros((len(np.unique(labels_noncat)), 3))
med_dist = np.zeros((len(np.unique(labels_noncat)), len(np.unique(labels_noncat))))
for i in range(len(np.unique(labels_noncat))):
center[i] = np.mean(reducing_output[np.where(labels_noncat == np.unique(
labels_noncat)[i])], axis=0)
for j in range(len(np.unique(labels_noncat))):
med_dist[i][j] = math.log(np.square(1 / np.mean(reducing_output[np.where(
labels_noncat == np.unique(labels_noncat)[j])] - center[i])))
emb_out = px.scatter_3d(df, x=0, y=1, z=2, color="category")
emb_out.update_traces(marker=dict(size=4))
emb_out.update_layout(
Expand Down

0 comments on commit 6f7223d

Please sign in to comment.