forked from mahmoodlab/CLAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathheatmap_utils.py
90 lines (72 loc) · 3.22 KB
/
heatmap_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import os
import pandas as pd
from utils.utils import *
from PIL import Image
from math import floor
import matplotlib.pyplot as plt
from datasets.wsi_dataset import Wsi_Region
import h5py
from wsi_core.WholeSlideImage import WholeSlideImage
from scipy.stats import percentileofscore
import math
from utils.file_utils import save_hdf5
from scipy.stats import percentileofscore
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
def score2percentile(score, ref):
percentile = percentileofscore(ref, score)
return percentile
def drawHeatmap(scores, coords, slide_path=None, wsi_object=None, vis_level = -1, **kwargs):
if wsi_object is None:
wsi_object = WholeSlideImage(slide_path)
print(wsi_object.name)
wsi = wsi_object.getOpenSlide()
if vis_level < 0:
vis_level = wsi.get_best_level_for_downsample(32)
heatmap = wsi_object.visHeatmap(scores=scores, coords=coords, vis_level=vis_level, **kwargs)
return heatmap
def initialize_wsi(wsi_path, seg_mask_path=None, seg_params=None, filter_params=None):
wsi_object = WholeSlideImage(wsi_path)
if seg_params['seg_level'] < 0:
best_level = wsi_object.wsi.get_best_level_for_downsample(32)
seg_params['seg_level'] = best_level
wsi_object.segmentTissue(**seg_params, filter_params=filter_params)
wsi_object.saveSegmentation(seg_mask_path)
return wsi_object
def compute_from_patches(wsi_object, clam_pred=None, model=None, feature_extractor=None, batch_size=512,
attn_save_path=None, ref_scores=None, feat_save_path=None, **wsi_kwargs):
top_left = wsi_kwargs['top_left']
bot_right = wsi_kwargs['bot_right']
patch_size = wsi_kwargs['patch_size']
roi_dataset = Wsi_Region(wsi_object, **wsi_kwargs)
roi_loader = get_simple_loader(roi_dataset, batch_size=batch_size, num_workers=8)
print('total number of patches to process: ', len(roi_dataset))
num_batches = len(roi_loader)
print('number of batches: ', len(roi_loader))
mode = "w"
for idx, (roi, coords) in enumerate(roi_loader):
roi = roi.to(device)
coords = coords.numpy()
with torch.no_grad():
features = feature_extractor(roi)
if attn_save_path is not None:
A = model(features, attention_only=True)
if A.size(0) > 1: #CLAM multi-branch attention
A = A[clam_pred]
A = A.view(-1, 1).cpu().numpy()
if ref_scores is not None:
for score_idx in range(len(A)):
A[score_idx] = score2percentile(A[score_idx], ref_scores)
asset_dict = {'attention_scores': A, 'coords': coords}
save_path = save_hdf5(attn_save_path, asset_dict, mode=mode)
if idx % math.ceil(num_batches * 0.05) == 0:
print('processed {} / {}'.format(idx, num_batches))
if feat_save_path is not None:
asset_dict = {'features': features.cpu().numpy(), 'coords': coords}
save_hdf5(feat_save_path, asset_dict, mode=mode)
mode = "a"
return attn_save_path, feat_save_path, wsi_object