diff --git a/Dockerfile b/Dockerfile index cb2cb257..769124f3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,10 +19,10 @@ RUN apt-get update \ && sudo mv cuda-ubuntu2004-keyring.gpg /usr/share/keyrings/cuda-archive-keyring.gpg \ && rm -f cuda-keyring_1.0-1_all.deb && rm -f /etc/apt/sources.list.d/cuda.list -# Install miniconda +# Install Mamba directly ENV PATH $CONDA_DIR/bin:$PATH -RUN wget https://repo.continuum.io/miniconda/Miniconda$CONDA_PYTHON_VERSION-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ - /bin/bash /tmp/miniconda.sh -b -p $CONDA_DIR && \ +RUN wget https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh -O /tmp/mamba.sh && \ + /bin/bash /tmp/mamba.sh -b -p $CONDA_DIR && \ rm -rf /tmp/* && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -32,13 +32,13 @@ RUN useradd --create-home -s /bin/bash --no-user-group -u $USERID $USERNAME && \ chown $USERNAME $CONDA_DIR -R && \ adduser $USERNAME sudo && \ echo "$USERNAME ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers + USER $USERNAME WORKDIR /home/$USERNAME/ RUN cd /home/$USERNAME && git clone --depth 1 "https://github.com/NRCan/geo-deep-learning.git" --branch $GIT_TAG RUN conda config --set ssl_verify no -RUN conda install libarchive mamba -c conda-forge RUN mamba env create -f /home/$USERNAME/geo-deep-learning/environment.yml ENV PATH $CONDA_DIR/envs/geo_deep_env/bin:$PATH -RUN echo "source activate geo_deep_env" > ~/.bashrc +RUN echo "source activate geo_deep_env" > ~/.bashrc \ No newline at end of file diff --git a/config/model/gdl_hrnet.yaml b/config/model/gdl_hrnet.yaml new file mode 100644 index 00000000..7df76c52 --- /dev/null +++ b/config/model/gdl_hrnet.yaml @@ -0,0 +1,4 @@ +# @package _global_ +model: + _target_: models.hrnet.hrnet_ocr.HRNet + pretrained: True \ No newline at end of file diff --git a/config/model/gdl_segformer.yaml b/config/model/gdl_segformer.yaml new file mode 100644 index 00000000..e481fb25 --- /dev/null +++ b/config/model/gdl_segformer.yaml @@ -0,0 +1,4 @@ +# @package _global_ +model: + _target_: models.segformer.SegFormer + encoder: "mit_b2" \ No newline at end of file diff --git a/docs/source/model.rst b/docs/source/model.rst index 8b9cce6f..c6dea199 100755 --- a/docs/source/model.rst +++ b/docs/source/model.rst @@ -52,3 +52,20 @@ folder to the complete list on different combinaisons. Also from the same library, another version of *DeepLabV3*, named *DeepLabV3+* of the *Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation* paper. + +Segformer +================================================ + +*Segformer* model implementation is based on the `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers `_ paper. +The encoder is called from `SMP `_. For more code implementation details check this `repo `_. + +.. autoclass:: models.segformer.SegFormer + + +HRNet + OCR +================================================ + +*HRNet + OCR* model implementation is based on the `HRNet paper `_ and `OCR paper `_. +For more code implementation details check this `repo `_. + +.. autoclass:: models.hrnet.hrnet_ocr.HRNet \ No newline at end of file diff --git a/models/hrnet/__init__.py b/models/hrnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/hrnet/backbone.py b/models/hrnet/backbone.py new file mode 100644 index 00000000..280f91ca --- /dev/null +++ b/models/hrnet/backbone.py @@ -0,0 +1,455 @@ +""" +This HRNet implementation is modified from the following repository: +https://github.com/HRNet/HRNet-Semantic-Segmentation +""" + +import logging +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from models.hrnet.utils import ModelHelpers +from pytorch_lightning.utilities import rank_zero_only + +BatchNorm2d = ModelHelpers.batchnorm2d() +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + +__all__ = ['hrnetv2'] + + +model_urls = { + 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode='bilinear', + align_corners=False) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HRNetV2(nn.Module): + def __init__(self, n_class, **kwargs): + super(HRNetV2, self).__init__() + extra = { + 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, + 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, + 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, + 'FINAL_CONV_KERNEL': 1 + } + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + self.high_level_ch = np.int_(np.sum(pre_stage_channels)) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + + x = torch.cat([x[0], x1, x2, x3], 1) + + # x = self.last_layer(x) + return x + + +def hrnetv2(num_of_classes, pretrained=False, **kwargs): + model = HRNetV2(n_class=num_of_classes, **kwargs) + if pretrained: + weights_file = ModelHelpers.load_url(model_urls['hrnetv2'], download=True) + model.load_state_dict(torch.load(weights_file, map_location=None), strict=False) + return model + +if __name__ == "__main__": + from torchinfo import summary + + model = hrnetv2(num_of_classes=4, pretrained=True) + batch_size = 8 + summary(model, input_size=(batch_size, 3, 512, 512)) + + diff --git a/models/hrnet/hrnet_ocr.py b/models/hrnet/hrnet_ocr.py new file mode 100644 index 00000000..070d6d7f --- /dev/null +++ b/models/hrnet/hrnet_ocr.py @@ -0,0 +1,56 @@ +import logging +import torch.nn.functional as F + +from torch import nn +from models.hrnet.ocr import OCR +from models.hrnet.backbone import hrnetv2 + + + +class HRNet(nn.Module): + """High Resolution Network (hrnet_w48_v2) with Object Contextual Representation module + + Args: + pretrained (bool): use pretrained weights + in_channels (int): number of bands/channels + classes (int): number of classes + """ + def __init__(self, pretrained, in_channels, classes) -> None: + super(HRNet, self).__init__() + if in_channels != 3: + logging.critical(F"HRNet model expects three channels input") + self.encoder = hrnetv2(num_of_classes=classes, pretrained=pretrained) + high_level_ch = self.encoder.high_level_ch + self.decoder = OCR(num_classes=classes, high_level_ch=high_level_ch) + + def forward(self, input): + high_level_features = self.encoder(input) + cls_out, aux_out, _ = self.decoder(high_level_features) + + input_size = input.shape[2:] + aux_out = F.interpolate(aux_out, size=input_size, mode='bilinear', align_corners=False) + cls_out = F.interpolate(cls_out, size=input_size, mode='bilinear', align_corners=False) + if self.training: + return cls_out, aux_out + else: + return cls_out + +if __name__ == "__main__": + import torch + from torchinfo import summary + + model = HRNet(pretrained=True, in_channels=3, classes=4) + model.to("cuda") + batch_size = 4 + + mask_tensor = torch.randn([batch_size, 3, 512, 512]).cuda() + + output, output_aux = model(mask_tensor) + for name, para in model.named_parameters(): + print("-"*20) + print(f"name: {name}") + print(f"requires_grad: {para.requires_grad}") + # print(output.shape) + # print(output_aux.shape) + # summary(model, input_size=(batch_size, 3, 512, 512)) + \ No newline at end of file diff --git a/models/hrnet/ocr.py b/models/hrnet/ocr.py new file mode 100644 index 00000000..c9f33647 --- /dev/null +++ b/models/hrnet/ocr.py @@ -0,0 +1,46 @@ + +from torch import nn +from models.hrnet.utils import ModelHelpers +from models.hrnet.ocr_modules import SpatialGather_Module, SpatialOCR_Module + + +BNReLU = ModelHelpers.BNReLU + +class OCR(nn.Module): + + def __init__(self, num_classes, high_level_ch) -> None: + super(OCR, self).__init__() + + ocr_mid_channels = 512 + ocr_key_channels = 256 + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(high_level_ch, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + BNReLU(ocr_mid_channels),) + self.ocr_gather_head = SpatialGather_Module(num_classes) + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + ) + + self.cls_head = nn.Conv2d(ocr_mid_channels, num_classes, + kernel_size=1, stride=1, padding=0,bias=True) + + self.aux_head = nn.Sequential(nn.Conv2d(high_level_ch, high_level_ch, + kernel_size=1, stride=1, padding=0), + BNReLU(high_level_ch), + nn.Conv2d(high_level_ch, num_classes, + kernel_size=1, stride=1, padding=0, bias=True)) + + def forward(self, high_level_features): + feats = self.conv3x3_ocr(high_level_features) + aux_out = self.aux_head(high_level_features) + context = self.ocr_gather_head(feats, aux_out) + ocr_feats = self.ocr_distri_head(feats, context) + cls_out = self.cls_head(ocr_feats) + return cls_out, aux_out, ocr_feats + + \ No newline at end of file diff --git a/models/hrnet/ocr_modules.py b/models/hrnet/ocr_modules.py new file mode 100644 index 00000000..975e46db --- /dev/null +++ b/models/hrnet/ocr_modules.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.hrnet.utils import ModelHelpers + +BNReLU = ModelHelpers.BNReLU + +# BatchNorm2d = ModelHelpers.batchnorm2d(bn_type="torch_bn") +# def BNReLU(ch): +# return nn.Sequential(BatchNorm2d(ch), nn.ReLU()) + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + + Output: + The correlation of every class map with every feature map + shape = [n, num_feats, num_classes, 1] + + + """ + def __init__(self, scale=1): + super(SpatialGather_Module, self).__init__() + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, = probs.size(0), probs.size(1) + + # each class image now a vector + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature + maps (save memory cost) + Return: + N X C X H X W + ''' + def __init__(self, in_channels, key_channels, scale=1): + super(ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.in_channels), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) + + return context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation + for each pixel. + """ + + def __init__(self, in_channels, key_channels, out_channels, scale=1, dropout=0.1): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock(in_channels, + key_channels, + scale) + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(sum([in_channels,in_channels]), out_channels, + kernel_size=1, padding=0, bias=False), + BNReLU(out_channels), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + return output + \ No newline at end of file diff --git a/models/hrnet/utils.py b/models/hrnet/utils.py new file mode 100644 index 00000000..7fe00208 --- /dev/null +++ b/models/hrnet/utils.py @@ -0,0 +1,41 @@ +import sys +import os +import logging +import torch.nn as nn +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve +import torch +from typing import Union, Optional +from pathlib import Path +from pytorch_lightning.utilities import rank_zero_only + +class ModelHelpers: + + @staticmethod + def batchnorm2d(bn_type: Union[str, "torch_sync_bn", "torch_bn"] = "torch_bn"): + if bn_type == "torch_bn": + return nn.BatchNorm2d + if bn_type == "torch_sync_bn": + return nn.SyncBatchNorm + + @staticmethod + def BNReLU(ch: torch.Tensor): + batchnorm = ModelHelpers.batchnorm2d() + return nn.Sequential( + batchnorm(ch), + nn.ReLU()) + + @rank_zero_only + @staticmethod + def load_url(url: str, download: bool): + model_dir = Path.home() / ".cache" / "torch" / "checkpoints" + if not model_dir.is_dir(): + Path.mkdir(model_dir, parents=True) + filename = url.split('/')[-1] + cached_file = model_dir.joinpath(filename) + if not cached_file.is_file() and download: + logging.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, str(cached_file)) + return cached_file \ No newline at end of file diff --git a/models/segformer.py b/models/segformer.py new file mode 100644 index 00000000..b23a680b --- /dev/null +++ b/models/segformer.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import segmentation_models_pytorch as smp + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class Decoder(nn.Module): + def __init__(self, encoder="mit_b2", + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + embedding_dim=768, + num_classes=1, dropout_ratio=0.1): + super(Decoder, self).__init__() + if encoder == "mit_b0": + in_channels = [32, 64, 160, 256] + if encoder == "mit_b0" or "mit_b1": + embedding_dim = 256 + assert len(feature_strides) == len(in_channels) + assert min(feature_strides) == feature_strides[0] + + self.num_classes = num_classes + self.in_channels = in_channels + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + + self.linear_fuse = nn.Sequential( + nn.Conv2d(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True)) + self.dropout = nn.Dropout2d(dropout_ratio) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(input=_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(input=_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(input=_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class SegFormer(nn.Module): + """Segformer Model + Args: + encoder (str): encoder name + in_channels (int): number of bands/channels + classes (int): number of classes + """ + def __init__(self, encoder, in_channels, classes) -> None: + super().__init__() + self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, depth=5, drop_path_rate=0.1) + self.decoder = Decoder(encoder=encoder, num_classes=classes) + + def forward(self, img): + x = self.encoder(img)[2:] + x = self.decoder(x) + x = F.interpolate(input=x, size=img.shape[2:], scale_factor=None, mode='bilinear', align_corners=False) + return x \ No newline at end of file diff --git a/tests/model/test_models.py b/tests/model/test_models.py index 73f5c176..5b4d5e28 100644 --- a/tests/model/test_models.py +++ b/tests/model/test_models.py @@ -26,14 +26,17 @@ def test_net(self) -> None: hconf = HydraConfig() hconf.set_config(cfg) del cfg.loss.is_binary # prevent exception at instantiation - rand_img = torch.rand((2, 4, 64, 64)) + rand_img = torch.rand((2, 3, 64, 64)) print(cfg.model._target_) model = define_model_architecture( net_params=cfg.model, - in_channels=4, + in_channels=3, out_classes=4, ) - output = model(rand_img) + if cfg.model._target_ == "models.hrnet.hrnet_ocr.HRNet": + output, output_aux = model(rand_img) + else: + output = model(rand_img) print(output.shape) @@ -41,7 +44,7 @@ class TestReadCheckpoint(object): """ Tests reading a checkpoint saved outside GDL into memory """ - var = 4 + var = 3 dummy_model = models.unet.UNetSmall(classes=var, in_channels=var) dummy_optimizer = instantiate({'_target_': 'torch.optim.Adam'}, params=dummy_model.parameters()) filename = "test.pth.tar" @@ -80,7 +83,7 @@ class TestDefineModelMultigpu(object): """ Tests defining model architecture with weights from provided checkpoint and pushing to multiple devices if possible """ - dummy_model = unet.UNet(4, 4, True, 0.5) + dummy_model = unet.UNet(4, 3, True, 0.5) filename = "test.pth.tar" torch.save(dummy_model.state_dict(), filename) @@ -92,7 +95,7 @@ class TestDefineModelMultigpu(object): checkpoint = read_checkpoint(filename) model = define_model( net_params={'_target_': 'models.unet.UNet'}, - in_channels=4, + in_channels=3, out_classes=4, main_device=device, devices=list(gpu_devices_dict.keys()), diff --git a/tests/test_tiling_segmentation.py b/tests/test_tiling_segmentation.py index f03b0bf1..18478aec 100644 --- a/tests/test_tiling_segmentation.py +++ b/tests/test_tiling_segmentation.py @@ -170,14 +170,20 @@ def test_tiling_segmentation_parallel(self): } cfg = DictConfig(cfg) tiling(cfg) - out_labels = [ - (Path(f"{data_dir}/{proj}/trn/23322E759967N_clipped_1m_1of2/labels_burned"), (80, 95)), - (Path(f"{data_dir}/{proj}/val/23322E759967N_clipped_1m_1of2/labels_burned"), (5, 20)), - (Path(f"{data_dir}/{proj}/tst/23322E759967N_clipped_1m_2of2/labels_burned"), (170, 190)), - ] - for labels_burned_dir, lbls_nb in out_labels: - # exact number may vary because of random sort between "trn" and "val" - assert lbls_nb[0] <= len(list(labels_burned_dir.iterdir())) <= lbls_nb[1] + trn_labels = list(Path(f"{data_dir}/{proj}/trn/").glob("*/labels_burned/*.tif")) + val_labels = list(Path(f"{data_dir}/{proj}/val/").glob("*/labels_burned/*.tif")) + tst_labels = list(Path(f"{data_dir}/{proj}/tst/").glob("*/labels_burned/*.tif")) + assert len(trn_labels) > 0 + assert len(val_labels) > 0 + assert len(tst_labels) > 0 + + patch_size = cfg.tiling.patch_size + for label_list in [trn_labels, val_labels, tst_labels]: + num_tifs_to_check = min(5, len(label_list)) + for tif_file in label_list[:num_tifs_to_check]: + with rasterio.open(tif_file) as src: + width, height = src.width, src.height + assert width == patch_size and height == patch_size shutil.rmtree(Path(data_dir) / proj) def test_tiling_inference(self): diff --git a/train_segmentation.py b/train_segmentation.py index b244c929..335bbc09 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -297,6 +297,7 @@ def training(train_loader, device, scale, vis_params, + aux_output: bool = False, debug=False): """ Train the model and return the metrics of the training epoch @@ -327,7 +328,10 @@ def training(train_loader, # forward optimizer.zero_grad() - outputs = model(inputs) + if aux_output: + outputs, outputs_aux = model(inputs) + else: + outputs = model(inputs) # added for torchvision models that output an OrderedDict with outputs in 'out' key. # More info: https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ if isinstance(outputs, OrderedDict): @@ -349,9 +353,13 @@ def training(train_loader, dataset='trn', ep_num=ep_idx + 1, scale=scale) - - loss = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) - + if aux_output: + loss_main = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + loss_aux = criterion(outputs_aux, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + loss = 0.4 * loss_aux + loss_main + else: + loss = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + train_metrics['loss'].update(loss.item(), batch_size) if device.type == 'cuda' and debug: @@ -628,6 +636,7 @@ def train(cfg: DictConfig) -> None: # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH checkpoint = read_checkpoint(train_state_dict_path) + aux_output = False model = define_model( net_params=cfg.model, in_channels=num_bands, @@ -637,7 +646,9 @@ def train(cfg: DictConfig) -> None: checkpoint_dict=checkpoint, checkpoint_dict_strict_load=state_dict_strict ) - + + if cfg.model._target_ == "models.hrnet.hrnet_ocr.HRNet": + aux_output = True criterion = define_loss(loss_params=cfg.loss, class_weights=class_weights) criterion = criterion.to(device) optimizer = instantiate(cfg.optimizer, params=model.parameters()) @@ -717,6 +728,7 @@ def train(cfg: DictConfig) -> None: device=device, scale=scale, vis_params=vis_params, + aux_output=aux_output, debug=debug) if 'trn_log' in locals(): # only save the value if a tracker is setup trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])