diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index 7338a802..c3124f82 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -1,3 +1,4 @@ +import inspect from abc import ABC, abstractmethod from typing import Generic, TypeVar @@ -80,8 +81,6 @@ class BaseNode( Provide only in case the `input_shapes` were not provided. """ - attach_index: AttachIndexType = "all" - def __init__( self, *, @@ -96,7 +95,21 @@ def __init__( ): super().__init__() - self.attach_index = attach_index or self.attach_index + if attach_index is None: + parameters = inspect.signature(self.forward).parameters + inputs_forward_type = parameters.get( + "inputs", parameters.get("input", parameters.get("x", None)) + ) + if ( + inputs_forward_type is not None + and inputs_forward_type.annotation == Tensor + ): + self.attach_index = -1 + else: + self.attach_index = "all" + else: + self.attach_index = attach_index + self.in_protocols = in_protocols or [FeaturesProtocol] self.task_type = task_type diff --git a/luxonis_train/nodes/bisenet_head.py b/luxonis_train/nodes/bisenet_head.py index 99845177..a3b11df6 100644 --- a/luxonis_train/nodes/bisenet_head.py +++ b/luxonis_train/nodes/bisenet_head.py @@ -15,7 +15,6 @@ class BiSeNetHead(BaseNode[Tensor, Tensor]): - attach_index: int = -1 in_height: int in_channels: int @@ -45,6 +44,6 @@ def wrap(self, output: Tensor) -> Packet[Tensor]: return {"segmentation": [output]} def forward(self, inputs: Tensor) -> Tensor: - inputs = self.conv_3x3(inputs) - inputs = self.conv_1x1(inputs) - return self.upscale(inputs) + x = self.conv_3x3(inputs) + x = self.conv_1x1(x) + return self.upscale(x) diff --git a/luxonis_train/nodes/classification_head.py b/luxonis_train/nodes/classification_head.py index 10f9b3c9..d96e6b72 100644 --- a/luxonis_train/nodes/classification_head.py +++ b/luxonis_train/nodes/classification_head.py @@ -7,7 +7,6 @@ class ClassificationHead(BaseNode[Tensor, Tensor]): in_channels: int - attach_index: int = -1 def __init__( self, diff --git a/luxonis_train/nodes/contextspatial.py b/luxonis_train/nodes/contextspatial.py index adbb84bc..1ca1460d 100644 --- a/luxonis_train/nodes/contextspatial.py +++ b/luxonis_train/nodes/contextspatial.py @@ -18,8 +18,6 @@ class ContextSpatial(BaseNode[Tensor, list[Tensor]]): - attach_index: int = -1 - def __init__(self, context_backbone: str = "MobileNetV2", **kwargs): """Context spatial backbone. TODO: Add more documentation. @@ -34,9 +32,9 @@ def __init__(self, context_backbone: str = "MobileNetV2", **kwargs): self.spatial_path = SpatialPath(3, 128) self.ffm = FeatureFusionBlock(256, 256) - def forward(self, x: Tensor) -> list[Tensor]: - spatial_out = self.spatial_path(x) - context16, _ = self.context_path(x) + def forward(self, inputs: Tensor) -> list[Tensor]: + spatial_out = self.spatial_path(inputs) + context16, _ = self.context_path(inputs) fm_fuse = self.ffm(spatial_out, context16) outs = [fm_fuse] return outs diff --git a/luxonis_train/nodes/efficientrep.py b/luxonis_train/nodes/efficientrep.py index e6a014af..ccff4189 100644 --- a/luxonis_train/nodes/efficientrep.py +++ b/luxonis_train/nodes/efficientrep.py @@ -19,8 +19,6 @@ class EfficientRep(BaseNode[Tensor, list[Tensor]]): - attach_index: int = -1 - def __init__( self, channels_list: list[int] | None = None, @@ -104,9 +102,9 @@ def set_export_mode(self, mode: bool = True) -> None: if isinstance(module, RepVGGBlock): module.reparametrize() - def forward(self, x: Tensor) -> list[Tensor]: + def forward(self, inputs: Tensor) -> list[Tensor]: outputs = [] - x = self.repvgg_encoder(x) + x = self.repvgg_encoder(inputs) for block in self.blocks: x = block(x) outputs.append(x) diff --git a/luxonis_train/nodes/implicit_keypoint_bbox_head.py b/luxonis_train/nodes/implicit_keypoint_bbox_head.py index 0fdca420..aff2b5a6 100644 --- a/luxonis_train/nodes/implicit_keypoint_bbox_head.py +++ b/luxonis_train/nodes/implicit_keypoint_bbox_head.py @@ -1,6 +1,6 @@ import logging import math -from typing import Literal, cast +from typing import cast import torch from torch import Tensor, nn @@ -22,8 +22,6 @@ class ImplicitKeypointBBoxHead(BaseNode): - attach_index: Literal["all"] = "all" - def __init__( self, n_keypoints: int | None = None, diff --git a/luxonis_train/nodes/micronet.py b/luxonis_train/nodes/micronet.py index 03b43e1f..603eabde 100644 --- a/luxonis_train/nodes/micronet.py +++ b/luxonis_train/nodes/micronet.py @@ -15,8 +15,6 @@ class MicroNet(BaseNode[Tensor, list[Tensor]]): TODO: DOCS """ - attach_index: int = -1 - def __init__(self, variant: Literal["M1", "M2", "M3"] = "M1", **kwargs): """MicroNet backbone. @@ -236,23 +234,21 @@ def __init__( ChannelShuffle(out_channels // 2) if y3 != 0 else nn.Sequential(), ) - def forward(self, x: Tensor): - identity = x - out = self.layers(x) + def forward(self, inputs: Tensor) -> Tensor: + out = self.layers(inputs) if self.identity: - out += identity + out += inputs return out class ChannelShuffle(nn.Module): def __init__(self, groups: int): - super(ChannelShuffle, self).__init__() + super().__init__() self.groups = groups - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: b, c, h, w = x.size() channels_per_group = c // self.groups - # reshape x = x.view(b, self.groups, channels_per_group, h, w) x = torch.transpose(x, 1, 2).contiguous() out = x.view(b, -1, h, w) @@ -300,7 +296,7 @@ def __init__( indexs = torch.cat([indexs[1], indexs[0]], dim=2) self.index = indexs.view(in_channels).long() - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: B, C, _, _ = x.shape x_out = x @@ -350,7 +346,7 @@ def __init__(self, in_channels: int, out_channels: int): nn.Linear(in_channels, out_channels), nn.BatchNorm1d(out_channels), HSwish() ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return self.linear(x) @@ -383,7 +379,7 @@ def __init__( ChannelShuffle(out_channels1), ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return self.conv(x) @@ -394,7 +390,7 @@ def __init__(self, in_channels: int, stride: int, outs: tuple[int, int] = (4, 4) SpatialSepConvSF(in_channels, outs, 3, stride), nn.ReLU6(True) ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return self.stem(x) @@ -430,7 +426,7 @@ def __init__( nn.BatchNorm2d(out_channels), ) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return self.conv(x) diff --git a/luxonis_train/nodes/mobilenetv2.py b/luxonis_train/nodes/mobilenetv2.py index 27fe87ec..732d0b12 100644 --- a/luxonis_train/nodes/mobilenetv2.py +++ b/luxonis_train/nodes/mobilenetv2.py @@ -15,8 +15,6 @@ class MobileNetV2(BaseNode[Tensor, list[Tensor]]): TODO: add more info """ - attach_index: int = -1 - def __init__(self, download_weights: bool = False, **kwargs): """Constructor of the MobileNetV2 backbone. @@ -37,8 +35,8 @@ def __init__(self, download_weights: bool = False, **kwargs): def forward(self, x: Tensor) -> list[Tensor]: outs = [] - for i, m in enumerate(self.backbone.features): - x = m(x) + for i, module in enumerate(self.backbone.features): + x = module(x) if i in self.out_indices: outs.append(x) diff --git a/luxonis_train/nodes/mobileone.py b/luxonis_train/nodes/mobileone.py index e92d3225..14e6e02b 100644 --- a/luxonis_train/nodes/mobileone.py +++ b/luxonis_train/nodes/mobileone.py @@ -52,7 +52,6 @@ class MobileOne(BaseNode[Tensor, list[Tensor]]): TODO: add more details """ - attach_index: int = -1 in_channels: int VARIANTS_SETTINGS: dict[str, dict] = { @@ -115,9 +114,9 @@ def __init__(self, variant: Literal["s0", "s1", "s2", "s3", "s4"] = "s0", **kwar num_se_blocks=self.num_blocks_per_stage[3] if self.use_se else 0, ) - def forward(self, x: Tensor) -> list[Tensor]: + def forward(self, inputs: Tensor) -> list[Tensor]: outs = [] - x = self.stage0(x) + x = self.stage0(inputs) outs.append(x) x = self.stage1(x) outs.append(x) diff --git a/luxonis_train/nodes/resnet.py b/luxonis_train/nodes/resnet.py index 14ff8066..8228d37a 100644 --- a/luxonis_train/nodes/resnet.py +++ b/luxonis_train/nodes/resnet.py @@ -12,8 +12,6 @@ class ResNet(BaseNode[Tensor, list[Tensor]]): - attach_index: int = -1 - def __init__( self, variant: Literal["18", "34", "50", "101", "152"] = "18", @@ -47,9 +45,9 @@ def __init__( ) self.channels_list = channels_list or [64, 128, 256, 512] - def forward(self, x: Tensor) -> list[Tensor]: + def forward(self, inputs: Tensor) -> list[Tensor]: outs = [] - x = self.backbone.conv1(x) + x = self.backbone.conv1(inputs) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) diff --git a/luxonis_train/nodes/rexnetv1.py b/luxonis_train/nodes/rexnetv1.py index fb4de4b1..de2c08ae 100644 --- a/luxonis_train/nodes/rexnetv1.py +++ b/luxonis_train/nodes/rexnetv1.py @@ -17,8 +17,6 @@ class ReXNetV1_lite(BaseNode[Tensor, list[Tensor]]): - attach_index: int = -1 - def __init__( self, fix_head_stem: bool = False, @@ -129,8 +127,8 @@ def __init__( def forward(self, x: Tensor) -> list[Tensor]: outs = [] - for i, m in enumerate(self.features): - x = m(x) + for i, module in enumerate(self.features): + x = module(x) if i in self.out_indices: outs.append(x) return outs @@ -186,12 +184,11 @@ def __init__( self.out = nn.Sequential(*out) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: out = self.out(x) if self.use_shortcut: - # this results in a ScatterND node which isn't supported yet in myriad - # out[:, 0:self.in_channels] += x + # NOTE: this results in a ScatterND node which isn't supported yet in myriad a = out[:, : self.in_channels] b = x a = a + b diff --git a/luxonis_train/nodes/segmentation_head.py b/luxonis_train/nodes/segmentation_head.py index bdfe814d..a3420491 100644 --- a/luxonis_train/nodes/segmentation_head.py +++ b/luxonis_train/nodes/segmentation_head.py @@ -16,7 +16,6 @@ class SegmentationHead(BaseNode[Tensor, Tensor]): - attach_index: int = -1 in_height: int in_channels: int