diff --git a/nbs/03_build_loader.ipynb b/nbs/03_build_loader.ipynb index e895c15..9363471 100644 --- a/nbs/03_build_loader.ipynb +++ b/nbs/03_build_loader.ipynb @@ -136,6 +136,44 @@ " print(f\"{key}: {value}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build DataLoader\n", + "\n", + "The `build_dataloader` function is a utility for creating a PyTorch `DataLoader` with added support for distributed training. Here's a breakdown of what the function does:\n", + "\n", + "1. **Distributed Training Support**:\n", + " - The function first checks if distributed training is initialized using `dist.is_initialized()`, if distributed training is active, it retrieves the rank and world size of the current process using `dist.get_rank()` and `dist.get_world_size()`.\n", + " - It then creates a `DistributedSampler`, which ensures that each process gets a different subset of the dataset. This sampler is used to handle data loading in a distributed manner.\n", + " - If distributed training is not initialized, it defaults to using no sampler.\n", + "\n", + "2. **Creating the DataLoader**:\n", + " - The function creates a `DataLoader` using the provided dataset, batch size, number of workers, shuffle, and pin memory options.\n", + " - It uses the sampler if one was created; otherwise, it shuffles the data if `shuffle` is set to `True`.\n", + "\n", + "### Parameters Abstracted from PyTorch Direct Implementation\n", + "\n", + "The function abstracts away the following details from a direct PyTorch `DataLoader` implementation:\n", + "- **DistributedSampler**: Automatically handles creating and using a `DistributedSampler` when distributed training is initialized.\n", + "- **Sampler Management**: Abstracts the logic for deciding when to use a sampler and whether to shuffle the data.\n", + "- **Collate Function**: Assumes a specific `collate_fn` (`collate`) is used, simplifying the `DataLoader` creation by not requiring the user to specify it.\n", + "\n", + "### Limitations\n", + "\n", + "- **Fixed Collate Function**: The function uses a predefined `collate_fn`. If a different collate function is needed, the user must manually modify the function.\n", + "- **Limited Customization**: The function only exposes a subset of possible `DataLoader` parameters (batch size, number of workers, shuffle, and pin memory). For more advanced customization, the user might need to modify the function or revert to directly creating a `DataLoader`. PyTorch `DataLoader` supports advanced features such as `persistent_workers`, `worker_init_fn`, and `timeout`. The function does not expose these features, limiting its flexibility for more complex use cases.\n", + "- **Distributed Training Dependency**: The function relies on PyTorch's distributed package (`torch.distributed`) to determine if distributed training is initialized. If used in a non-distributed context without the appropriate setup, the distributed checks and sampler creation might add unnecessary complexity.\n", + "\n", + "### Further Enhancements\n", + "\n", + "Some potential enhancements to the function include:\n", + "\n", + "- **Custom Collate Function**: Allow users to specify a custom `collate_fn` for more flexibility in data processing.\n", + "- **Expose Advanced DataLoader Parameters**: Provide additional parameters for more advanced `DataLoader` configurations using **kwargs." + ] + }, { "cell_type": "code", "execution_count": null, @@ -148,8 +186,8 @@ " num_workers=8, # Number of workers\n", " shuffle:bool=False, # Shuffle the data\n", " pin_memory=False # Pin memory\n", - " ):\n", - " \"\"\"This function is designed to build a DataLoader object for a given dataset.\"\"\"\n", + " ): # A PyTorch DataLoader instance with the specified configuration.\n", + " \"\"\"This function is designed to build a DataLoader object for a given dataset with optional distributed training support.\"\"\"\n", " if dist.is_initialized():\n", " rank = dist.get_rank()\n", " world_size = dist.get_world_size()\n", diff --git a/pillarnext_explained/datasets/build_loader.py b/pillarnext_explained/datasets/build_loader.py index 00e68d1..d1ceee5 100644 --- a/pillarnext_explained/datasets/build_loader.py +++ b/pillarnext_explained/datasets/build_loader.py @@ -44,14 +44,14 @@ def collate(batch_list): return ret -# %% ../../nbs/03_build_loader.ipynb 6 +# %% ../../nbs/03_build_loader.ipynb 7 def build_dataloader(dataset, # Dataset object batch_size=4, # Batch size num_workers=8, # Number of workers shuffle:bool=False, # Shuffle the data pin_memory=False # Pin memory - ): - """This function is designed to build a DataLoader object for a given dataset.""" + ): # A PyTorch DataLoader instance with the specified configuration. + """This function is designed to build a DataLoader object for a given dataset with optional distributed training support.""" if dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size()