diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4950f8aa..a4795e59 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -29,7 +29,7 @@ @dataclasses.dataclass class GPTDataBatch: ids: torch.Tensor - spans: torch.Tensor + spans: list[torch.Tensor] def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTDataBatch: diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index e553e03c..266c30ab 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -46,6 +46,8 @@ def name(self) -> str: def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # FIM # TODO: permute segments in sample_list, before concatenating. + if self._config.rate > 0.0 and sample.spans.size > 0: + raise NotImplementedError("FIM is currently not compatible with loss masking.") sample_len = sample.ids.shape[0] eod = self._tokenizer.eod segment_breaks = np.argwhere(sample.ids == eod) # split sample by document diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index ced4dcc8..92b4ba26 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -65,21 +65,22 @@ def tokenize_with_spans( for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) beginning_of_text = False input_ids.extend(tokenized_text) curr_text = text[start : end + 1] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + if end >= len(text) - 1: + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) + else: + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) beginning_of_text = False token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) input_ids.extend(tokenized_text) char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text) + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) input_ids.extend(tokenized_text) - if self.special_tokens_mode in [SpecialTokensMode.eos_only, SpecialTokensMode.bos_eos]: - input_ids.append(self.eod_id) return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0e806dd7..62f120f4 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -11,10 +11,8 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -29,7 +27,7 @@ def torch_cross_entropy_forward_backward( if grad_output is None: loss = None else: - loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean() + loss = torch.nn.functional.cross_entropy(logits_, target).mean() loss.backward(torch.full_like(loss, grad_output)) loss.detach_() return loss.detach(), logits_.grad.detach().to(logits.dtype) @@ -39,10 +37,8 @@ def torch_cross_entropy_forward_backward( def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -61,7 +57,6 @@ def fused_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) @@ -69,7 +64,7 @@ def fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - grad.index_put_((loss_mask,), exp_logits.to(logits.dtype)) + grad = exp_logits.to(logits.dtype) loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() @@ -80,11 +75,9 @@ def fused_cross_entropy_forward_backward( def parallel_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -113,7 +106,6 @@ def parallel_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits1 = exp_logits.scatter( 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) ) @@ -121,7 +113,7 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - grad.index_put_((loss_mask,), exp_logits2.to(logits.dtype)) + grad = exp_logits2.to(logits.dtype) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) @@ -157,10 +149,16 @@ def cross_entropy_forward_backward( logits = logits[loss_mask] if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return parallel_cross_entropy_forward_backward( - logits, target, loss_mask, grad_output, group, logits_scale_factor=logits_scale_factor + loss, grad_logits = parallel_cross_entropy_forward_backward( + logits, target, grad_output, group, logits_scale_factor=logits_scale_factor ) else: - return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, loss_mask, grad_output, logits_scale_factor=logits_scale_factor + loss, grad_logits = _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( + logits, target, grad_output, logits_scale_factor=logits_scale_factor ) + if grad_logits is not None: + grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) + grad.index_put_((loss_mask,), grad_logits) + return loss, grad + else: + return loss, grad_logits diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index f78f1cc0..9e6e697f 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -9,7 +9,6 @@ def triton_cross_entropy_forward_backward_kernel( logits_ptr, labels_ptr, - loss_mask_ptr, grad_logits_ptr, losses_ptr, grad_losses, @@ -82,7 +81,6 @@ def triton_cross_entropy_forward_backward( triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, - loss_mask, grad_logits, losses, 1 if grad_output is None else grad_output / n_rows, @@ -93,6 +91,4 @@ def triton_cross_entropy_forward_backward( block_size=block_size, num_warps=num_warps, ) - full_grad_logits = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) - full_grad_logits.index_put_((loss_mask,), grad_logits) - return losses.mean(), None if grad_output is None else full_grad_logits + return losses.mean(), None if grad_output is None else grad_logits