Skip to content

Commit

Permalink
add define_preprocess function in _Model
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Mar 8, 2021
1 parent 94eaf1f commit 80dd972
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 3 additions & 1 deletion trojanvision/models/imagemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(self, norm_par: dict[str, list[float]] = {'mean': [0.0], 'std': [1.
num_classes=None, **kwargs):
if num_classes is None:
num_classes = 1000
super().__init__(num_classes=num_classes, **kwargs)
super().__init__(num_classes=num_classes, norm_par=norm_par, **kwargs)

def define_preprocess(self, norm_par: dict[str, list[float]] = {'mean': [0.0], 'std': [1.0]}, **kwargs):
self.normalize = Normalize(mean=norm_par['mean'], std=norm_par['std'])

# get feature map
Expand Down
4 changes: 4 additions & 0 deletions trojanzoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class _Model(nn.Module):

def __init__(self, num_classes: int = None, **kwargs):
super().__init__()
self.define_preprocess(**kwargs)
self.num_classes = num_classes
self.features = self.define_features(**kwargs) # feature extractor
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # average pooling
Expand All @@ -47,6 +48,9 @@ def __init__(self, num_classes: int = None, **kwargs):
self.softmax = nn.Softmax(dim=1)
self.layer_name_list: list[str] = None

def define_preprocess(self, **kwargs):
pass

@staticmethod
def define_features(**kwargs) -> nn.Module:
return nn.Identity()
Expand Down

0 comments on commit 80dd972

Please sign in to comment.