Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamparikh committed Jan 28, 2025
1 parent d8e3ae1 commit a887dd6
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 26 deletions.
2 changes: 1 addition & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
@dataclasses.dataclass
class GPTDataBatch:
ids: torch.Tensor
spans: torch.Tensor
spans: list[torch.Tensor]


def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTDataBatch:
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def name(self) -> str:
def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample:
# FIM
# TODO: permute segments in sample_list, before concatenating.
if self._config.rate > 0.0 and sample.spans.size > 0:
raise NotImplementedError("FIM is currently not compatible with loss masking.")
sample_len = sample.ids.shape[0]
eod = self._tokenizer.eod
segment_breaks = np.argwhere(sample.ids == eod) # split sample by document
Expand Down
11 changes: 6 additions & 5 deletions fast_llm/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ def tokenize_with_spans(
for start, end in char_spans:
if char_pos < start:
curr_text = text[char_pos:start]
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text)
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False)
beginning_of_text = False
input_ids.extend(tokenized_text)
curr_text = text[start : end + 1]
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text)
if end >= len(text) - 1:
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True)
else:
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False)
beginning_of_text = False
token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1))
input_ids.extend(tokenized_text)
char_pos = end + 1
if char_pos < len(text):
curr_text = text[char_pos:]
tokenized_text = self.tokenize(curr_text)
tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True)
input_ids.extend(tokenized_text)
if self.special_tokens_mode in [SpecialTokensMode.eos_only, SpecialTokensMode.bos_eos]:
input_ids.append(self.eod_id)
return input_ids, token_spans

def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str:
Expand Down
28 changes: 13 additions & 15 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
def torch_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor,
grad_output: float | None,
logits_scale_factor: float = 1.0,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A wrapper for the pytorch implementation of cross-entropy.
Expand All @@ -29,7 +27,7 @@ def torch_cross_entropy_forward_backward(
if grad_output is None:
loss = None
else:
loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean()
loss = torch.nn.functional.cross_entropy(logits_, target).mean()
loss.backward(torch.full_like(loss, grad_output))
loss.detach_()
return loss.detach(), logits_.grad.detach().to(logits.dtype)
Expand All @@ -39,10 +37,8 @@ def torch_cross_entropy_forward_backward(
def fused_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor,
grad_output: float | None,
logits_scale_factor: float = 1.0,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile.
Expand All @@ -61,15 +57,14 @@ def fused_cross_entropy_forward_backward(
if grad_output is None:
grad = None
else:
grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device)
exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1))
# exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits
exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))

if logits_scale_factor != 1.0:
exp_logits *= logits_scale_factor

grad.index_put_((loss_mask,), exp_logits.to(logits.dtype))
grad = exp_logits.to(logits.dtype)

loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean()

Expand All @@ -80,11 +75,9 @@ def fused_cross_entropy_forward_backward(
def parallel_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor,
grad_output: float | None,
group: ProcessGroup,
logits_scale_factor: float = 1.0,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile, with support for tensor parallelism.
Expand Down Expand Up @@ -113,15 +106,14 @@ def parallel_cross_entropy_forward_backward(
if grad_output is None:
grad = None
else:
grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device)
exp_logits1 = exp_logits.scatter(
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1)
)
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
if logits_scale_factor != 1.0:
exp_logits2 *= logits_scale_factor

grad.index_put_((loss_mask,), exp_logits2.to(logits.dtype))
grad = exp_logits2.to(logits.dtype)

predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1)
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
Expand Down Expand Up @@ -157,10 +149,16 @@ def cross_entropy_forward_backward(
logits = logits[loss_mask]
if group:
Assert.eq(implementation, CrossEntropyImpl.fused)
return parallel_cross_entropy_forward_backward(
logits, target, loss_mask, grad_output, group, logits_scale_factor=logits_scale_factor
loss, grad_logits = parallel_cross_entropy_forward_backward(
logits, target, grad_output, group, logits_scale_factor=logits_scale_factor
)
else:
return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
logits, target, loss_mask, grad_output, logits_scale_factor=logits_scale_factor
loss, grad_logits = _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
logits, target, grad_output, logits_scale_factor=logits_scale_factor
)
if grad_logits is not None:
grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device)
grad.index_put_((loss_mask,), grad_logits)
return loss, grad
else:
return loss, grad_logits
6 changes: 1 addition & 5 deletions fast_llm/functional/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
def triton_cross_entropy_forward_backward_kernel(
logits_ptr,
labels_ptr,
loss_mask_ptr,
grad_logits_ptr,
losses_ptr,
grad_losses,
Expand Down Expand Up @@ -82,7 +81,6 @@ def triton_cross_entropy_forward_backward(
triton_cross_entropy_forward_backward_kernel[(n_rows,)](
logits,
target,
loss_mask,
grad_logits,
losses,
1 if grad_output is None else grad_output / n_rows,
Expand All @@ -93,6 +91,4 @@ def triton_cross_entropy_forward_backward(
block_size=block_size,
num_warps=num_warps,
)
full_grad_logits = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device)
full_grad_logits.index_put_((loss_mask,), grad_logits)
return losses.mean(), None if grad_output is None else full_grad_logits
return losses.mean(), None if grad_output is None else grad_logits

0 comments on commit a887dd6

Please sign in to comment.