Skip to content

Commit

Permalink
Using Pathology bundles for nuclick and classification models (#1172)
Browse files Browse the repository at this point in the history
* support bundles for nuclick and classify models

Signed-off-by: Sachidanand Alle <[email protected]>

* sync up changes

Signed-off-by: Sachidanand Alle <[email protected]>

* remove nuclick transform copy and use monai app instead

Signed-off-by: Sachidanand Alle <[email protected]>

Signed-off-by: Sachidanand Alle <[email protected]>
  • Loading branch information
SachidanandAlle committed Dec 3, 2022
1 parent 4e72ec6 commit 3148ff0
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 1,241 deletions.
1 change: 1 addition & 0 deletions monailabel/interfaces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):
self.path = None
self.labels = None
self.label_colors = None
self.bundle_path = None

def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
self.name = name
Expand Down
1 change: 1 addition & 0 deletions monailabel/tasks/train/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def _create_trainer(self, context: Context):
amp=self._amp,
postprocessing=self._validate_transforms(self.train_post_transforms(context), "Training", "post"),
key_train_metric=self.train_key_metric(context),
additional_metrics=self.train_additional_metrics(context),
train_handlers=train_handlers,
iteration_update=self.train_iteration_update(context),
event_names=self.event_names(context),
Expand Down
8 changes: 5 additions & 3 deletions monailabel/tasks/train/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def config(self):
"gpus": "all", # COMMA SEPARATE DEVICE INDEX
}

def _fetch_datalist(self, datastore: Datastore):
def _fetch_datalist(self, request, datastore: Datastore):
return datastore.datalist()

def _partition_datalist(self, datalist, request, shuffle=False):
Expand Down Expand Up @@ -144,16 +144,18 @@ def _load_checkpoint(self, output_dir, pretrained, train_handlers):
train_handlers.insert(0, loader)

def __call__(self, request, datastore: Datastore):
ds = self._fetch_datalist(datastore)
logger.info(f"Train Request: {request}")
ds = self._fetch_datalist(request, datastore)
train_ds, val_ds = self._partition_datalist(ds, request)

max_epochs = request.get("max_epochs", 50)
pretrained = request.get("pretrained", True)
multi_gpu = request.get("multi_gpu", False)
multi_gpu = request.get("multi_gpu", True)
multi_gpu = multi_gpu if torch.cuda.device_count() > 1 else False

gpus = request.get("gpus", "all")
gpus = list(range(torch.cuda.device_count())) if gpus == "all" else [int(g) for g in gpus.split(",")]
multi_gpu = True if multi_gpu and len(gpus) > 1 else False
logger.info(f"Using Multi GPU: {multi_gpu}; GPUS: {gpus}")
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")

Expand Down
4 changes: 2 additions & 2 deletions sample-apps/endoscopy/lib/trainers/inbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


class InBody(BundleTrainTask):
def _fetch_datalist(self, datastore: Datastore):
ds = super()._fetch_datalist(datastore)
def _fetch_datalist(self, request, datastore: Datastore):
ds = super()._fetch_datalist(request, datastore)

out_body = datastore.label_map.get("OutBody", 3) if isinstance(datastore, CVATDatastore) else 1
load = LoadImage(dtype=np.uint8, image_only=True)
Expand Down
83 changes: 8 additions & 75 deletions sample-apps/pathology/lib/configs/classification_nuclei.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from typing import Any, Dict, Optional, Union

import lib.infers
import lib.trainers
from monai.networks.nets import DenseNet121
from monai.bundle import download

from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.tasks.infer_v2 import InferTask
from monailabel.interfaces.tasks.train import TrainTask
from monailabel.utils.others.generic import download_file, strtobool

logger = logging.getLogger(__name__)

Expand All @@ -30,81 +28,16 @@ class ClassificationNuclei(TaskConfig):
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
super().init(name, model_dir, conf, planner, **kwargs)

# Labels
self.labels = {
"Neoplastic cells": 1,
"Inflammatory": 2,
"Connective/Soft tissue cells": 3,
"Dead Cells": 4,
"Epithelial": 5,
}
self.label_colors = {
"Neoplastic cells": (255, 0, 0),
"Inflammatory": (255, 255, 0),
"Connective/Soft tissue cells": (0, 255, 0),
"Dead Cells": (0, 0, 0),
"Epithelial": (0, 0, 255),
}

consep = strtobool(self.conf.get("consep", "false"))
if consep:
self.labels = {
"Other": 1,
"Inflammatory": 2,
"Epithelial": 3,
"Spindle-Shaped": 4,
}
self.label_colors = {
"Other": (255, 0, 0),
"Inflammatory": (255, 255, 0),
"Epithelial": (0, 0, 255),
"Spindle-Shaped": (0, 255, 0),
}

# Model Files
self.path = [
os.path.join(self.model_dir, f"pretrained_{name}{'_consep' if consep else ''}.pt"), # pretrained
os.path.join(self.model_dir, f"{name}{'_consep' if consep else ''}.pt"), # published
]

# Download PreTrained Model
if strtobool(self.conf.get("use_pretrained_model", "true")):
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
url = f"{url}/pathology_classification_densenet121_nuclei{'_consep' if consep else ''}.pt"
download_file(url, self.path[0])

# Network
self.network = DenseNet121(spatial_dims=2, in_channels=4, out_channels=len(self.labels))
bundle_name = conf.get("bundle_name", "pathology_nuclei_classification")
bundle_version = conf.get("bundle_version", "0.0.1")
self.bundle_path = os.path.join(self.model_dir, bundle_name)
if not os.path.exists(self.bundle_path):
download(name=bundle_name, version=bundle_version, bundle_dir=self.model_dir)

def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
task: InferTask = lib.infers.ClassificationNuclei(
path=self.path,
network=self.network,
labels=self.labels,
preload=strtobool(self.conf.get("preload", "false")),
roi_size=json.loads(self.conf.get("roi_size", "[128, 128]")),
config={
"label_colors": self.label_colors,
},
)
task: InferTask = lib.infers.ClassificationNuclei(self.bundle_path, self.conf)
return task

def trainer(self) -> Optional[TrainTask]:
output_dir = os.path.join(self.model_dir, self.name)
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]

task: TrainTask = lib.trainers.ClassificationNuclei(
model_dir=output_dir,
network=self.network,
load_path=load_path,
publish_path=self.path[1],
labels=self.labels,
description="Train Nuclei Classification Model",
train_save_interval=1,
config={
"max_epochs": 10,
"train_batch_size": 16,
"val_batch_size": 16,
},
)
task: TrainTask = lib.trainers.ClassificationNuclei(self.bundle_path, self.conf)
return task
61 changes: 8 additions & 53 deletions sample-apps/pathology/lib/configs/nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from typing import Any, Dict, Optional, Union

import lib.infers
import lib.trainers
from monai.networks.nets import BasicUNet
from monai.bundle import download

from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.tasks.infer_v2 import InferTask
from monailabel.interfaces.tasks.train import TrainTask
from monailabel.utils.others.generic import download_file, strtobool

logger = logging.getLogger(__name__)

Expand All @@ -30,59 +28,16 @@ class NuClick(TaskConfig):
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
super().init(name, model_dir, conf, planner, **kwargs)

# Labels
self.labels = {"Nuclei": 1}
self.label_colors = {"Nuclei": (0, 255, 255)}

consep = strtobool(self.conf.get("consep", "false"))

# Model Files
self.path = [
os.path.join(self.model_dir, f"pretrained_{name}{'_consep' if consep else ''}.pt"), # pretrained
os.path.join(self.model_dir, f"{name}{'_consep' if consep else ''}.pt"), # published
]

# Download PreTrained Model
if strtobool(self.conf.get("use_pretrained_model", "true")):
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
url = f"{url}/pathology_nuclick_bunet_nuclei{'_consep' if consep else ''}.pt"
download_file(url, self.path[0])

# Network
self.network = BasicUNet(
spatial_dims=2,
in_channels=5,
out_channels=1,
features=(32, 64, 128, 256, 512, 32),
)
bundle_name = conf.get("bundle_name", "pathology_nuclick_annotation")
bundle_version = conf.get("bundle_version", "0.0.1")
self.bundle_path = os.path.join(self.model_dir, bundle_name)
if not os.path.exists(self.bundle_path):
download(name=bundle_name, version=bundle_version, bundle_dir=self.model_dir)

def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
task: InferTask = lib.infers.NuClick(
path=self.path,
network=self.network,
labels=self.labels,
preload=strtobool(self.conf.get("preload", "false")),
roi_size=json.loads(self.conf.get("roi_size", "[512, 512]")),
config={"label_colors": self.label_colors, "ignore_non_click_patches": True},
)
task: InferTask = lib.infers.NuClick(self.bundle_path, self.conf)
return task

def trainer(self) -> Optional[TrainTask]:
output_dir = os.path.join(self.model_dir, self.name)
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]

task: TrainTask = lib.trainers.NuClick(
model_dir=output_dir,
network=self.network,
load_path=load_path,
publish_path=self.path[1],
labels=self.labels,
description="Train Nuclei DeepEdit Model",
train_save_interval=1,
config={
"max_epochs": 10,
"train_batch_size": 16,
"val_batch_size": 16,
},
)
task: TrainTask = lib.trainers.NuClick(self.bundle_path, self.conf)
return task
Loading

0 comments on commit 3148ff0

Please sign in to comment.