diff --git a/README.md b/README.md index ed69fb2..46038e2 100644 --- a/README.md +++ b/README.md @@ -1 +1,157 @@ -# x-mlps +# X-MLPs + +An MLP model that provides a flexible foundation to implement, mix-and-match, and test various state-of-the-art MLP building blocks and architectures. +Built on Jax and Haiku. + +## Installation + +```sh +pip install x-mlps +``` + +**Note**: X-MLPs will not install Jax for you (see [here](https://github.com/google/jax#installation) for install instructions). + +## Getting Started + +The `XMLP` class provides the foundation from which all MLP architectures are built on, and is the primary class you use. +Additionally, X-MLPs relies heavily on factory functions to customize and instantiate the building blocks that make up a particular `XMLP` instance. +Fortunately, this library provides several SOTA MLP blocks out-of-the-box as factory functions. +For example, to implement the ResMLP architecture, you can implement the follow model function: + +```python +import haiku as hk +import jax +from einops import rearrange +from x_mlps import XMLP, Affine, resmlp_block_factory + +def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): + # NOTE: Operating directly on batched data is supported as well. + @hk.vmap + def model_fn(x: jnp.ndarray) -> jnp.ndarray: + # Reformat input image into a sequence of patches + x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) + return XMLP( + num_patches=x.shape[-2], + dim=dim, + depth=depth, + block=resmlp_block_factory, + normalization=lambda num_patches, dim, depth, **kwargs: Affine(dim, **kwargs), + num_classes=num_classes, + )(x) + + return model_fn + +model = create_model(patch_size=4, dim=384, depth=12) +model_fn = hk.transform(model) +model_fn = hk.without_apply_rng(model_fn) + +rng = jax.random.PRNGKey(0) +params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) +``` + +It's important to note the `XMLP` module _does not_ reformat input data to the form appropriate for whatever block you make use of (e.g., a sequence of patches). +As such, you must reformat data manually before feeding it to an `XMLP` module. +The [einops](https://github.com/arogozhnikov/einops) library (which is installed by X-MLPs) provides functions (e.g., `rearrange`) that can help here. + +**Note**: Like the core Haiku modules, all modules implemented in X-MLPs support both batched and vectorized data. + +## X-MLPs Architecture Details + +X-MLPs uses a layered approach to construct arbitrary MLP networks. There are three core modules used to create a network's structure: + +1. `XSublayer` - bottom level module which wraps arbitrary feedforward functionality. +2. `XBlock` - mid level module consisting of one or more `XSublayer` modules. +3. `XMLP` - top level module which represents a generic MLP network, and is composed of a stack of repeated `XBlock` modules. + +To support user-defined modules, each of the above modules support passing arbitrary keyword arguments to child modules. +This is accomplished by prepending arguments with one or more predefined prefixes (including user defined prefixes). +Built-in prefixes include: + +1. "block\_" - arguments fed directly to the `XBlock` module. +2. "sublayers\_" - arguments fed to all `XSublayer`s in each `XBlock`. +3. "sublayers{i}\_" - arguments fed to the i-th `XSublayer` in each `XBlock` (where 1 <= i <= # of sublayers). +4. "ff\_" - arguments fed to the feedforward module in a `XSublayer`. + +This must be combined in order when passing them to the `XMLP` module (e.g., "block_sublayer1_ff\_"). + +### XSublayer + +The `XSublayer` module is a flexible sublayer wrapper module providing skip connections and pre/post-normalization to an arbitrary child module (specifically, arbitrary feedforward modules e.g., `XChannelFeedForward`). +Child module instances are not passed directly, rather a factory function which creates the child module is instead. +This ensures that individual sublayers can be configured automatically based on depth. + +### XBlock + +The `XBlock` module is a generic MLP block. It is composed of one or more `XSublayer` modules, passed as factory functions. + +### Top Layer - XMLP + +At the top level is the `XMLP` module, which represents a generic MLP network. +N `XBlock` modules are stacked together to form a network, created via a common factory function. + +## Built-in MLP Architectures + +The following architectures have been implemented in the form of `XBlock`s and have corresponding factory functions. + +- ResMLP - `resmlp_block_factory` +- MLP-Mixer - `mlpmixer_block_factory` +- gMLP - `gmlp_block_factory` +- S²-MLP `s2mlp_block_factory` + +See their respective docstrings for more information. + +## LICENSE + +See [LICENSE](LICENSE). + +## Citations + +```bibtex +@article{Touvron2021ResMLPFN, + title={ResMLP: Feedforward networks for image classification with data-efficient training}, + author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Gautier Izacard and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Herv'e J'egou}, + journal={ArXiv}, + year={2021}, + volume={abs/2105.03404} +} +``` + +```bibtex +@article{Tolstikhin2021MLPMixerAA, + title={MLP-Mixer: An all-MLP Architecture for Vision}, + author={Ilya O. Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy}, + journal={ArXiv}, + year={2021}, + volume={abs/2105.01601} +} +``` + +```bibtex +@article{Liu2021PayAT, + title={Pay Attention to MLPs}, + author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le}, + journal={ArXiv}, + year={2021}, + volume={abs/2105.08050} +} +``` + +```bibtex +@article{Yu2021S2MLPSM, + title={S2-MLP: Spatial-Shift MLP Architecture for Vision}, + author={Tan Yu and Xu Li and Yunfeng Cai and Mingming Sun and Ping Li}, + journal={ArXiv}, + year={2021}, + volume={abs/2106.07477} +} +``` + +```bibtex +@article{Touvron2021GoingDW, + title={Going deeper with Image Transformers}, + author={Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Herv'e J'egou}, + journal={ArXiv}, + year={2021}, + volume={abs/2103.17239} +} +``` diff --git a/src/x_mlps/_x_mlps.py b/src/x_mlps/_x_mlps.py index ee298e3..374e319 100644 --- a/src/x_mlps/_x_mlps.py +++ b/src/x_mlps/_x_mlps.py @@ -9,6 +9,8 @@ class XModuleFactory(Protocol): + """Defines a common factory function interface for all X-MLP modules.""" + def __call__(self, num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any) -> hk.Module: ... @@ -23,7 +25,21 @@ def _calc_layer_scale_eps(depth: int) -> float: return init_eps -def create_shift2d_op(height: int, width: int, amount: int = 1) -> Callable: +def create_shift2d_op(height: int, width: int, amount: int = 1) -> Callable[[jnp.ndarray], jnp.ndarray]: + """Create a 2D shift operator based on spatial shift algorithm introduced in S^2-MLP¹. + + Args: + height: Height of the original input image divided by the patch size. + width: Width of the original input image divided by the patch size. + amount: Amount of shift. + + Returns: + Callable[[jnp.ndarray], jnp.ndarray]: The configured shift operator. + + References: + 1. S2-MLP: Spatial-Shift MLP Architecture for Vision (https://arxiv.org/abs/2106.07477). + """ + def shift2d(x: jnp.ndarray) -> jnp.ndarray: c = x.shape[-1] x = rearrange(x, "... (h w) c -> ... h w c", h=height, w=width) @@ -38,6 +54,20 @@ def shift2d(x: jnp.ndarray) -> jnp.ndarray: class Affine(hk.Module): + """Affine transform layer as described in ResMLP¹. + + Briefly, this operator rescales and shifts its input by a learned weight and bias of size `dim`. + + Args: + dim (int): Size of the channel dimension. + init_scale (float): Initial weight value of alpha. + name (str, optional): The name of the module. Defaults to None. + + References: + 1. ResMLP: Feedforward networks for image classification with data-efficient training + (https://arxiv.org/abs/2105.03404). + """ + def __init__(self, dim: int, init_scale: float = 1.0, name: Optional[str] = None): super().__init__(name=name) @@ -53,6 +83,20 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class LayerScale(hk.Module): + """LayerScale layer as described in *Going deeper with Image Transformers*¹. + + Briefly, rescales the input by a learned weight of size `dim`. + + Args: + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. This is used to determine the + initial weight values. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + + References: + 1. Going deeper with Image Transformers (https://arxiv.org/abs/2103.17239). + """ + def __init__(self, dim: int, depth: int, name: Optional[str] = None): super().__init__(name=name) @@ -67,6 +111,25 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class SpatialGatingUnit(hk.Module): + """Spatial Gating Unit as described in *Pay Attention to MLPs*¹. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + norm (XModuleFactory, optional): Normalization layer factory function. Defaults to LayerNorm via + `layernorm_factory`. + activation (Callable[[jnp.ndarray], jnp.ndarray], optional): Activation function. Applied to the gate values + spatial projection. Defaults to the identity function. + init_eps (float): Initial weight of the spatial projection layer. Scaled by the number of patches. Defaults to + 1e-3. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Currently unused. + + References: + 1. Pay Attention to MLPs (https://arxiv.org/abs/2105.08050). + """ + def __init__( self, num_patches: int, @@ -107,6 +170,25 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class MLPMixerXPatchFeedForward(hk.Module): + """Patch (token) mixing feedforward layer as described in *MLP-Mixer: An all-MLP Architecture for Vision*¹. + + Note that this module does not implement normalization nor a skip connection. This module can be combined with the + `XSublayer` module to add these functionalities. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. Not used directly, rather it's included as an argument to establish + a consistent interface with other modules. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + dim_hidden (int, optional): Hidden dimension size. Defaults to 4 x num_patches. + activation (Callable[[jnp.ndarray], jnp.ndarray], optional): Activation function. Defaults to the GELU function. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Currently unused. + + References: + 1. MLP-Mixer: An all-MLP Architecture for Vision (https://arxiv.org/abs/2105.01601). + """ + def __init__( self, num_patches: int, @@ -126,12 +208,30 @@ def __init__( self.activation = activation def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + # NOTE: Equivalent to a transpose and pair of linear layers. x = hk.Conv1D(self.dim_hidden, 1, data_format="NCW", name="conv_1")(inputs) x = self.activation(x) return hk.Conv1D(self.num_patches, 1, data_format="NCW", name="conv_2")(x) class ResMLPXPatchFeedForward(hk.Module): + """Patch (token) mixing feedforward layer as described in ResMLP¹. + + Note that this module does not implement normalization nor a skip connection. This module can be combined with the + `XSublayer` module to add these functionalities. + + Args + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. Not used directly, rather it's included as an argument to establish + a consistent interface with other modules. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + **kwargs: Currently unused. + + References: + 1. ResMLP: Feedforward networks for image classification with data-efficient training + (https://arxiv.org/abs/2105.03404). + """ + def __init__(self, num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any): super().__init__(name=name) @@ -140,10 +240,31 @@ def __init__(self, num_patches: int, dim: int, depth: int, name: Optional[str] = self.depth = depth def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + # NOTE: Equivalent to a transpose and linear layer. return hk.Conv1D(self.num_patches, 1, data_format="NCW", name="conv")(inputs) class gMLPFeedForward(hk.Module): + """gMLP feedforward layer as described in *Pay Attention to MLPs*¹. + + Note that this module does not implement normalization nor a skip connection. This module can be combined with the + `XSublayer` module to add these functionalities. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + dim_hidden (int, optional): Hidden dimension size of the first projection. Defaults to 4 x dim. + sgu (XModuleFactory, optional): Spatial gating unit factory function. Defaults to the standard SpatialGatingUnit + module via `sgu_factory`. + activation (Callable[[jnp.ndarray], jnp.ndarray], optional): Activation function. Defaults to the GELU function. + name (str, optional): The name of the module. Defaults to None. + **kwargs: All arguments starting with "sgu_" are passed to the spatial gating unit factory function. + + References: + 1. Pay Attention to MLPs (https://arxiv.org/abs/2105.08050). + """ + def __init__( self, num_patches: int, @@ -179,6 +300,24 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class XChannelFeedForward(hk.Module): + """Common channel mixing feedforward layer. + + Note that this module does not implement normalization nor a skip connection. This module can be combined with the + `XSublayer` module to add these functionalities. + + Args: + num_patches (int): Number of patches in the input. Not used directly, rather it's included as an argument to + establish a consistent interface with other modules. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + dim_hidden (int, optional): Hidden dimension size. Defaults to 4 x dim. + activation (Callable[[jnp.ndarray], jnp.ndarray], optional): Activation function. Defaults to the GELU function. + shift (Callable[[jnp.ndarray], jnp.ndarray], optional): Token shifting function (e.g., spatial shift in S^2-MLP). + Defaults to `None`. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Currently unused. + """ + def __init__( self, num_patches: int, @@ -209,6 +348,22 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class XSublayer(hk.Module): + """Flexible sublayer wrapper module providing skip connections and pre/post-normalization to arbitrary layers. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + ff (XModuleFactory): Feedforward layer factory function. + prenorm (XModuleFactory, optional): Pre-normalization layer factory function. Defaults to `None`. + postnorm (XModuleFactory, optional): Post-normalization layer factory function. Defaults to `None`. + residual (bool): Whether to add a residual/skip connection. Defaults to `True`. + name (str, optional): The name of the module. Defaults to None. + **kwargs: All arguments starting with "ff_" are passed to the feedforward layer factory function. + All arguments starting with "prenorm_" are passed to the pre-normalization layer factory function. + All arguments starting with "postnorm_" are passed to the post-normalization layer factory function. + """ + def __init__( self, num_patches: int, @@ -252,6 +407,27 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class XBlock(hk.Module): + """Generic MLP block. + + One or more `XSublayer` modules are stacked together to form a block. Optionally, a skip connection can be added. + Arbitrary arguments can be passed to `XSublayer` modules two different ways: + + 1. As keyword arguments prefixed with "sublayers_". These arguments are passed to all the sublayers. + 2. As keyword arguments prefixed with "sublayer{i}_" where 1 <= i <= len(sublayers). These arguments are passed to + to the i-th sublayer. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of this block in the network. Note that depth starts from 1. + sublayers (Sequence[XSublayerFactory]): Sublayer factory functions. Created sublayers will be stacked in the + order of their respective factory function in the sequence. + residual (bool): Whether to add a residual/skip connection. Defaults to `False`. + name (str, optional): The name of the module. Defaults to None. + **kwargs: All arguments starting with "sublayers_" are passed to all sublayers. All arguments starting with + "sublayer{i}_" are passed to the i-th sublayer. + """ + def __init__( self, num_patches: int, @@ -290,6 +466,34 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class XMLP(hk.Module): + """Generic MLP network. + + N `XBlock` modules are stacked together to form a network (where N is set to `depth`). Importantly, this network + assumes the input has been formatted appropriately (e.g., a sequence of patches). Before data is processed by the + stack of `XBlock` modules, it is first projected to the specified dimension `dim` via a linear layer. + + Optionally, the network can be configured to have a classification layer at the end by setting `num_classes` to a + non-zero value. In this case, the resulting sequence from stack of `XBlock` modules will be averaged over the + sequence dimension before being fed to the classification layer. + + Arbitrary arguments can be passed to `XBlock` modules by prepending the argument name with "block_". Further, to + ensure arguments are passed to child modules of each `XBlock` module, each argument name should additionally be + prefixed with that child module's identifier, starting with "block" and working down the hierarchy. For example, + to pass an argument to the feedforward module of the first sublayer of each block, the argument name should be + "block_sublayer1_ff_". + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. Inputs fed to this network are projected to this dimension. + depth (int): The number of blocks in the network. + block (XBlockFactory): Block factory function. + normalization (XModuleFactory, optional): Normalization module factory function. Occurs after the stack of + `XBlock` modules. Useful for pre-normalization architectures. Defaults to None. + num_classes (int, optional): Number of classes in the classification layer. Defaults to None. + name (str, optional): The name of the module. Defaults to None. + **kwargs: All arguments starting with "block_" are passed to all blocks. + """ + def __init__( self, num_patches: int, @@ -326,6 +530,28 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: def layernorm_factory(num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any) -> hk.Module: + """Layer normalization module factory function. + + Standard `hk.LayerNorm` arguments can be passed via the `kwargs` dictionary. If no arguments are passed, the + following defaults are used: + + 1. `axis` is set to `-1` (the last dimension). + 2. `create_scale` is set to `True`. + 3. `create_offset` is set to `True`. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. Unused. + dim (int): Size of the channel dimension. Unused. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + Unused. + name (str, optional): The name of the module. Defaults to None. + **kwargs: `hk.LayerNorm` arguments. + + Returns: + hk.Module: `hk.LayerNorm` module. + """ axis, create_scale, create_offset = pick_and_pop(["axis", "create_scale", "create_offset"], kwargs) if axis is None: axis = -1 @@ -338,34 +564,127 @@ def layernorm_factory(num_patches: int, dim: int, depth: int, name: Optional[str def sgu_factory(num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any) -> hk.Module: + """`SpatialGatingUnit` module factory function. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional `SpatialGatingUnit` arguments. + + Returns: + hk.Module: `SpatialGatingUnit` module. + """ + return SpatialGatingUnit(num_patches, dim, depth, **kwargs, name=name) def gmlp_feedforward_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """gMLP feedforward module factory function. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional `gMLPFeedForward` arguments. + + Returns: + hk.Module: `gMLPFeedForward` module. + """ + return gMLPFeedForward(num_patches, dim, depth, **kwargs, name=name) def mlpmixer_xpatch_feedforward_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """MLP-Mixer cross-patch feedforward module factory function. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional `MLPMixerXPatchFeedForward` arguments. + + Returns: + hk.Module: `MLPMixerXPatchFeedForward` module. + """ + return MLPMixerXPatchFeedForward(num_patches, dim, depth, **kwargs, name=name) def resmlp_xpatch_feedforward_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """ResMLP cross-patch feedforward module factory function. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional `ResMLPXPatchFeedForward` arguments. + + Returns: + hk.Module: `ResMLPXPatchFeedForward` module. + """ return ResMLPXPatchFeedForward(num_patches, dim, depth, **kwargs, name=name) def xchannel_feedforward_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """Standard cross-channel feedforward module factory function. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional `XChannelFeedForward` arguments. + + Returns: + hk.Module: `XChannelFeedForward` module. + """ return XChannelFeedForward(num_patches, dim, depth, **kwargs, name=name) def gmlp_block_factory(num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any) -> hk.Module: + """gMLP block module factory function. + + Builds a `XBlock` module with the gMLP block structure as defined in *Pay Attention to MLPs*¹. Specifically, this + consists of a single `XSublayer` with a `gMLPFeedForward` module and layer normalization (pre-normalization). + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional block and child module arguments. + + Returns: + hk.Module: `XBlock` module. + + References: + 1. Pay Attention to MLPs (https://arxiv.org/abs/2105.08050). + """ return XBlock( num_patches, dim, @@ -383,6 +702,27 @@ def gmlp_block_factory(num_patches: int, dim: int, depth: int, name: Optional[st def mlpmixer_block_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """MLP-Mixer block module factory function. + + Builds a `XBlock` module with the MLP-Mixer block structure as defined in *MLP-Mixer: An all-MLP Architecture for + Vision*¹. Specifically, this consists of two `XSublayer`s: 1) a `MLPMixerXPatchFeedForward` module and 2) a + `XChannelFeedForward` module. Both make use of layer normalization (pre-normalization). + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional block and child module arguments. + + Returns: + hk.Module: `XBlock` module. + + References: + 1. MLP-Mixer: An all-MLP Architecture for Vision (https://arxiv.org/abs/2105.01601). + """ return XBlock( num_patches, dim, @@ -405,6 +745,29 @@ def mlpmixer_block_factory( def resmlp_block_factory( num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any ) -> hk.Module: + """ResMLP block module factory function. + + Builds a `XBlock` module with the ResMLP block structure as defined in *ResMLP: Feedforward networks for image + classification with data-efficient training*¹. Specifically, this consists of two `XSublayer`s: 1) a + `ResMLPXPatchFeedForward` module and 2) a `XChannelFeedForward` module. Both make use of `Affine` pre-normalization + and `LayerScale` post-normalization. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional block and child module arguments. + + Returns: + hk.Module: `XBlock` module. + + References: + 1. ResMLP: Feedforward networks for image classification with data-efficient training + (https://arxiv.org/abs/2105.03404). + """ return XBlock( num_patches, dim, @@ -435,6 +798,32 @@ def resmlp_block_factory( def s2mlp_block_factory(num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any) -> hk.Module: + """S^2-MLP block module factory function. + + Builds a `XBlock` module with the S^2-MLP block structure as defined in *S^2-MLP: Spatial-Shift MLP Architecture + for Vision*¹. Specifically, this consists of two `XSublayer`s: 1) a `XChannelFeedForward` module with a shift + function and 2) a second `XChannelFeedForward` module. Both make use of layer normalization (pre-normalization). + + Note: currently, the spatial shift function must be passed as the keyword argument "sublayer1_ff_shift". This is + due to the fact that `create_shift2d_op` requires post-patching height and width dimensions, which the `XMLP` + module has no knowledge of. + + Satifies the `XModuleFactory` interface. + + Args: + num_patches (int): Number of patches in the input. + dim (int): Size of the channel dimension. + depth (int): The depth of the block which contains this layer in the network. Note that depth starts from 1. + name (str, optional): The name of the module. Defaults to None. + **kwargs: Additional block and child module arguments. + + Returns: + hk.Module: `XBlock` module. + + References: + 1. S2-MLP: Spatial-Shift MLP Architecture for Vision (https://arxiv.org/abs/2106.07477). + """ + if "sublayer1_ff_shift" not in kwargs: raise ValueError("s2mlp_block_factory requires sublayer1_ff_shift to be specified") if "sublayer1_ff_dim_hidden" not in kwargs and "sublayers_ff_dim_hidden" not in kwargs: