Skip to content

Commit

Permalink
debugging_directory
Browse files Browse the repository at this point in the history
  • Loading branch information
YousefMetwally committed Feb 7, 2025
1 parent 20d0902 commit 8d4f76d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 66 deletions.
1 change: 1 addition & 0 deletions tomotwin/embed_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def sliding_window_embedding(
positions = np.array(positions)
if padding is True:
print("Adjusting positions after padding")
odd_factor = box_size % 2
positions = positions - int((box_size - odd_factor) // 2)
embeddings = np.hstack([positions, embeddings])

Expand Down
142 changes: 76 additions & 66 deletions tomotwin/modules/tools/embedding_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,75 +161,85 @@ def median_mode(self,
batch_size: int,
threshold: float,
dilation: float,
padding: bool
padding: bool,
) -> np.array:
'''
Calculates a mask based on median embedding
'''
with tempfile.TemporaryDirectory() as tmp_pth:
# Embed
emb_out_pth = os.path.join(tmp_pth, "embed")

print ('median_moade.padding = ', padding)
conf = EmbedConfiguration(
model_path=model_pth,
volumes_path=tomo_pth,
output_path=emb_out_pth,
mode=EmbedMode.TOMO,
batchsize=batch_size,
stride=stride,
zrange=None,
maskpth=None,
distr_mode=DistrMode.DDP,
padding = padding
)
"""
Calculates a mask based on median embedding and saves all generated files.
"""
# Create output directories
output_dir= '/mnt/data1/D1/paddingtrial/TS_73_6/debugging'
os.makedirs(output_dir, exist_ok=True)
emb_out_pth = os.path.join(output_dir, "embed")
median_out_pth = os.path.join(output_dir, "median_emb")
map_out_pth = os.path.join(output_dir, "map")

os.makedirs(emb_out_pth, exist_ok=True)
os.makedirs(median_out_pth, exist_ok=True)
os.makedirs(map_out_pth, exist_ok=True)

print('median_mode.padding = ', padding)

# Embed
conf = EmbedConfiguration(
model_path=model_pth,
volumes_path=tomo_pth,
output_path=emb_out_pth,
mode=EmbedMode.TOMO,
batchsize=batch_size,
stride=stride,
zrange=None,
maskpth=None,
distr_mode=DistrMode.DDP,
padding=padding
)

embed.start(conf)

# Median embedding
median_out_pth = os.path.join(tmp_pth, "median_emb")
args = SimpleNamespace(input=glob(os.path.join(emb_out_pth, "*.temb"))[0], output=median_out_pth)
mtool = median_tool.MedianTool()
mtool.run(args)

# Map
map_out_pth = os.path.join(tmp_pth, "map")
map_conf = MapConfiguration(
reference_embeddings_path=glob(os.path.join(median_out_pth, "*.temb"))[0],
volume_embeddings_path=glob(os.path.join(emb_out_pth, "*.temb"))[0],
output_path=map_out_pth,
mode=MapMode.DISTANCE,
skip_refinement=True
)
tmap.run(map_conf)

# Heatmap
print("Calculate heatmap")
map_output = pd.read_pickle(glob(os.path.join(map_out_pth, "*.tmap"))[0])
raw_heatmap = FindMaximaLocator.to_volume(
df=map_output,
target_class=0,
stride=(stride, stride, stride),
window_size=map_output.attrs['window_size'],
embed.start(conf)

# Median embedding
args = SimpleNamespace(input=glob(os.path.join(emb_out_pth, "*.temb"))[0], output=median_out_pth)
mtool = median_tool.MedianTool()
mtool.run(args)

# Map
map_conf = MapConfiguration(
reference_embeddings_path=glob(os.path.join(median_out_pth, "*.temb"))[0],
volume_embeddings_path=glob(os.path.join(emb_out_pth, "*.temb"))[0],
output_path=map_out_pth,
mode=MapMode.DISTANCE,
skip_refinement=True
)
tmap.run(map_conf)

# Heatmap
print("Calculate heatmap")
map_output = pd.read_pickle(glob(os.path.join(map_out_pth, "*.tmap"))[0])
raw_heatmap = FindMaximaLocator.to_volume(
df=map_output,
target_class=0,
stride=(stride, stride, stride),
window_size=map_output.attrs['window_size'],
)
raw_heatmap = raw_heatmap.astype(np.float32)
heatmap = locate.scale_and_pad_heatmap(
raw_heatmap,
stride=stride,
tomo_input_shape=map_output.attrs['tomogram_input_shape']
)

print("Binarize heatmap")
# Binarize heatmap
mask = np.logical_and(heatmap < threshold, heatmap > np.min(heatmap))

if dilation != 0:
mask = morphology.binary_dilation(
mask, morphology.ball(radius=dilation)
)
raw_heatmap = raw_heatmap.astype(np.float32)
heatmap = locate.scale_and_pad_heatmap(raw_heatmap,
stride=stride,
tomo_input_shape=map_output.attrs['tomogram_input_shape'])
print("Binarize heatmap")
# Binarize heatmap
mask = np.logical_and(heatmap < threshold, heatmap > np.min(heatmap))

if dilation != 0:
mask = morphology.binary_dilation(
mask, morphology.ball(radius=dilation)
)

bin_mask = np.zeros_like(mask, dtype=np.float32)
print(bin_mask.shape)
bin_mask[mask] = 1

return bin_mask

bin_mask = np.zeros_like(mask, dtype=np.float32)
bin_mask[mask] = 1

return bin_mask


def intensity_mode(self, img: np.array) -> np.array:
'''
Expand Down

0 comments on commit 8d4f76d

Please sign in to comment.