Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prompt encoder wrong device #18

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,20 @@ def update_slice(
continue

inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1)
if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"):
if device and (
(isinstance(device, str) and device.startswith("cuda"))
or isinstance(device, torch.device)
and device.type == "cuda"
):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(
inputs, class_prompts, point_prompts, start_idx, original_affine, device=device
)

predictor.eval()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith("cuda")) or (
isinstance(device, torch.device) and device.type == "cuda"
):
with torch.cuda.amp.autocast():
outputs = predictor(data)
logit = outputs[0]["high_res_logits"]
Expand Down Expand Up @@ -297,14 +303,18 @@ def iterate_all(
)
for start_idx in start_range:
inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1)
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith("cuda")) or (
isinstance(device, torch.device) and device.type == "cuda"
):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device)
predictor = predictor.eval()
with autocast():
if cachedEmbedding:
curr_embedding = cachedEmbedding[start_idx]
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith("cuda")) or (
isinstance(device, torch.device) and device.type == "cuda"
):
curr_embedding = curr_embedding.cuda()
outputs = predictor.get_mask_prediction(data, curr_embedding)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi

class_list = [[i + 1] for i in class_prompts]
unique_labels = torch.tensor(class_list).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith("cuda")) or (
isinstance(device, torch.device) and device.type == "cuda"
):
unique_labels = unique_labels.cuda()

volume_point_coords = [cp for cp in foreground_all]
Expand Down Expand Up @@ -133,7 +135,9 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi
if point_coords:
point_coords = torch.tensor(point_coords).long()
point_labels = torch.tensor(point_labels).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith("cuda")) or (
isinstance(device, torch.device) and device.type == "cuda"
):
point_coords = point_coords.cuda()
point_labels = point_labels.cuda()

Expand Down
47 changes: 47 additions & 0 deletions scripts/remap_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse
import os
from typing import Dict

import SimpleITK as sitk

MAP = {
0: 0, # background
1: 105, # brainstem
2: 106, # opticChiasm
3: 107, # opticNerveL
4: 108, # opticNerveR
5: 109, # parotidGlandL
6: 110, # garotidGlandR
7: 111, # mandible
}


def remap_labels(data_dir: str, file_substr_id: str, mapping: Dict[int, int], reverse: bool = False):
for root, dirs, files in os.walk(data_dir):
for file in files:
if file_substr_id in file:
file_path = os.path.join(root, file)
print(f"Processing: {file_path}")
label_itk = sitk.ReadImage(file_path, sitk.sitkUInt32)
label_array = sitk.GetArrayFromImage(label_itk)
if reverse:
for k, v in mapping.items():
label_array[label_array == v] = k
else:
for k, v in mapping.items():
label_array[label_array == k] = v
new_label_itk = sitk.GetImageFromArray(label_array)
new_label_itk.CopyInformation(label_itk)
sitk.WriteImage(new_label_itk, file_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", help="data directory")
parser.add_argument("--file_substr_id", default="segmentation", help="substr used to indentify the label file")
args = parser.parse_args()
remap_labels(args.data_dir, args.file_substr_id, MAP, reverse=False)


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
python training/main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \
--num_patch 24 --num_prompt 32 \
--json_list ./../data/han/task_HaN_small.json \
--data_dir ./../data/han \
--roi_z_iter 9 --save_checkpoint \
--sam_base_model vit_b \
--logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \
--iterative_training_warm_up_epoch 50 --reuse_img_embedding \
--label_prompt_warm_up_epoch 25 \
--checkpoint ./runs/9s_2dembed_model.pt \
--num_classes 112
3 changes: 2 additions & 1 deletion training/example_train_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \
--logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \
--iterative_training_warm_up_epoch 50 --reuse_img_embedding \
--label_prompt_warm_up_epoch 25 \
--checkpoint ./runs/9s_2dembed_model.pt
--checkpoint ./runs/9s_2dembed_model.pt \
--num_classes 112
8 changes: 7 additions & 1 deletion training/main_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
parser.add_argument("--skip_bk", action="store_true", help="skip background (0) during training")
parser.add_argument("--patch_embed_3d", action="store_true", help="using 3d patch embedding layer")

parser.add_argument("--num_classes", default=105, type=int, help="number of output classes")


def start_tb(log_dir):
cmd = ["tensorboard", "--logdir", log_dir]
Expand All @@ -123,6 +125,10 @@ def main():
args = parser.parse_args()
args.amp = not args.noamp
args.logdir = "./runs/" + args.logdir

if args.num_classes == 0:
warnings.warn("consider setting the correct number of classes")

# start_tb(args.logdir)
if args.seed > -1:
set_determinism(seed=args.seed)
Expand Down Expand Up @@ -162,7 +168,7 @@ def main_worker(gpu, args):

dice_loss = DiceCELoss(sigmoid=True)

post_label = AsDiscrete(to_onehot=105)
post_label = AsDiscrete(to_onehot=args.num_classes)
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)

Expand Down
15 changes: 9 additions & 6 deletions training/trainer_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model):
unique_labels = unique_labels[: args.num_prompt]

# add 4 background labels to every batch
background_labels = list(set([i for i in range(1, 105)]) - set(unique_labels.cpu().numpy()))
background_labels = list(set([i for i in range(1, args.num_classes)]) - set(unique_labels.cpu().numpy()))
random.shuffle(background_labels)
unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:4]).cuda(args.rank)])

Expand Down Expand Up @@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar


def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand All @@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None):

def prepare_sam_val_input_cp_only(inputs, labels, args):
# Don't exclude background in val but will ignore it in metric calculation
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand Down Expand Up @@ -457,15 +457,18 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=
y_pred = torch.stack(post_pred(decollate_batch(logit)), 0)

# TODO: we compute metric for each prompt for simplicity in validation.
acc_batch = compute_dice(y_pred=y_pred, y=target)
acc_batch = compute_dice(y_pred=y_pred[None,], y=target[None,])
acc_sum, not_nans = (
torch.nansum(acc_batch).item(),
104 - torch.sum(torch.isnan(acc_batch).float()).item(),
(args.num_classes - 1) - torch.sum(torch.isnan(acc_batch).float()).item(),
)
acc_sum_total += acc_sum
not_nans_total += not_nans

acc, not_nans = acc_sum_total / not_nans_total, not_nans_total
if not_nans_total > 0:
acc, not_nans = acc_sum_total / not_nans_total, not_nans_total
else:
acc, not_nans = 0, 0
f_name = batch_data["image"].meta["filename_or_obj"]
print(f"Rank: {args.rank}, Case: {f_name}, Acc: {acc:.4f}, N_prompts: {int(not_nans)} ")

Expand Down