Skip to content

Commit

Permalink
Improve model_utils documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
opedromartins committed Aug 1, 2024
1 parent 5c161da commit a75581b
Showing 2 changed files with 38 additions and 50 deletions.
66 changes: 33 additions & 33 deletions nbs/04_model_utils.ipynb
Original file line number Diff line number Diff line change
@@ -16,6 +16,36 @@
"#| default_exp models/model_utils"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These classes define various convolutional blocks for both dense (regular) and sparse convolutional neural networks (CNNs), abstracting some of the complexities and repetitive code that is often encountered when building such networks directly using PyTorch. Below is an explanation of what these classes are doing, their differences from standard PyTorch implementations, and their limitations.\n",
"\n",
"### Module differences and limitations\n",
"\n",
"#### Differences from PyTorch Direct Implementation\n",
"\n",
"- **Abstraction**: These classes encapsulate common patterns (convolution + normalization + activation) into single modules, reducing repetitive code and making the network definitions more concise and easier to read.\n",
"- **Configuration**: They provide a higher-level interface for configuring layers, automatically setting common parameters such as padding.\n",
"- **Sparse Convolution Support**: The sparse convolution blocks use the `spconv` library, which is not part of standard PyTorch, to handle sparse input data more efficiently.\n",
"\n",
"#### Parameters Abstracted from PyTorch Direct Implementation\n",
"\n",
"- **Padding Calculation**: Automatically calculates padding based on the kernel size if not provided.\n",
"- **Layer Initialization**: Automatically initializes convolutional, normalization, and activation layers within the block, so users don't need to explicitly define each component.\n",
"- **Residual Connections**: For the basic blocks, the residual connections (identity mappings) are integrated within the block, simplifying the addition of these connections.\n",
"\n",
"#### Limitations\n",
"\n",
"- **Flexibility**: While these classes simplify the creation of common patterns, they can be less flexible than directly using PyTorch when non-standard configurations or additional customizations are required.\n",
"- **Dependency on `spconv`**: The sparse convolution blocks depend on the `spconv` library, which might not be as widely used or supported as PyTorch's native functionality.\n",
"- **Debugging**: Abstracting layers into higher-level blocks can make debugging more difficult, as the internal operations are hidden away. Users may need to dig into the class implementations to troubleshoot issues.\n",
"- **Performance Overhead**: Although the abstraction can simplify code, it might introduce slight performance overhead due to additional function calls and encapsulation.\n",
"\n",
"Overall, these classes provide a convenient and structured way to build CNNs, particularly when using common patterns and when working with sparse data. However, for highly customized or performance-critical applications, a more direct approach using PyTorch's lower-level APIs might be preferable."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -259,36 +289,6 @@
"print(\"Output shape:\", output_tensor.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\ndef replace_feature(out, new_features):\\n if \"replace_feature\" in out.__dir__():\\n # spconv 2.x behaviour\\n return out.replace_feature(new_features)\\n else:\\n out.features = new_features\\n return out\\n'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#|export\n",
"#|hide\n",
"\"\"\"\n",
"def replace_feature(out, new_features):\n",
" if \"replace_feature\" in out.__dir__():\n",
" # spconv 2.x behaviour\n",
" return out.replace_feature(new_features)\n",
" else:\n",
" out.features = new_features\n",
" return out\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -299,11 +299,11 @@
"#|hide\n",
"def replace_feature(out, new_features):\n",
" if \"replace_feature\" in out.__dir__():\n",
" # Use the replace_feature method for SparseConvTensor\n",
" # spconv 2.x behaviour\n",
" return out.replace_feature(new_features)\n",
" else:\n",
" # Assuming `out` is a SparseConvTensor and it does not have replace_feature method\n",
" return spconv.pytorch.SparseConvTensor(new_features, out.indices, out.spatial_shape, out.batch_size)\n"
" out.features = new_features\n",
" return out"
]
},
{
22 changes: 5 additions & 17 deletions pillarnext_explained/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,13 @@
__all__ = ['Conv', 'ConvBlock', 'BasicBlock', 'replace_feature', 'SparseConvBlock', 'SparseBasicBlock', 'SparseConv3dBlock',
'SparseBasicBlock3d']

# %% ../../nbs/04_model_utils.ipynb 2
# %% ../../nbs/04_model_utils.ipynb 3
import torch.nn as nn
import spconv
import spconv.pytorch
from spconv.core import ConvAlgo

# %% ../../nbs/04_model_utils.ipynb 4
# %% ../../nbs/04_model_utils.ipynb 5
class Conv(nn.Module):
"""
A convolutional layer module for neural networks.
@@ -38,7 +38,7 @@ def __init__(self,
def forward(self, x):
return self.conv(x)

# %% ../../nbs/04_model_utils.ipynb 6
# %% ../../nbs/04_model_utils.ipynb 7
class ConvBlock(nn.Module):
"""
A convolutional block module combining a convolutional layer, a normalization layer,
@@ -73,7 +73,7 @@ def forward(self, x):
out = self.act(out)
return out

# %% ../../nbs/04_model_utils.ipynb 8
# %% ../../nbs/04_model_utils.ipynb 9
class BasicBlock(nn.Module):
"""
A basic residual block module for neural networks.
@@ -102,26 +102,14 @@ def forward(self, x):

return out

# %% ../../nbs/04_model_utils.ipynb 10
"""
# %% ../../nbs/04_model_utils.ipynb 11
def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# spconv 2.x behaviour
return out.replace_feature(new_features)
else:
out.features = new_features
return out
"""

# %% ../../nbs/04_model_utils.ipynb 11
def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# Use the replace_feature method for SparseConvTensor
return out.replace_feature(new_features)
else:
# Assuming `out` is a SparseConvTensor and it does not have replace_feature method
return spconv.pytorch.SparseConvTensor(new_features, out.indices, out.spatial_shape, out.batch_size)


# %% ../../nbs/04_model_utils.ipynb 12
class SparseConvBlock(spconv.pytorch.SparseModule):

0 comments on commit a75581b

Please sign in to comment.