Skip to content

Commit 4db76ce

Browse files
committed
extra documentation for ragged coll as per the comment
- #48 (comment)
1 parent f39916b commit 4db76ce

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

chebai/preprocessing/collate.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData:
4141

4242

4343
class RaggedCollator(Collator):
44-
"""Collator for handling ragged data samples."""
44+
"""
45+
Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None).
46+
47+
This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes,
48+
such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some
49+
of the data samples might be partially labeled, which is useful for certain loss functions that allow training
50+
with incomplete or fuzzy data (e.g., fuzzy loss).
51+
52+
During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate
53+
between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled
54+
data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for
55+
metrics computation such as F1-score or MSE, especially in cases where some data points lack labels.
56+
57+
Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829
58+
"""
4559

4660
def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
47-
"""Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into
48-
a batch.
61+
"""
62+
Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch.
63+
64+
Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
65+
of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
66+
unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
67+
ensures alignment between features and labels.
4968
5069
Args:
51-
data (List[Union[Dict, Tuple]]): List of ragged data samples.
70+
data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
71+
with 'features', 'labels', and 'ident'.
5272
5373
Returns:
54-
XYData: Batched data with appropriate padding and masks.
74+
XYData: A batch of padded sequences and labels, including masks for valid positions and indices of
75+
non-null labels for metric computation.
5576
"""
5677
model_kwargs: Dict = dict()
78+
# Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs.
5779
loss_kwargs: Dict = dict()
5880

5981
if isinstance(data[0], tuple):
@@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
6486
*((d["features"], d["labels"], d.get("ident")) for d in data)
6587
)
6688
if any(x is not None for x in y):
89+
# If any label is not None: (None, None, `1`, None)
6790
if any(x is None for x in y):
91+
# If any label is None: (`None`, `None`, 1, `None`)
6892
non_null_labels = [i for i, r in enumerate(y) if r is not None]
6993
y = self.process_label_rows(
7094
tuple(ye for i, ye in enumerate(y) if i in non_null_labels)
7195
)
7296
loss_kwargs["non_null_labels"] = non_null_labels
7397
else:
98+
# If all labels are not None: (`0`, `2`, `1`, `3`)
7499
y = self.process_label_rows(y)
75100
else:
101+
# If all labels are None : (`None`, `None`, `None`, `None`)
76102
y = None
77103
loss_kwargs["non_null_labels"] = []
78104

105+
# Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
79106
lens = torch.tensor(list(map(len, x)))
80107
model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None]
81108
model_kwargs["lens"] = lens
@@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
89116
)
90117

91118
def process_label_rows(self, labels: Tuple) -> torch.Tensor:
92-
"""Process label rows by padding sequences.
119+
"""
120+
Process label rows by padding sequences to ensure uniform shape across the batch.
121+
122+
This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor.
123+
It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`).
93124
94125
Args:
95126
labels (Tuple): Tuple of label rows.

0 commit comments

Comments
 (0)