diff --git a/ft_datasets/alpaca_dataset.py b/ft_datasets/alpaca_dataset.py index 4d492460f..77cbb27ea 100644 --- a/ft_datasets/alpaca_dataset.py +++ b/ft_datasets/alpaca_dataset.py @@ -42,6 +42,9 @@ def __len__(self): return len(self.ann) def __getitem__(self, index): + IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + + ann = self.ann[index] if ann.get("input", "") == "": prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) @@ -66,7 +69,7 @@ def __getitem__(self, index): example_mask = example.ge(0) label_mask = labels.ge(0) example[~example_mask] = 0 - labels[~label_mask] = 0 + labels[~label_mask] = IGNORE_INDEX example_mask = example_mask.float() label_mask = label_mask.float()