@@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData:
41
41
42
42
43
43
class RaggedCollator (Collator ):
44
- """Collator for handling ragged data samples."""
44
+ """
45
+ Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None).
46
+
47
+ This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes,
48
+ such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some
49
+ of the data samples might be partially labeled, which is useful for certain loss functions that allow training
50
+ with incomplete or fuzzy data (e.g., fuzzy loss).
51
+
52
+ During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate
53
+ between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled
54
+ data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for
55
+ metrics computation such as F1-score or MSE, especially in cases where some data points lack labels.
56
+
57
+ Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829
58
+ """
45
59
46
60
def __call__ (self , data : List [Union [Dict , Tuple ]]) -> XYData :
47
- """Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into
48
- a batch.
61
+ """
62
+ Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch.
63
+
64
+ Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
65
+ of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
66
+ unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
67
+ ensures alignment between features and labels.
49
68
50
69
Args:
51
- data (List[Union[Dict, Tuple]]): List of ragged data samples.
70
+ data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
71
+ with 'features', 'labels', and 'ident'.
52
72
53
73
Returns:
54
- XYData: Batched data with appropriate padding and masks.
74
+ XYData: A batch of padded sequences and labels, including masks for valid positions and indices of
75
+ non-null labels for metric computation.
55
76
"""
56
77
model_kwargs : Dict = dict ()
78
+ # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs.
57
79
loss_kwargs : Dict = dict ()
58
80
59
81
if isinstance (data [0 ], tuple ):
@@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
64
86
* ((d ["features" ], d ["labels" ], d .get ("ident" )) for d in data )
65
87
)
66
88
if any (x is not None for x in y ):
89
+ # If any label is not None: (None, None, `1`, None)
67
90
if any (x is None for x in y ):
91
+ # If any label is None: (`None`, `None`, 1, `None`)
68
92
non_null_labels = [i for i , r in enumerate (y ) if r is not None ]
69
93
y = self .process_label_rows (
70
94
tuple (ye for i , ye in enumerate (y ) if i in non_null_labels )
71
95
)
72
96
loss_kwargs ["non_null_labels" ] = non_null_labels
73
97
else :
98
+ # If all labels are not None: (`0`, `2`, `1`, `3`)
74
99
y = self .process_label_rows (y )
75
100
else :
101
+ # If all labels are None : (`None`, `None`, `None`, `None`)
76
102
y = None
77
103
loss_kwargs ["non_null_labels" ] = []
78
104
105
+ # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
79
106
lens = torch .tensor (list (map (len , x )))
80
107
model_kwargs ["mask" ] = torch .arange (max (lens ))[None , :] < lens [:, None ]
81
108
model_kwargs ["lens" ] = lens
@@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
89
116
)
90
117
91
118
def process_label_rows (self , labels : Tuple ) -> torch .Tensor :
92
- """Process label rows by padding sequences.
119
+ """
120
+ Process label rows by padding sequences to ensure uniform shape across the batch.
121
+
122
+ This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor.
123
+ It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`).
93
124
94
125
Args:
95
126
labels (Tuple): Tuple of label rows.
0 commit comments