Skip to content

Commit

Permalink
Decomposed conv forward mods and some fixes (#76)
Browse files Browse the repository at this point in the history
*  update SVDOptimization for forward_mode
  • Loading branch information
N-31V authored Aug 7, 2023
1 parent c3e8006 commit 4b31871
Show file tree
Hide file tree
Showing 13 changed files with 440 additions and 329 deletions.
55 changes: 54 additions & 1 deletion fedot_ind/core/architecture/datasets/visualization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union, List, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch

_PALETTE = ((255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255), (128, 0, 0),
(0, 128, 0), (128, 0, 128), (0, 128, 128), (0, 0, 128), (128, 0, 0), (220, 20, 60), (255, 165, 0),
(218, 165, 32), (240, 230, 140), (154, 205, 50), (107, 142, 35), (0, 100, 0), (46, 139, 87),
(32, 178, 170), (64, 224, 208), (70, 130, 180), (138, 43, 226), (72, 61, 139), (147, 112, 219),
(139, 0, 139), (218, 112, 214), (219, 112, 147), (255, 20, 147), (255, 228, 196), (139, 69, 19),
(210, 105, 30), (244, 164, 96), (188, 143, 143), (112, 128, 144), (230, 230, 250), (245, 245, 245))


def _2nparray(arrs: List[Union[torch.Tensor, List]]) -> List[np.ndarray]:
return [arr.numpy() if isinstance(arr, torch.Tensor) else np.array(arr) for arr in arrs]
Expand Down Expand Up @@ -67,3 +74,49 @@ def draw_sample_with_bboxes(
ax.set_axis_off()
ax.imshow(pimage)
return fig


def _put_mask(axis: plt.axis, image: np.ndarray, mask: np.ndarray, palette: Tuple):
thickness = 1 + int(image.shape[-2] / 500)
image = image.copy()
for ch in range(mask.shape[0]):
contours, _ = cv2.findContours(mask[ch, :, :], cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
for contour in contours:
cv2.polylines(image, contour, True, palette[ch], thickness)
axis.set_axis_off()
axis.imshow(image)


def draw_sample_with_masks(
image: torch.Tensor,
target: torch.Tensor = None,
prediction: torch.Tensor = None,
palette: Tuple = _PALETTE
) -> plt.Figure:
"""
Returns the image with masks.
Args:
image: Image tensor.
target: N-channel target tensor with masks, where n - number of classes.
prediction: N-channel prediction tensor with masks, where n - number of classes.
palette: Color palette for each class.
Returns:
`matplotlib.pyplot.Figure` of the image with masks.
"""
assert prediction is not None or target is not None, "At least one parameter from 'target' and 'prediction' must not be None"

image = image.permute(1, 2, 0).numpy()

n = 1 if prediction is None or target is None else 2
fig = plt.figure(figsize=(10 * n, 10))

if target is not None:
ax = plt.subplot(1, n, 1)
_put_mask(ax, image, target.numpy().astype(np.uint8), palette)

if prediction is not None:
ax = plt.subplot(1, n, n)
_put_mask(ax, image, prediction.numpy().astype(np.uint8), palette)
return fig
53 changes: 34 additions & 19 deletions fedot_ind/core/architecture/experiment/nn_experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import shutil
from functools import partial
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union
Expand All @@ -15,7 +16,7 @@

from fedot_ind.core.architecture.abstraction.writers import CSVWriter, TFWriter, WriterComposer
from fedot_ind.core.architecture.abstraction.сheckers import parameter_value_check
from fedot_ind.core.metrics.cv_metrics import ClassificationMetricCounter, LossesAverager, \
from fedot_ind.core.metrics.cv_metrics import MetricCounter, ClassificationMetricCounter, LossesAverager, \
ObjectDetectionMetricCounter, SegmentationMetricCounter


Expand All @@ -32,6 +33,7 @@ class FitParameters:
lr_scheduler: Type of learning rate scheduler, e.g ``torch.optim.lr_scheduler.StepLR``.
models_path: Path to folder for saving models.
summary_path: Path to folder for writing experiment summary info.
validation_period: Validation frequency in epochs.
class_metrics: If ``True``, calculates validation metrics for each class.
description: Additional line describing the experiment.
Expand All @@ -41,10 +43,11 @@ class FitParameters:
train_dl: DataLoader
val_dl: DataLoader
num_epochs: int
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
lr_scheduler: Optional[Type[torch.optim.lr_scheduler.LRScheduler]] = None
optimizer: Union[Type[torch.optim.Optimizer], partial] = torch.optim.Adam
lr_scheduler: Optional[Union[Type[torch.optim.lr_scheduler.LRScheduler], partial]] = None
models_path: Union[Path, str] = 'models'
summary_path: Union[Path, str] = 'summary'
validation_period: int = 1
class_metrics: bool = False
description: str = ''

Expand All @@ -67,7 +70,7 @@ def __init__(
self,
model: torch.nn.Module,
metric: str,
metric_counter,
metric_counter: Type[MetricCounter],
name: Optional[str],
weights: Optional[str],
device: str,
Expand All @@ -87,31 +90,37 @@ def fit(self,
p: FitParameters,
phase: str = 'train',
model_losses: Optional[Callable] = None,
start_epoch: int = 0
filter_pruning: Optional[Dict] = None,
start_epoch: int = 0,
initial_validation: bool = False
) -> None:
"""Run model training.
Args:
p: An object containing training parameters.
phase: String explanation of training.
model_losses: Function for calculating losses from model weights.
filter_pruning: Parameters (pruning function and condition) passed to ``apply_func`` function.
start_epoch: Initial training epoch.
initial_validation: If ``True`` run validation loop before training.
"""
model_path = os.path.join(p.models_path, p.dataset_name, self.name, p.description, phase)
summary_path = os.path.join(p.summary_path, p.dataset_name, self.name, p.description, phase)
writer = WriterComposer(summary_path, [TFWriter, CSVWriter])

self.logger.info(f"{phase}: {self.name}, using device: {self.device}")
init_scores = self.val_loop(dataloader=p.val_dl, class_metrics=p.class_metrics)
writer.write_scores('val', init_scores, start_epoch)
self._save_model_sd_if_best(val_scores=init_scores, file_path=model_path)

if initial_validation:
init_scores = self.val_loop(dataloader=p.val_dl, class_metrics=p.class_metrics)
writer.write_scores('val', init_scores, start_epoch)
self._save_model_sd_if_best(val_scores=init_scores, file_path=model_path)
start_epoch += 1

optimizer = p.optimizer(self.model.parameters())
lr_scheduler = None
if p.lr_scheduler is not None:
lr_scheduler = p.lr_scheduler(optimizer)

for epoch in range(start_epoch, start_epoch + p.num_epochs):
self.logger.info(f"Epoch {epoch}")
train_scores = self.train_loop(
Expand All @@ -120,15 +129,21 @@ def fit(self,
model_losses=model_losses
)
writer.write_scores('train', train_scores, epoch)
val_scores = self.val_loop(
dataloader=p.val_dl,
class_metrics=p.class_metrics
)
writer.write_scores('val', val_scores, epoch)
self._save_model_sd_if_best(val_scores=val_scores, file_path=model_path)
if isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step(val_scores[self.metric])
elif isinstance(lr_scheduler, LRScheduler):

if filter_pruning is not None:
self._apply_function(**filter_pruning)

if epoch % p.validation_period == 0:
val_scores = self.val_loop(
dataloader=p.val_dl,
class_metrics=p.class_metrics
)
writer.write_scores('val', val_scores, epoch)
self._save_model_sd_if_best(val_scores=val_scores, file_path=model_path)
if isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step(val_scores[self.metric])

if isinstance(lr_scheduler, LRScheduler) and not isinstance(lr_scheduler, ReduceLROnPlateau):
lr_scheduler.step()
self.load_model(model_path)
self.logger.info(f'{self.metric} score: {self.best_score}')
Expand Down Expand Up @@ -208,7 +223,7 @@ def number_of_model_params(self) -> int:
"""Returns number of model parameters."""
return sum(p.numel() for p in self.model.parameters())

def _apply_func(
def _apply_function(
self,
func: Callable,
condition: Optional[Callable] = None
Expand Down
62 changes: 38 additions & 24 deletions fedot_ind/core/metrics/cv_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains functions and classes for computing metrics
in computer vision tasks.
"""
from abc import ABC, abstractmethod
from typing import Dict, List

import torch
Expand All @@ -10,52 +11,68 @@
from torchmetrics.detection.mean_ap import MeanAveragePrecision


def iou_score(outputs: torch.Tensor, masks: torch.Tensor, smooth=1e-10) -> torch.Tensor:
def iou_score(
outputs: torch.Tensor,
masks: torch.Tensor,
threshold: float = 0.5,
smooth: float = 1e-10
) -> torch.Tensor:
"""Computes intersection over union (masks) on batch.
Args:
outputs: Output from semantic segmentation model.
masks: True masks.
threshold: Binarization threshold for output.
smooth: Additional constant to avoid division by zero.
Returns:
Intersection over union for batch.
"""
outputs = (outputs > 0.5).float()
outputs = (outputs > threshold).float()
intersection = torch.logical_and(outputs, masks).float().sum((2, 3))
union = torch.logical_or(outputs, masks).float().sum((2, 3))
iou = (intersection + smooth) / (union + smooth)
iou[union == 0] = -1
return iou


def dice_score(outputs: torch.Tensor, masks: torch.Tensor, smooth=1e-10) -> torch.Tensor:
def dice_score(
outputs: torch.Tensor,
masks: torch.Tensor,
threshold: float = 0.5,
smooth: float = 1e-10
) -> torch.Tensor:
"""Computes dice coefficient (masks) on batch.
Args:
outputs: Output from semantic segmentation model.
masks: True masks.
threshold: Binarization threshold for output.
smooth: Additional constant to avoid division by zero.
Returns:
Dice for batch.
"""
outputs = (outputs > 0.5).float()
outputs = (outputs > threshold).float()
intersection = torch.logical_and(outputs, masks).float().sum((2, 3))
total = (outputs + masks).sum((2, 3))
dice = (2 * intersection + smooth) / (total + smooth)
dice[total == 0] = -1
return dice


class MetricCounter:
class MetricCounter(ABC):
"""Generalized class for calculating metrics"""

def __init__(self) -> None:
def __init__(self, **kwargs) -> None:
pass

@abstractmethod
def update(self, **kwargs) -> None:
"""Have to implement updating, taking model outputs as input."""
raise NotImplementedError

@abstractmethod
def compute(self) -> Dict[str, float]:
"""Have to implement metrics computing."""
raise NotImplementedError
Expand Down Expand Up @@ -112,37 +129,34 @@ class SegmentationMetricCounter(MetricCounter):

def __init__(self, class_metrics: bool = False) -> None:
super().__init__()
self.iou = None
self.dice = None
self.n = 0
self.iou = []
self.dice = []
self.class_metrics = class_metrics

def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None:
"""Accumulates iou and dice."""
masks = torch.zeros_like(predictions)
for i in range(predictions.size()[1]):
for i in range(predictions.shape[1]):
masks[:, i, :, :] = torch.squeeze(targets == i)
self.n += predictions.size()[0]
if self.iou is None:
self.iou = iou_score(predictions, masks).sum(0)
else:
self.iou += iou_score(predictions, masks).sum(0)
if self.dice is None:
self.dice = dice_score(predictions, masks).sum(0)
else:
self.dice += dice_score(predictions, masks).sum(0)
self.iou.append(iou_score(predictions, masks))
self.dice.append(dice_score(predictions, masks))

def compute(self) -> Dict[str, float]:
"""Compute average metrics.
Returns:
Dictionary: `{metric: score}`.
"""
iou = self.iou / self.n
dice = self.dice/ self.n
scores = {'iou': iou.mean().item(), 'dice': dice.mean().item()}
iou = torch.cat(self.iou).T
dice = torch.cat(self.dice).T

scores = {
'iou': iou[1:][iou[1:] >= 0].mean().item(),
'dice': dice[1:][dice[1:] >= 0].mean().item()
}
if self.class_metrics:
scores.update({f'iou_for_class_{i}': s.item() for i, s in enumerate(iou)})
scores.update({f'dice_for_class_{i}': s.item() for i, s in enumerate(dice)})
scores.update({f'iou_for_class_{i}': s[s >= 0].mean().item() for i, s in enumerate(iou)})
scores.update({f'dice_for_class_{i}': s[s >= 0].mean().item() for i, s in enumerate(dice)})
return scores


Expand Down
Loading

0 comments on commit 4b31871

Please sign in to comment.