Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent error when FSC does not drop below threshold. #5

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,12 @@ def ttfsc_cli(
str, typer.Option("--plot-matplotlib-style", rich_help_panel="Plotting options")
] = "default",
mask: Annotated[Masking, typer.Option("--mask", rich_help_panel="Masking options")] = Masking.none,
mask_file: Annotated[
Optional[Path],
typer.Option(
"--mask-file", help="Path to custom mask file (required when mask=custom)", rich_help_panel="Masking options"
),
] = None,
mask_radius_angstroms: Annotated[
float, typer.Option("--mask-radius-angstroms", rich_help_panel="Masking options")
] = 100.0,
@@ -54,12 +60,17 @@ def ttfsc_cli(
float, typer.Option("--correct-from-fraction-of-estimated-resolution", rich_help_panel="Masking correction options")
] = 0.5,
) -> None:
if mask == Masking.custom and mask_file is None:
raise typer.BadParameter("--mask-file is required when using --mask=custom")
if mask == Masking.sphere and mask_file is not None:
rprint("[yellow]Warning: --mask-file is ignored when using --mask=sphere[/yellow]")
result = ttfsc(
map1=map1,
map2=map2,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
mask=mask,
mask_filename=mask_file,
mask_radius_angstroms=mask_radius_angstroms,
mask_soft_edge_width_pixels=mask_soft_edge_width_pixels,
correct_for_masking=correct_for_masking,
1 change: 1 addition & 0 deletions src/ttfsc/_data_models.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
class Masking(str, Enum):
none = "none"
sphere = "sphere"
custom = "custom"


class TTFSCResult(BaseModel):
66 changes: 54 additions & 12 deletions src/ttfsc/_masking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mrcfile
import torch
from torch_fourier_shell_correlation import fsc

@@ -58,12 +59,18 @@ def calculate_noise_injected_fsc(result: TTFSCResult) -> None:
result.fsc_values_corrected[to_correct] - result.fsc_values_masked_randomized[to_correct]
) / (1.0 - result.fsc_values_masked_randomized[to_correct])

result.estimated_resolution_frequency_pixel = float(
result.frequency_pixels[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1]
)
result.estimated_resolution_angstrom = float(
result.resolution_angstroms[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1]
)
# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_corrected < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_corrected = result.estimated_resolution_angstrom


@@ -93,12 +100,47 @@ def calculate_masked_fsc(result: TTFSCResult) -> None:
map2_tensor_masked = result.map2_tensor * result.mask_tensor
result.fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked)

result.estimated_resolution_frequency_pixel = float(
result.frequency_pixels[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1]
)
result.estimated_resolution_angstrom = float(
result.resolution_angstroms[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1]
)
# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_masked < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_masked = result.estimated_resolution_angstrom

return
elif result.mask == Masking.custom:
if result.mask_filename is None:
raise ValueError("Must provide mask_filename for custom mask")

with mrcfile.open(result.mask_filename) as f:
result.mask_tensor = torch.tensor(f.data)

if result.mask_tensor.shape != result.map1_tensor.shape:
raise ValueError(f"Mask shape {result.mask_tensor.shape} does not match map shape {result.map1_tensor.shape}")

map1_tensor_masked = result.map1_tensor * result.mask_tensor
map2_tensor_masked = result.map2_tensor * result.mask_tensor
result.fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked)

# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_masked < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_masked = result.estimated_resolution_angstrom

return
19 changes: 17 additions & 2 deletions src/ttfsc/ttfsc.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ def ttfsc(
pixel_spacing_angstroms: Optional[float] = None,
fsc_threshold: float = 0.143,
mask: Masking = Masking.none,
mask_filename: Optional[Path] = None,
mask_radius_angstroms: float = 100.0,
mask_soft_edge_width_pixels: int = 10,
correct_for_masking: bool = True,
@@ -40,6 +41,7 @@ def ttfsc(
pixel_spacing_angstroms (Optional[float]): Pixel spacing in Å/px. If not provided, it will be taken from the header.
fsc_threshold (float): FSC threshold value. Default is 0.143.
mask (Masking): Masking option to use. Default is Masking.none.
mask_filename (Optional[Path]): Path to the mask file. Default is None.
mask_radius_angstroms (float): Radius of the mask in Å. Default is 100.0.
mask_soft_edge_width_pixels (int): Width of the soft edge of the mask in pixels. Default is 10.
correct_for_masking (bool): Whether to correct for masking effects. Default is True.
@@ -59,6 +61,7 @@ def ttfsc(
pixel_spacing_angstroms=1.0,
fsc_threshold=0.143,
mask=Masking.soft,
mask_filename=Path("mask.mrc"),
mask_radius_angstroms=150.0,
mask_soft_edge_width_pixels=5,
correct_for_masking=True,
@@ -78,8 +81,19 @@ def ttfsc(

fsc_values_unmasked = fsc(map1_tensor, map2_tensor)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
# Find indices where FSC is below threshold
below_threshold_indices = (fsc_values_unmasked < fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
estimated_resolution_frequency_pixel = float(frequency_pixels[index])
estimated_resolution_angstrom = float(resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
estimated_resolution_frequency_pixel = float(frequency_pixels[-1])
estimated_resolution_angstrom = float(resolution_angstroms[-1])

result = TTFSCResult(
map1=map1,
map1_tensor=map1_tensor,
@@ -99,6 +113,7 @@ def ttfsc(
from ._masking import calculate_masked_fsc

result.mask = mask
result.mask_filename = mask_filename
result.mask_radius_angstroms = mask_radius_angstroms
result.mask_soft_edge_width_pixels = mask_soft_edge_width_pixels
calculate_masked_fsc(result)