Skip to content

Commit

Permalink
Automatic Inference of attach_index (#14)
Browse files Browse the repository at this point in the history
* automatic inference of attach index based on type signature

* added inference for input and x names
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent d6081d4 commit d78dccb
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 53 deletions.
19 changes: 16 additions & 3 deletions luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

Expand Down Expand Up @@ -80,8 +81,6 @@ class BaseNode(
Provide only in case the `input_shapes` were not provided.
"""

attach_index: AttachIndexType = "all"

def __init__(
self,
*,
Expand All @@ -96,7 +95,21 @@ def __init__(
):
super().__init__()

self.attach_index = attach_index or self.attach_index
if attach_index is None:
parameters = inspect.signature(self.forward).parameters
inputs_forward_type = parameters.get(
"inputs", parameters.get("input", parameters.get("x", None))
)
if (
inputs_forward_type is not None
and inputs_forward_type.annotation == Tensor
):
self.attach_index = -1
else:
self.attach_index = "all"
else:
self.attach_index = attach_index

self.in_protocols = in_protocols or [FeaturesProtocol]
self.task_type = task_type

Expand Down
7 changes: 3 additions & 4 deletions luxonis_train/nodes/bisenet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class BiSeNetHead(BaseNode[Tensor, Tensor]):
attach_index: int = -1
in_height: int
in_channels: int

Expand Down Expand Up @@ -45,6 +44,6 @@ def wrap(self, output: Tensor) -> Packet[Tensor]:
return {"segmentation": [output]}

def forward(self, inputs: Tensor) -> Tensor:
inputs = self.conv_3x3(inputs)
inputs = self.conv_1x1(inputs)
return self.upscale(inputs)
x = self.conv_3x3(inputs)
x = self.conv_1x1(x)
return self.upscale(x)
1 change: 0 additions & 1 deletion luxonis_train/nodes/classification_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

class ClassificationHead(BaseNode[Tensor, Tensor]):
in_channels: int
attach_index: int = -1

def __init__(
self,
Expand Down
8 changes: 3 additions & 5 deletions luxonis_train/nodes/contextspatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@


class ContextSpatial(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(self, context_backbone: str = "MobileNetV2", **kwargs):
"""Context spatial backbone.
TODO: Add more documentation.
Expand All @@ -34,9 +32,9 @@ def __init__(self, context_backbone: str = "MobileNetV2", **kwargs):
self.spatial_path = SpatialPath(3, 128)
self.ffm = FeatureFusionBlock(256, 256)

def forward(self, x: Tensor) -> list[Tensor]:
spatial_out = self.spatial_path(x)
context16, _ = self.context_path(x)
def forward(self, inputs: Tensor) -> list[Tensor]:
spatial_out = self.spatial_path(inputs)
context16, _ = self.context_path(inputs)
fm_fuse = self.ffm(spatial_out, context16)
outs = [fm_fuse]
return outs
Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@


class EfficientRep(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
channels_list: list[int] | None = None,
Expand Down Expand Up @@ -104,9 +102,9 @@ def set_export_mode(self, mode: bool = True) -> None:
if isinstance(module, RepVGGBlock):
module.reparametrize()

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outputs = []
x = self.repvgg_encoder(x)
x = self.repvgg_encoder(inputs)
for block in self.blocks:
x = block(x)
outputs.append(x)
Expand Down
4 changes: 1 addition & 3 deletions luxonis_train/nodes/implicit_keypoint_bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math
from typing import Literal, cast
from typing import cast

import torch
from torch import Tensor, nn
Expand All @@ -22,8 +22,6 @@


class ImplicitKeypointBBoxHead(BaseNode):
attach_index: Literal["all"] = "all"

def __init__(
self,
n_keypoints: int | None = None,
Expand Down
24 changes: 10 additions & 14 deletions luxonis_train/nodes/micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ class MicroNet(BaseNode[Tensor, list[Tensor]]):
TODO: DOCS
"""

attach_index: int = -1

def __init__(self, variant: Literal["M1", "M2", "M3"] = "M1", **kwargs):
"""MicroNet backbone.
Expand Down Expand Up @@ -236,23 +234,21 @@ def __init__(
ChannelShuffle(out_channels // 2) if y3 != 0 else nn.Sequential(),
)

def forward(self, x: Tensor):
identity = x
out = self.layers(x)
def forward(self, inputs: Tensor) -> Tensor:
out = self.layers(inputs)
if self.identity:
out += identity
out += inputs
return out


class ChannelShuffle(nn.Module):
def __init__(self, groups: int):
super(ChannelShuffle, self).__init__()
super().__init__()
self.groups = groups

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
b, c, h, w = x.size()
channels_per_group = c // self.groups
# reshape
x = x.view(b, self.groups, channels_per_group, h, w)
x = torch.transpose(x, 1, 2).contiguous()
out = x.view(b, -1, h, w)
Expand Down Expand Up @@ -300,7 +296,7 @@ def __init__(
indexs = torch.cat([indexs[1], indexs[0]], dim=2)
self.index = indexs.view(in_channels).long()

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
B, C, _, _ = x.shape
x_out = x

Expand Down Expand Up @@ -350,7 +346,7 @@ def __init__(self, in_channels: int, out_channels: int):
nn.Linear(in_channels, out_channels), nn.BatchNorm1d(out_channels), HSwish()
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)


Expand Down Expand Up @@ -383,7 +379,7 @@ def __init__(
ChannelShuffle(out_channels1),
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)


Expand All @@ -394,7 +390,7 @@ def __init__(self, in_channels: int, stride: int, outs: tuple[int, int] = (4, 4)
SpatialSepConvSF(in_channels, outs, 3, stride), nn.ReLU6(True)
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.stem(x)


Expand Down Expand Up @@ -430,7 +426,7 @@ def __init__(
nn.BatchNorm2d(out_channels),
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)


Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ class MobileNetV2(BaseNode[Tensor, list[Tensor]]):
TODO: add more info
"""

attach_index: int = -1

def __init__(self, download_weights: bool = False, **kwargs):
"""Constructor of the MobileNetV2 backbone.
Expand All @@ -37,8 +35,8 @@ def __init__(self, download_weights: bool = False, **kwargs):

def forward(self, x: Tensor) -> list[Tensor]:
outs = []
for i, m in enumerate(self.backbone.features):
x = m(x)
for i, module in enumerate(self.backbone.features):
x = module(x)
if i in self.out_indices:
outs.append(x)

Expand Down
5 changes: 2 additions & 3 deletions luxonis_train/nodes/mobileone.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class MobileOne(BaseNode[Tensor, list[Tensor]]):
TODO: add more details
"""

attach_index: int = -1
in_channels: int

VARIANTS_SETTINGS: dict[str, dict] = {
Expand Down Expand Up @@ -115,9 +114,9 @@ def __init__(self, variant: Literal["s0", "s1", "s2", "s3", "s4"] = "s0", **kwar
num_se_blocks=self.num_blocks_per_stage[3] if self.use_se else 0,
)

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outs = []
x = self.stage0(x)
x = self.stage0(inputs)
outs.append(x)
x = self.stage1(x)
outs.append(x)
Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@


class ResNet(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
variant: Literal["18", "34", "50", "101", "152"] = "18",
Expand Down Expand Up @@ -47,9 +45,9 @@ def __init__(
)
self.channels_list = channels_list or [64, 128, 256, 512]

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outs = []
x = self.backbone.conv1(x)
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
Expand Down
11 changes: 4 additions & 7 deletions luxonis_train/nodes/rexnetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@


class ReXNetV1_lite(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
fix_head_stem: bool = False,
Expand Down Expand Up @@ -129,8 +127,8 @@ def __init__(

def forward(self, x: Tensor) -> list[Tensor]:
outs = []
for i, m in enumerate(self.features):
x = m(x)
for i, module in enumerate(self.features):
x = module(x)
if i in self.out_indices:
outs.append(x)
return outs
Expand Down Expand Up @@ -186,12 +184,11 @@ def __init__(

self.out = nn.Sequential(*out)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
out = self.out(x)

if self.use_shortcut:
# this results in a ScatterND node which isn't supported yet in myriad
# out[:, 0:self.in_channels] += x
# NOTE: this results in a ScatterND node which isn't supported yet in myriad
a = out[:, : self.in_channels]
b = x
a = a + b
Expand Down
1 change: 0 additions & 1 deletion luxonis_train/nodes/segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class SegmentationHead(BaseNode[Tensor, Tensor]):
attach_index: int = -1
in_height: int
in_channels: int

Expand Down

0 comments on commit d78dccb

Please sign in to comment.