Skip to content

Commit ccd32ca

Browse files
authored
[WIP] 7145 common factory class (Project-MONAI#7159)
Fixes Project-MONAI#7145 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham <[email protected]>
1 parent 487f98b commit ccd32ca

File tree

4 files changed

+369
-64
lines changed

4 files changed

+369
-64
lines changed

Diff for: monai/networks/layers/factories.py

+171-64
Original file line numberDiff line numberDiff line change
@@ -68,40 +68,40 @@ def use_factory(fact_args):
6868
import torch.nn as nn
6969

7070
from monai.networks.utils import has_nvfuser_instance_norm
71-
from monai.utils import look_up_option, optional_import
71+
from monai.utils import ComponentStore, look_up_option, optional_import
7272

7373
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
7474

7575

76-
class LayerFactory:
76+
class LayerFactory(ComponentStore):
7777
"""
7878
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
7979
callables. These functions are referred to by name and can be added at any time.
8080
"""
8181

82-
def __init__(self) -> None:
83-
self.factories: dict[str, Callable] = {}
82+
def __init__(self, name: str, description: str) -> None:
83+
super().__init__(name, description)
84+
self.__doc__ = (
85+
f"Layer Factory '{name}': {description}\n".strip()
86+
+ "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
87+
+ "\n\nThe supported members are:"
88+
)
8489

85-
@property
86-
def names(self) -> tuple[str, ...]:
90+
def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None:
8791
"""
88-
Produces all factory names.
92+
Add the factory function to this object under the given name, with optional description.
8993
"""
94+
description: str = desc or func.__doc__ or ""
95+
self.add(name.upper(), description, func)
96+
# append name to the docstring
97+
assert self.__doc__ is not None
98+
self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``"
9099

91-
return tuple(self.factories)
92-
93-
def add_factory_callable(self, name: str, func: Callable) -> None:
100+
def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:
94101
"""
95-
Add the factory function to this object under the given name.
102+
Adds a factory function which returns the supplied class under the given name, with optional description.
96103
"""
97-
98-
self.factories[name.upper()] = func
99-
self.__doc__ = (
100-
"The supported member"
101-
+ ("s are: " if len(self.names) > 1 else " is: ")
102-
+ ", ".join(f"``{name}``" for name in self.names)
103-
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
104-
)
104+
self.add_factory_callable(name, lambda x=None: cls, desc)
105105

106106
def factory_function(self, name: str) -> Callable:
107107
"""
@@ -126,8 +126,9 @@ def get_constructor(self, factory_name: str, *args) -> Any:
126126
if not isinstance(factory_name, str):
127127
raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")
128128

129-
func = look_up_option(factory_name.upper(), self.factories)
130-
return func(*args)
129+
component = look_up_option(factory_name.upper(), self.components)
130+
131+
return component.value(*args)
131132

132133
def __getitem__(self, args) -> Any:
133134
"""
@@ -153,7 +154,7 @@ def __getattr__(self, key):
153154
as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
154155
"""
155156

156-
if key in self.factories:
157+
if key in self.components:
157158
return key
158159

159160
return super().__getattribute__(key)
@@ -194,56 +195,60 @@ def split_args(args):
194195

195196

196197
# Define factories for these layer types
197-
198-
Dropout = LayerFactory()
199-
Norm = LayerFactory()
200-
Act = LayerFactory()
201-
Conv = LayerFactory()
202-
Pool = LayerFactory()
203-
Pad = LayerFactory()
198+
Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.")
199+
Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.")
200+
Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.")
201+
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
202+
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
203+
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
204204

205205

206206
@Dropout.factory_function("dropout")
207207
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
208+
"""
209+
Dropout layers in 1,2,3 dimensions.
210+
211+
Args:
212+
dim: desired dimension of the dropout layer
213+
214+
Returns:
215+
Dropout[dim]d
216+
"""
208217
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
209218
return types[dim - 1]
210219

211220

212-
@Dropout.factory_function("alphadropout")
213-
def alpha_dropout_factory(_dim):
214-
return nn.AlphaDropout
221+
Dropout.add_factory_class("alphadropout", nn.AlphaDropout)
215222

216223

217224
@Norm.factory_function("instance")
218225
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
226+
"""
227+
Instance normalization layers in 1,2,3 dimensions.
228+
229+
Args:
230+
dim: desired dimension of the instance normalization layer
231+
232+
Returns:
233+
InstanceNorm[dim]d
234+
"""
219235
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
220236
return types[dim - 1]
221237

222238

223239
@Norm.factory_function("batch")
224240
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
225-
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
226-
return types[dim - 1]
227-
228-
229-
@Norm.factory_function("group")
230-
def group_factory(_dim) -> type[nn.GroupNorm]:
231-
return nn.GroupNorm
232-
233-
234-
@Norm.factory_function("layer")
235-
def layer_factory(_dim) -> type[nn.LayerNorm]:
236-
return nn.LayerNorm
237-
238-
239-
@Norm.factory_function("localresponse")
240-
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
241-
return nn.LocalResponseNorm
241+
"""
242+
Batch normalization layers in 1,2,3 dimensions.
242243
244+
Args:
245+
dim: desired dimension of the batch normalization layer
243246
244-
@Norm.factory_function("syncbatch")
245-
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
246-
return nn.SyncBatchNorm
247+
Returns:
248+
BatchNorm[dim]d
249+
"""
250+
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
251+
return types[dim - 1]
247252

248253

249254
@Norm.factory_function("instance_nvfuser")
@@ -274,91 +279,193 @@ def instance_nvfuser_factory(dim):
274279
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]
275280

276281

277-
Act.add_factory_callable("elu", lambda: nn.modules.ELU)
278-
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
279-
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
280-
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
281-
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
282-
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
283-
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
284-
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
285-
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
286-
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
287-
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
288-
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
282+
Norm.add_factory_class("group", nn.GroupNorm)
283+
Norm.add_factory_class("layer", nn.LayerNorm)
284+
Norm.add_factory_class("localresponse", nn.LocalResponseNorm)
285+
Norm.add_factory_class("syncbatch", nn.SyncBatchNorm)
286+
287+
288+
Act.add_factory_class("elu", nn.modules.ELU)
289+
Act.add_factory_class("relu", nn.modules.ReLU)
290+
Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU)
291+
Act.add_factory_class("prelu", nn.modules.PReLU)
292+
Act.add_factory_class("relu6", nn.modules.ReLU6)
293+
Act.add_factory_class("selu", nn.modules.SELU)
294+
Act.add_factory_class("celu", nn.modules.CELU)
295+
Act.add_factory_class("gelu", nn.modules.GELU)
296+
Act.add_factory_class("sigmoid", nn.modules.Sigmoid)
297+
Act.add_factory_class("tanh", nn.modules.Tanh)
298+
Act.add_factory_class("softmax", nn.modules.Softmax)
299+
Act.add_factory_class("logsoftmax", nn.modules.LogSoftmax)
289300

290301

291302
@Act.factory_function("swish")
292303
def swish_factory():
304+
"""
305+
Swish activation layer.
306+
307+
Returns:
308+
Swish
309+
"""
293310
from monai.networks.blocks.activation import Swish
294311

295312
return Swish
296313

297314

298315
@Act.factory_function("memswish")
299316
def memswish_factory():
317+
"""
318+
Memory efficient swish activation layer.
319+
320+
Returns:
321+
MemoryEfficientSwish
322+
"""
300323
from monai.networks.blocks.activation import MemoryEfficientSwish
301324

302325
return MemoryEfficientSwish
303326

304327

305328
@Act.factory_function("mish")
306329
def mish_factory():
330+
"""
331+
Mish activation layer.
332+
333+
Returns:
334+
Mish
335+
"""
307336
from monai.networks.blocks.activation import Mish
308337

309338
return Mish
310339

311340

312341
@Act.factory_function("geglu")
313342
def geglu_factory():
343+
"""
344+
GEGLU activation layer.
345+
346+
Returns:
347+
GEGLU
348+
"""
314349
from monai.networks.blocks.activation import GEGLU
315350

316351
return GEGLU
317352

318353

319354
@Conv.factory_function("conv")
320355
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
356+
"""
357+
Convolutional layers in 1,2,3 dimensions.
358+
359+
Args:
360+
dim: desired dimension of the convolutional layer
361+
362+
Returns:
363+
Conv[dim]d
364+
"""
321365
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
322366
return types[dim - 1]
323367

324368

325369
@Conv.factory_function("convtrans")
326370
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
371+
"""
372+
Transposed convolutional layers in 1,2,3 dimensions.
373+
374+
Args:
375+
dim: desired dimension of the transposed convolutional layer
376+
377+
Returns:
378+
ConvTranspose[dim]d
379+
"""
327380
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
328381
return types[dim - 1]
329382

330383

331384
@Pool.factory_function("max")
332385
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
386+
"""
387+
Max pooling layers in 1,2,3 dimensions.
388+
389+
Args:
390+
dim: desired dimension of the max pooling layer
391+
392+
Returns:
393+
MaxPool[dim]d
394+
"""
333395
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
334396
return types[dim - 1]
335397

336398

337399
@Pool.factory_function("adaptivemax")
338400
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
401+
"""
402+
Adaptive max pooling layers in 1,2,3 dimensions.
403+
404+
Args:
405+
dim: desired dimension of the adaptive max pooling layer
406+
407+
Returns:
408+
AdaptiveMaxPool[dim]d
409+
"""
339410
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
340411
return types[dim - 1]
341412

342413

343414
@Pool.factory_function("avg")
344415
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
416+
"""
417+
Average pooling layers in 1,2,3 dimensions.
418+
419+
Args:
420+
dim: desired dimension of the average pooling layer
421+
422+
Returns:
423+
AvgPool[dim]d
424+
"""
345425
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
346426
return types[dim - 1]
347427

348428

349429
@Pool.factory_function("adaptiveavg")
350430
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
431+
"""
432+
Adaptive average pooling layers in 1,2,3 dimensions.
433+
434+
Args:
435+
dim: desired dimension of the adaptive average pooling layer
436+
437+
Returns:
438+
AdaptiveAvgPool[dim]d
439+
"""
351440
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
352441
return types[dim - 1]
353442

354443

355444
@Pad.factory_function("replicationpad")
356445
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
446+
"""
447+
Replication padding layers in 1,2,3 dimensions.
448+
449+
Args:
450+
dim: desired dimension of the replication padding layer
451+
452+
Returns:
453+
ReplicationPad[dim]d
454+
"""
357455
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
358456
return types[dim - 1]
359457

360458

361459
@Pad.factory_function("constantpad")
362460
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
461+
"""
462+
Constant padding layers in 1,2,3 dimensions.
463+
464+
Args:
465+
dim: desired dimension of the constant padding layer
466+
467+
Returns:
468+
ConstantPad[dim]d
469+
"""
363470
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
364471
return types[dim - 1]

Diff for: monai/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
# have to explicitly bring these in here to resolve circular import issues
1515
from .aliases import alias, resolve_name
16+
from .component_store import ComponentStore
1617
from .decorators import MethodReplacer, RestartGenerator
1718
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1819
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather

0 commit comments

Comments
 (0)