Skip to content

Commit

Permalink
added code for generating slide-level local features (of shape [M, np…
Browse files Browse the repository at this point in the history
…atch, 384] where M = number of regions in the slide, and npatch = number of (256,256) patches in each region)
  • Loading branch information
clemsgrs committed Apr 16, 2024
1 parent 79e28ab commit 2ea47f4
Show file tree
Hide file tree
Showing 14 changed files with 501 additions and 62 deletions.
25 changes: 25 additions & 0 deletions dinov2/configs/eval/knn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
data:
query_dataset:
test_dataset:
batch_size: 256
image_size: 256
student:
pretrained_weights:
knn:
nb_knn: [10, 20, 100, 200]
temperature: 0.07
n_tries: 1
n_per_class_list: -1
output_dir: ./output
speed:
num_workers: 8
gather_on_cpu: false
wandb:
enable: false
project: 'vision'
username: 'vlfm'
exp_name: 'knn'
tags: ['${wandb.exp_name}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ train:
centering: sinkhorn_knopp
inference:
data_dir: /root/data
image_size: 256
batch_size: 64
num_workers: 8
student:
Expand Down
33 changes: 33 additions & 0 deletions dinov2/configs/inference/vits14_slide.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
dino:
head_bottleneck_dim: 384
train:
centering: sinkhorn_knopp
experiment_name: 'feature_extraction'
inference:
root_dir: '/data/pathology/projects/ais-cap/dataset/panda/hs2p/patches/otsu/${inference.region_size}/jpg'
slide_list:
level: local
region_size: 2048
patch_size: 256
num_workers: 8
save_region_features: false
student:
arch: vit_small
patch_size: 14
num_register_tokens: 0
pretrained_weights: '/data/pathology/projects/ais-cap/clement/code/dinov2/output/769naczt/eval/training_649999/teacher_checkpoint.pth'
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
crops:
global_crops_size: 224
local_crops_size: 98
wandb:
enable: true
project: 'hipt'
username: 'clemsg'
exp_name: '${train.experiment_name}'
tags: ['${wandb.exp_name}', '${inference.level}', '${inference.region_size}', '${student.arch}']
dir: '/home/user'
group:
resume_id:
1 change: 1 addition & 0 deletions dinov2/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .knn import KNNDataset
from .foundation import PathologyFoundationDataset
from .image_folder import ImageFolderWithNameDataset
from .regions import SlideIDsDataset, SlideRegionDataset
64 changes: 64 additions & 0 deletions dinov2/data/datasets/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch

from einops import rearrange
from pathlib import Path
from torchvision import transforms
from typing import Callable, Optional
from torchvision.datasets.folder import default_loader


class SlideIDsDataset(torch.utils.data.Dataset):
"""Dataset for iterating over slide IDs."""

def __init__(self, slide_ids):
"""
Args:
slide_ids (list of str): List of slide IDs.
"""
self.slide_ids = slide_ids

def __len__(self):
return len(self.slide_ids)

def __getitem__(self, index):
slide_id = self.slide_ids[index]
return slide_id


class SlideRegionDataset(torch.utils.data.Dataset):
def __init__(
self,
root_dir: Path,
slide_id: str,
fmt: str = "jpg",
image_size: int = 256,
loader: Callable = default_loader,
transform: Optional[Callable] = None,
):
self.root_dir = root_dir
self.slide_id = slide_id
self.format = fmt
self.image_size = image_size
self.transform = transform
self.loader = loader
self.region_paths = self._find_region_paths()

def _find_region_paths(self):
region_dir = Path(self.root_dir, self.slide_id, "imgs")
sorted_region_paths = sorted([str(fp) for fp in region_dir.glob(f"*.{self.format}")])
return sorted_region_paths

def __len__(self):
return len(self.region_paths)

def __getitem__(self, idx):
region_path = self.region_paths[idx]
region = self.loader(region_path)
region = transforms.functional.to_tensor(region) # [3, region_size, region_size]
region = region.unfold(1, self.image_size, self.image_size).unfold(
2, self.image_size, self.image_size
) # [3, npatch, npatch, image_size, image_size]
region = rearrange(region, "c p1 p2 w h -> (p1 p2) c w h") # [num_patches, 3, image_size, image_size]
if self.transform is not None:
region = self.transform(region)
return idx, region, region_path
37 changes: 30 additions & 7 deletions dinov2/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,40 @@ def make_classification_train_transform(
# This matches (roughly) torchvision's preset for classification evaluation:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(
*,
image_size: int,
resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
if image_size == crop_size:
transforms_list = [
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
elif image_size > crop_size:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
return transforms.Compose(transforms_list)


def make_feature_extraction_transform(
image_size: int,
resize_size: int = 224,
interpolation=transforms.InterpolationMode.BICUBIC,
) -> transforms.Compose:
if image_size == resize_size:
transforms_list = [
MaybeToTensor(),
]
elif image_size > resize_size:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
MaybeToTensor(),
]
return transforms.Compose(transforms_list)
4 changes: 2 additions & 2 deletions dinov2/eval/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dinov2.eval.metrics import AccuracyAveraging, build_metric
from dinov2.utils.utils import initialize_wandb
from dinov2.utils.config import setup, write_config
from dinov2.eval.setup import setup_and_build_model
from dinov2.models import setup_and_build_model
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
from dinov2.data.transforms import make_classification_eval_transform

Expand Down Expand Up @@ -387,7 +387,7 @@ def main(args):

model, autocast_dtype = setup_and_build_model(cfg)

transform = make_classification_eval_transform()
transform = make_classification_eval_transform(image_size=cfg.data.image_size)
query_dataset_str = cfg.data.query_dataset
test_dataset_str = cfg.data.test_dataset
query_dataset = make_dataset(
Expand Down
5 changes: 2 additions & 3 deletions dinov2/eval/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch
import torch.backends.cudnn as cudnn

from dinov2.models import build_model_from_cfg
import dinov2.utils.utils as dinov2_utils
from dinov2.models import build_model_from_cfg, load_pretrained_weights


def get_args_parser(
Expand Down Expand Up @@ -60,7 +59,7 @@ def get_autocast_dtype(config):

def build_model_for_eval(config):
model, _ = build_model_from_cfg(config, only_teacher=True)
dinov2_utils.load_pretrained_weights(model, config.student.pretrained_weights, "teacher")
load_pretrained_weights(model, config.student.pretrained_weights, "teacher")
model.eval()
model.cuda()
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from dinov2.utils.utils import initialize_wandb
from dinov2.data import SamplerType, make_data_loader
from dinov2.data.datasets import ImageFolderWithNameDataset
from dinov2.data.transforms import make_classification_eval_transform

# from dinov2.data.transforms import make_classification_eval_transform
from dinov2.data.transforms import make_feature_extraction_transform


def get_args_parser(add_help: bool = True):
Expand Down Expand Up @@ -83,7 +85,8 @@ def main(args):
verbose=distributed.is_main_process(),
)

transform = make_classification_eval_transform()
# transform = make_classification_eval_transform()
transform = make_feature_extraction_transform(image_size=cfg.inference.image_size)
dataset = ImageFolderWithNameDataset(cfg.inference.data_dir, transform)

if run_distributed:
Expand Down
Loading

0 comments on commit 2ea47f4

Please sign in to comment.