Skip to content

Commit

Permalink
Fix transpose and patch coords bug (#8047)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Fix the bug that causes wrong results in model zoo finetuning. Patch
coords was not passed from sliding window to vista3d.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent 1a8afd1 commit b62d1e1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ zarr
huggingface_hub
pyamg>=5.0.0
packaging
polygraphy
29 changes: 18 additions & 11 deletions monai/apps/vista3d/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import torch
from torch import Tensor

__all__ = ["sample_prompt_pairs"]

ENABLE_SPECIAL = True
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
MERGE_LIST = {
Expand All @@ -30,6 +28,8 @@
132: [57], # overlap with trachea merge into airway
}

__all__ = ["sample_prompt_pairs"]


def _get_point_label(id: int) -> tuple[int, int]:
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
Expand Down Expand Up @@ -66,22 +66,29 @@ def sample_prompt_pairs(
max_backprompt: int, max number of prompt from background.
max_point: maximum number of points for each object.
include_background: if include 0 into training prompt. If included, background 0 is treated
the same as foreground. Always be False for multi-partial-dataset training. If needed,
can be true for finetuning specific dataset, .
the same as foreground and points will be sampled. Can be true only if user want to segment
background 0 with point clicks, otherwise always be false.
drop_label_prob: probability to drop label prompt.
drop_point_prob: probability to drop point prompt.
point_sampler: sampler to augment masks with supervoxel.
point_sampler_kwargs: arguments for point_sampler.
Returns:
label_prompt: [B, 1]. The classes used for training automatic segmentation.
point: [B, N, 3]. The corresponding points for each class.
Note that background label prompt requires matching point as well ([0,0,0] is used).
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
-1 is used for padding the background label prompt and will be ignored.
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
label_prompt can be None, and prompt_class is used to identify point classes.
tuple:
- label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
training automatic segmentation.
- point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
for each class. Note that background label prompts require matching points as well
(e.g., [0, 0, 0] is used).
- point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
labels for each point (negative or positive). -1 is used for padding the background
label prompt and will be ignored.
- prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
for label indexing during training. If label_prompt is None, prompt_class is used to
identify point classes.
"""

# class label number
if not labels.shape[0] == 1:
raise ValueError("only support batch size 1")
Expand Down
7 changes: 5 additions & 2 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
def forward(
self,
input_images: torch.Tensor,
patch_coords: Sequence[slice] | None = None,

This comment has been minimized.

Copy link
@yiheng-wang-nv

yiheng-wang-nv Aug 30, 2024

Contributor

This introduces an issue in update_point_to_patch.
The sliding window inferer inputs unravel_slice for patch_coords, and it has type: List[List[slice]]

point_coords: torch.Tensor | None = None,
point_labels: torch.Tensor | None = None,
class_vector: torch.Tensor | None = None,
prompt_class: torch.Tensor | None = None,
patch_coords: Sequence[slice] | None = None,
labels: torch.Tensor | None = None,
label_set: Sequence[int] | None = None,
prev_mask: torch.Tensor | None = None,
Expand Down Expand Up @@ -421,7 +421,10 @@ def forward(
point_coords, point_labels = None, None

if point_coords is None and class_vector is None:
return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
if transpose:
logits = logits.transpose(1, 0)
return logits

if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
out, out_auto = self.image_embeddings, None
Expand Down

0 comments on commit b62d1e1

Please sign in to comment.