Skip to content

Commit

Permalink
add scaling functionality, add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fedshyvana committed Aug 24, 2020
1 parent 718c523 commit a126ee3
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
11 changes: 10 additions & 1 deletion datasets/dataset_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class Whole_Slide_Bag(Dataset):
def __init__(self,
file_path,
pretrained=False,
custom_transforms=None
custom_transforms=None,
target_patch_size=-1,
):
"""
Args:
Expand All @@ -48,6 +49,10 @@ def __init__(self,
custom_transforms (callable, optional): Optional transform to be applied on a sample
"""
self.pretrained=pretrained
if target_patch_size > 0:
self.target_patch_size = (target_patch_size, target_patch_size)
else:
self.target_patch_size = None

if not custom_transforms:
self.roi_transforms = eval_transforms(pretrained=pretrained)
Expand All @@ -73,13 +78,17 @@ def summary(self):

print('pretrained:', self.pretrained)
print('transformations:', self.roi_transforms)
if self.target_patch_size is not None:
print('target_size: ', self.target_patch_size)

def __getitem__(self, idx):
with h5py.File(self.file_path,'r') as hdf5_file:
img = hdf5_file['imgs'][idx]
coord = hdf5_file['coords'][idx]

img = Image.fromarray(img)
if self.target_patch_size is not None:
img = img.resize(self.target_patch_size)
img = self.roi_transforms(img).unsqueeze(0)
return img, coord

Expand Down
12 changes: 9 additions & 3 deletions extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def save_hdf5(output_dir, asset_dict, mode='a'):
return output_dir


def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0, print_every=20, pretrained=True):
def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0,
print_every=20, pretrained=True, target_patch_size=-1):
"""
args:
file_path: directory of bag (.h5 file)
Expand All @@ -46,7 +47,8 @@ def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0,
verbose: level of feedback
pretrained: use weights pretrained on imagenet
"""
dataset = Whole_Slide_Bag(file_path=file_path, pretrained=pretrained)
dataset = Whole_Slide_Bag(file_path=file_path, pretrained=pretrained,
target_patch_size=target_patch_size)
x, y = dataset[0]
kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {}
loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features)
Expand Down Expand Up @@ -79,6 +81,8 @@ def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0,
parser.add_argument('--feat_dir', type=str)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--no_auto_skip', default=False, action='store_true')
parser.add_argument('--target_patch_size', type=int, default=-1,
help='the desired size of patches for optional scaling before feature embedding')
args = parser.parse_args()


Expand Down Expand Up @@ -118,7 +122,9 @@ def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0,
file_path = bag_candidate
time_start = time.time()
output_file_path = compute_w_loader(file_path, output_path,
model = model, batch_size = args.batch_size, verbose = 1, print_every = 20)
model = model, batch_size = args.batch_size,
verbose = 1, print_every = 20,
target_patch_size=args.target_patch_size)
time_elapsed = time.time() - time_start
print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed))
file = h5py.File(output_file_path, "r")
Expand Down
2 changes: 1 addition & 1 deletion models/model_clam.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def forward(self, h, label=None, instance_eval=False, return_features=False, att
class CLAM_MB(CLAM_SB):
def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8, n_classes=2,
instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False):
super(CLAM_MB, self).__init__()
nn.Module.__init__(self)
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
size = self.size_dict[size_arg]
fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
Expand Down
8 changes: 4 additions & 4 deletions models/model_mil.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def relocate(self):
self.classifier = nn.DataParallel(self.classifier, device_ids=device_ids).to('cuda:0')
else:
self.classifier.to(device)

def forward(self, h, return_features=False):
if return_features:
h = self.classifier.module[:3](h)
logits = self.classifier.module[3](h)
else:
logits = self.classifier(h) # K x 1
top_instance_idx = torch.topk(logits[:, 1], self.top_k, dim=0)[1].view(1,)

y_probs = F.softmax(logits, dim = 1)
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,)
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
Y_hat = torch.topk(top_instance, 1, dim = 1)[1]
Y_prob = F.softmax(top_instance, dim = 1)
y_probs = F.softmax(logits, dim = 1)
Y_prob = F.softmax(top_instance, dim = 1)
results_dict = {}

if return_features:
Expand Down
6 changes: 1 addition & 5 deletions utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def initiate_model(args, ckpt_path):

print_network(model)

# ckpt = torch.load(ckpt_path)
# model.load_state_dict(ckpt, strict=False)

ckpt = torch.load(ckpt_path)
ckpt_clean = {}
for key in ckpt.keys():
Expand Down Expand Up @@ -92,15 +89,14 @@ def summary(model, loader, args):
del data
test_error /= len(loader)

aucs = []
if len(np.unique(all_labels)) == 1:
auc_score = -1

else:
if args.n_classes == 2:
auc_score = roc_auc_score(all_labels, all_probs[:, 1])
aucs = []
else:
aucs = []
binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)])
for class_idx in range(args.n_classes):
if class_idx in all_labels:
Expand Down
11 changes: 9 additions & 2 deletions wsi_core/WholeSlideImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,21 @@ def _getPatchGenerator(self, cont, cont_idx, patch_level, save_path, patch_size=

if custom_downsample > 1:
assert custom_downsample == 2
target_patch_size = patch_size
patch_size = target_patch_size * 2
# the target size is what's specified by patch_size
target_patch_size = patch_size
# the actual patches that we want to take is 2 * target_size for each dimension
patch_size = target_patch_size * 2
# similarly, the step size is 2 * what's specified
step_size = step_size * 2
print("Custom Downsample: {}, Patching at {} x {}, But Final Patch Size is {} x {}".format(custom_downsample, patch_size, patch_size,
target_patch_size, target_patch_size))

# the downsample corresponding to the patch_level
patch_downsample = (int(self.level_downsamples[patch_level][0]), int(self.level_downsamples[patch_level][1]))
# size of patch at level 0 (reference size)
ref_patch_size = (patch_size*patch_downsample[0], patch_size*patch_downsample[1])

# step sizes to take at levl 0
step_size_x = step_size * patch_downsample[0]
step_size_y = step_size * patch_downsample[1]

Expand Down Expand Up @@ -288,6 +294,7 @@ def _getPatchGenerator(self, cont, cont_idx, patch_level, save_path, patch_size=
if self.isBlackPatch(np.array(patch_PIL), rgbThresh=black_thresh) or self.isWhitePatch(np.array(patch_PIL), satThresh=white_thresh):
continue

# x, y coordinates become the coordinates in the downsample, and no long correspond to level 0 of WSI
patch_info = {'x':x // (patch_downsample[0] * custom_downsample), 'y':y // (patch_downsample[1] * custom_downsample), 'cont_idx':cont_idx, 'patch_level':patch_level,
'downsample': self.level_downsamples[patch_level], 'downsampled_level_dim': tuple(np.array(self.level_dim[patch_level])//custom_downsample), 'level_dim': self.level_dim[patch_level],
'patch_PIL':patch_PIL, 'name':self.name, 'save_path':save_path}
Expand Down

0 comments on commit a126ee3

Please sign in to comment.