Skip to content

Commit

Permalink
Improve build_loader documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
opedromartins committed Aug 1, 2024
1 parent 899eb98 commit 5c161da
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
42 changes: 40 additions & 2 deletions nbs/03_build_loader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,44 @@
" print(f\"{key}: {value}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build DataLoader\n",
"\n",
"The `build_dataloader` function is a utility for creating a PyTorch `DataLoader` with added support for distributed training. Here's a breakdown of what the function does:\n",
"\n",
"1. **Distributed Training Support**:\n",
" - The function first checks if distributed training is initialized using `dist.is_initialized()`, if distributed training is active, it retrieves the rank and world size of the current process using `dist.get_rank()` and `dist.get_world_size()`.\n",
" - It then creates a `DistributedSampler`, which ensures that each process gets a different subset of the dataset. This sampler is used to handle data loading in a distributed manner.\n",
" - If distributed training is not initialized, it defaults to using no sampler.\n",
"\n",
"2. **Creating the DataLoader**:\n",
" - The function creates a `DataLoader` using the provided dataset, batch size, number of workers, shuffle, and pin memory options.\n",
" - It uses the sampler if one was created; otherwise, it shuffles the data if `shuffle` is set to `True`.\n",
"\n",
"### Parameters Abstracted from PyTorch Direct Implementation\n",
"\n",
"The function abstracts away the following details from a direct PyTorch `DataLoader` implementation:\n",
"- **DistributedSampler**: Automatically handles creating and using a `DistributedSampler` when distributed training is initialized.\n",
"- **Sampler Management**: Abstracts the logic for deciding when to use a sampler and whether to shuffle the data.\n",
"- **Collate Function**: Assumes a specific `collate_fn` (`collate`) is used, simplifying the `DataLoader` creation by not requiring the user to specify it.\n",
"\n",
"### Limitations\n",
"\n",
"- **Fixed Collate Function**: The function uses a predefined `collate_fn`. If a different collate function is needed, the user must manually modify the function.\n",
"- **Limited Customization**: The function only exposes a subset of possible `DataLoader` parameters (batch size, number of workers, shuffle, and pin memory). For more advanced customization, the user might need to modify the function or revert to directly creating a `DataLoader`. PyTorch `DataLoader` supports advanced features such as `persistent_workers`, `worker_init_fn`, and `timeout`. The function does not expose these features, limiting its flexibility for more complex use cases.\n",
"- **Distributed Training Dependency**: The function relies on PyTorch's distributed package (`torch.distributed`) to determine if distributed training is initialized. If used in a non-distributed context without the appropriate setup, the distributed checks and sampler creation might add unnecessary complexity.\n",
"\n",
"### Further Enhancements\n",
"\n",
"Some potential enhancements to the function include:\n",
"\n",
"- **Custom Collate Function**: Allow users to specify a custom `collate_fn` for more flexibility in data processing.\n",
"- **Expose Advanced DataLoader Parameters**: Provide additional parameters for more advanced `DataLoader` configurations using **kwargs."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -148,8 +186,8 @@
" num_workers=8, # Number of workers\n",
" shuffle:bool=False, # Shuffle the data\n",
" pin_memory=False # Pin memory\n",
" ):\n",
" \"\"\"This function is designed to build a DataLoader object for a given dataset.\"\"\"\n",
" ): # A PyTorch DataLoader instance with the specified configuration.\n",
" \"\"\"This function is designed to build a DataLoader object for a given dataset with optional distributed training support.\"\"\"\n",
" if dist.is_initialized():\n",
" rank = dist.get_rank()\n",
" world_size = dist.get_world_size()\n",
Expand Down
6 changes: 3 additions & 3 deletions pillarnext_explained/datasets/build_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def collate(batch_list):

return ret

# %% ../../nbs/03_build_loader.ipynb 6
# %% ../../nbs/03_build_loader.ipynb 7
def build_dataloader(dataset, # Dataset object
batch_size=4, # Batch size
num_workers=8, # Number of workers
shuffle:bool=False, # Shuffle the data
pin_memory=False # Pin memory
):
"""This function is designed to build a DataLoader object for a given dataset."""
): # A PyTorch DataLoader instance with the specified configuration.
"""This function is designed to build a DataLoader object for a given dataset with optional distributed training support."""
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
Expand Down

0 comments on commit 5c161da

Please sign in to comment.