Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write umap embeddings metadata + positions into umap attrs and remove label mask calculation #82

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 6 additions & 34 deletions tomotwin/modules/tools/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
except ImportError:
print("cuml can't be loaded")

import mrcfile
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -98,28 +97,6 @@

return embedding, reducer

def create_embedding_mask(self, embeddings: pd.DataFrame):
"""
Creates mask where each individual subvolume of the running windows gets an individual ID
"""
print("Create embedding mask")
Z = embeddings.attrs["tomogram_input_shape"][0]
Y = embeddings.attrs["tomogram_input_shape"][1]
X = embeddings.attrs["tomogram_input_shape"][2]
stride = embeddings.attrs["stride"][0]
segmentation_array = np.zeros(shape=(Z, Y, X), dtype=np.float32)
z = np.array(embeddings["Z"], dtype=int)
y = np.array(embeddings["Y"], dtype=int)
x = np.array(embeddings["X"], dtype=int)

values = np.array(range(1, len(x) + 1))
for stride_x in tqdm(list(range(stride))):
for stride_y in range(stride):
for stride_z in range(stride):
index = (z + stride_z, y + stride_y, x + stride_x)
segmentation_array[index] = values

return segmentation_array

def run(self, args):
print("Read data")
Expand All @@ -144,23 +121,18 @@
os.makedirs(out_pth,exist_ok=True)
fname = os.path.splitext(os.path.basename(args.input))[0]
df_embeddings = pd.DataFrame(umap_embeddings)
df_embeddings.reset_index(drop=True, inplace=True)
embeddings.reset_index(drop=True, inplace=True)

Check warning on line 125 in tomotwin/modules/tools/umap.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/tools/umap.py#L124-L125

Added lines #L124 - L125 were not covered by tests

print("Write embeedings to disk")
df_embeddings.columns = [f"umap_{i}" for i in range(umap_embeddings.shape[1])]
df_embeddings = pd.concat([embeddings[['X', 'Y', 'Z']], df_embeddings], axis=1)
df_embeddings.attrs['embeddings_attrs'] = embeddings.attrs
df_embeddings.attrs['embeddings_path'] = os.path.realpath(args.input)

Check warning on line 131 in tomotwin/modules/tools/umap.py

View check run for this annotation

Codecov / codecov/patch

tomotwin/modules/tools/umap.py#L129-L131

Added lines #L129 - L131 were not covered by tests

df_embeddings.to_pickle(os.path.join(out_pth,fname+".tumap"))

print("Write umap model to disk")
pickle.dump(fitted_umap, open(os.path.join(out_pth, fname + "_umap_model.pkl"), "wb"))

print("Calculate label mask and write it to disk")
embedding_mask = self.create_embedding_mask(embeddings)
with mrcfile.new(
os.path.join(
args.output,
fname + "_label_mask.mrci",
),
overwrite=True,
) as mrc:
mrc.set_data(embedding_mask)

print("Done")
Loading