Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Inference of attach_index #14

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

Expand Down Expand Up @@ -80,8 +81,6 @@ class BaseNode(
Provide only in case the `input_shapes` were not provided.
"""

attach_index: AttachIndexType = "all"

def __init__(
self,
*,
Expand All @@ -96,7 +95,21 @@ def __init__(
):
super().__init__()

self.attach_index = attach_index or self.attach_index
if attach_index is None:
parameters = inspect.signature(self.forward).parameters
inputs_forward_type = parameters.get(
"inputs", parameters.get("input", parameters.get("x", None))
)
if (
inputs_forward_type is not None
and inputs_forward_type.annotation == Tensor
):
self.attach_index = -1
else:
self.attach_index = "all"
else:
self.attach_index = attach_index

self.in_protocols = in_protocols or [FeaturesProtocol]
self.task_type = task_type

Expand Down
7 changes: 3 additions & 4 deletions luxonis_train/nodes/bisenet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class BiSeNetHead(BaseNode[Tensor, Tensor]):
attach_index: int = -1
in_height: int
in_channels: int

Expand Down Expand Up @@ -45,6 +44,6 @@ def wrap(self, output: Tensor) -> Packet[Tensor]:
return {"segmentation": [output]}

def forward(self, inputs: Tensor) -> Tensor:
inputs = self.conv_3x3(inputs)
inputs = self.conv_1x1(inputs)
return self.upscale(inputs)
x = self.conv_3x3(inputs)
x = self.conv_1x1(x)
return self.upscale(x)
1 change: 0 additions & 1 deletion luxonis_train/nodes/classification_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

class ClassificationHead(BaseNode[Tensor, Tensor]):
in_channels: int
attach_index: int = -1

def __init__(
self,
Expand Down
8 changes: 3 additions & 5 deletions luxonis_train/nodes/contextspatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@


class ContextSpatial(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(self, context_backbone: str = "MobileNetV2", **kwargs):
"""Context spatial backbone.
TODO: Add more documentation.
Expand All @@ -34,9 +32,9 @@ def __init__(self, context_backbone: str = "MobileNetV2", **kwargs):
self.spatial_path = SpatialPath(3, 128)
self.ffm = FeatureFusionBlock(256, 256)

def forward(self, x: Tensor) -> list[Tensor]:
spatial_out = self.spatial_path(x)
context16, _ = self.context_path(x)
def forward(self, inputs: Tensor) -> list[Tensor]:
spatial_out = self.spatial_path(inputs)
context16, _ = self.context_path(inputs)
fm_fuse = self.ffm(spatial_out, context16)
outs = [fm_fuse]
return outs
Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@


class EfficientRep(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
channels_list: list[int] | None = None,
Expand Down Expand Up @@ -104,9 +102,9 @@ def set_export_mode(self, mode: bool = True) -> None:
if isinstance(module, RepVGGBlock):
module.reparametrize()

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outputs = []
x = self.repvgg_encoder(x)
x = self.repvgg_encoder(inputs)
for block in self.blocks:
x = block(x)
outputs.append(x)
Expand Down
4 changes: 1 addition & 3 deletions luxonis_train/nodes/implicit_keypoint_bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math
from typing import Literal, cast
from typing import cast

import torch
from torch import Tensor, nn
Expand All @@ -22,8 +22,6 @@


class ImplicitKeypointBBoxHead(BaseNode):
attach_index: Literal["all"] = "all"

def __init__(
self,
n_keypoints: int | None = None,
Expand Down
24 changes: 10 additions & 14 deletions luxonis_train/nodes/micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ class MicroNet(BaseNode[Tensor, list[Tensor]]):
TODO: DOCS
"""

attach_index: int = -1

def __init__(self, variant: Literal["M1", "M2", "M3"] = "M1", **kwargs):
"""MicroNet backbone.
Expand Down Expand Up @@ -236,23 +234,21 @@ def __init__(
ChannelShuffle(out_channels // 2) if y3 != 0 else nn.Sequential(),
)

def forward(self, x: Tensor):
identity = x
out = self.layers(x)
def forward(self, inputs: Tensor) -> Tensor:
out = self.layers(inputs)
if self.identity:
out += identity
out += inputs
return out


class ChannelShuffle(nn.Module):
def __init__(self, groups: int):
super(ChannelShuffle, self).__init__()
super().__init__()
self.groups = groups

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
b, c, h, w = x.size()
channels_per_group = c // self.groups
# reshape
x = x.view(b, self.groups, channels_per_group, h, w)
x = torch.transpose(x, 1, 2).contiguous()
out = x.view(b, -1, h, w)
Expand Down Expand Up @@ -300,7 +296,7 @@ def __init__(
indexs = torch.cat([indexs[1], indexs[0]], dim=2)
self.index = indexs.view(in_channels).long()

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
B, C, _, _ = x.shape
x_out = x

Expand Down Expand Up @@ -350,7 +346,7 @@ def __init__(self, in_channels: int, out_channels: int):
nn.Linear(in_channels, out_channels), nn.BatchNorm1d(out_channels), HSwish()
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)


Expand Down Expand Up @@ -383,7 +379,7 @@ def __init__(
ChannelShuffle(out_channels1),
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)


Expand All @@ -394,7 +390,7 @@ def __init__(self, in_channels: int, stride: int, outs: tuple[int, int] = (4, 4)
SpatialSepConvSF(in_channels, outs, 3, stride), nn.ReLU6(True)
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.stem(x)


Expand Down Expand Up @@ -430,7 +426,7 @@ def __init__(
nn.BatchNorm2d(out_channels),
)

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)


Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ class MobileNetV2(BaseNode[Tensor, list[Tensor]]):
TODO: add more info
"""

attach_index: int = -1

def __init__(self, download_weights: bool = False, **kwargs):
"""Constructor of the MobileNetV2 backbone.
Expand All @@ -37,8 +35,8 @@ def __init__(self, download_weights: bool = False, **kwargs):

def forward(self, x: Tensor) -> list[Tensor]:
outs = []
for i, m in enumerate(self.backbone.features):
x = m(x)
for i, module in enumerate(self.backbone.features):
x = module(x)
if i in self.out_indices:
outs.append(x)

Expand Down
5 changes: 2 additions & 3 deletions luxonis_train/nodes/mobileone.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class MobileOne(BaseNode[Tensor, list[Tensor]]):
TODO: add more details
"""

attach_index: int = -1
in_channels: int

VARIANTS_SETTINGS: dict[str, dict] = {
Expand Down Expand Up @@ -115,9 +114,9 @@ def __init__(self, variant: Literal["s0", "s1", "s2", "s3", "s4"] = "s0", **kwar
num_se_blocks=self.num_blocks_per_stage[3] if self.use_se else 0,
)

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outs = []
x = self.stage0(x)
x = self.stage0(inputs)
outs.append(x)
x = self.stage1(x)
outs.append(x)
Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@


class ResNet(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
variant: Literal["18", "34", "50", "101", "152"] = "18",
Expand Down Expand Up @@ -47,9 +45,9 @@ def __init__(
)
self.channels_list = channels_list or [64, 128, 256, 512]

def forward(self, x: Tensor) -> list[Tensor]:
def forward(self, inputs: Tensor) -> list[Tensor]:
outs = []
x = self.backbone.conv1(x)
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
Expand Down
11 changes: 4 additions & 7 deletions luxonis_train/nodes/rexnetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@


class ReXNetV1_lite(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
fix_head_stem: bool = False,
Expand Down Expand Up @@ -129,8 +127,8 @@ def __init__(

def forward(self, x: Tensor) -> list[Tensor]:
outs = []
for i, m in enumerate(self.features):
x = m(x)
for i, module in enumerate(self.features):
x = module(x)
if i in self.out_indices:
outs.append(x)
return outs
Expand Down Expand Up @@ -186,12 +184,11 @@ def __init__(

self.out = nn.Sequential(*out)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
out = self.out(x)

if self.use_shortcut:
# this results in a ScatterND node which isn't supported yet in myriad
# out[:, 0:self.in_channels] += x
# NOTE: this results in a ScatterND node which isn't supported yet in myriad
a = out[:, : self.in_channels]
b = x
a = a + b
Expand Down
1 change: 0 additions & 1 deletion luxonis_train/nodes/segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class SegmentationHead(BaseNode[Tensor, Tensor]):
attach_index: int = -1
in_height: int
in_channels: int

Expand Down
Loading