Skip to content

Commit

Permalink
Merge pull request #65 from RWKV/rwkv-x-selective-loss-exp
Browse files Browse the repository at this point in the history
Rwkv x selective loss exp
  • Loading branch information
PicoCreator authored Jan 24, 2024
2 parents 42cf1af + d57efd2 commit 66e37ad
Show file tree
Hide file tree
Showing 24 changed files with 141,438 additions and 25 deletions.
10 changes: 10 additions & 0 deletions RWKV-v5/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ data:
# this can be used together with sort_by_length, otherwise a shuffle will be done
packing_in_sequence: False

# ----------------------------
# Specal use caes flags
# ----------------------------

# Reverse the training dataset order before saving, this is useful for,
# optimizing dataset packing process, when using packing_in_sequence
# and sort_by_length desc order together
reverse_train_dataset_before_save: False


# Path to the current checkpoint to continue training from
# this should be the directory path, and ends with `.ckpt/`
ckpt_path: null
16 changes: 16 additions & 0 deletions RWKV-v5/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,13 @@ def merge_into_existing_samples(i):
# Get the subset of the dataset
src_dataset["train"] = src_dataset["train"].select(range(offset_val, offset_val + length_val))

# Dataset flipping (if needed)
if kargs["reverse_train_dataset_before_save"]:
train_dataset = src_dataset["train"]
def reverse_dataset(x, idx):
return train_dataset[train_dataset.num_rows - idx - 1]
src_dataset["train"] = src_dataset["train"].map(reverse_dataset, with_indices=True, num_proc=num_cpus)

# Save the dataset to disk
src_dataset.save_to_disk(kargs["data_path"])

Expand Down Expand Up @@ -961,6 +968,15 @@ def __init__(
# this can be used together with sort_by_length, otherwise a shuffle will be done
packing_in_sequence: bool = False,

# ----------------------------
# Specal use caes flags
# ----------------------------

# Reverse the training dataset order before saving, this is useful for,
# optimizing dataset packing process, when using packing_in_sequence
# and sort_by_length desc order together
reverse_train_dataset_before_save: bool = False,

# ----------------------------
# System tweaks
# ----------------------------
Expand Down
35 changes: 20 additions & 15 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,13 +1139,18 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# https://lightning.ai/docs/pytorch/2.0.4/common/lightning_module.html#backward
learning_loss = segment_train_loss / gradient_accumulation_steps

# Undocumented multiple backward pass support
# https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251
self.manual_backward(learning_loss, optimizer, retain_graph=True)

# Accumulate without gradient, as we already did the backward pass
# This does mean, that a single backward pass is "wasted" at the end
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
# Perform the backward pass accordingly, for valid segments (besides the last segment)
if i == start_learning_segment + backward_segment_count - 1:
# This is the last backward pass, we let the default pytorch lightning handle the backward pass
# and return the segment loss as part of the total loss
training_loss = training_loss + segment_train_loss
else:
# Undocumented multiple backward pass support
# https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251
self.manual_backward(learning_loss, optimizer, retain_graph=True)

# Accumulate without gradient, as we already did the backward pass
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
else:
# Even if its not the segments we use for backward pass, we still need to accumulate the loss
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
Expand Down Expand Up @@ -1234,7 +1239,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,

# Throw if total loss is NaN
assert not torch.isnan(training_loss), "training_loss is NaN"
return training_loss
return sampling_loss, training_loss

#
# Training and validation steps
Expand All @@ -1244,9 +1249,9 @@ def training_step(self, batch, batch_idx):
# print("=== BATCH ID SHAPE ===", batch["input_ids"].shape)
# print("=== BATCH AM SHAPE ===", batch["attention_mask"].shape)

total_loss = self.compute_loss(batch, batch_idx, True)
sampling_loss, training_loss = self.compute_loss(batch, batch_idx, True)

self.log('train/loss', total_loss, prog_bar=True)
self.log('train/loss', training_loss, prog_bar=True)
# If set - forces the above train/loss log line to always be on a new line
if self.substep_logging:
print("")
Expand All @@ -1256,21 +1261,21 @@ def training_step(self, batch, batch_idx):
torch.cuda.empty_cache()

# if loss not a number return None
if torch.isnan(total_loss):
if torch.isnan(training_loss):
return None

return total_loss
return training_loss

@TCompileBaseline
def validation_step(self, batch, batch_idx):
total_loss = self.compute_loss(batch, batch_idx, False)
self.log('validation/loss', total_loss, prog_bar=True, sync_dist=True)
sampling_loss, training_loss = self.compute_loss(batch, batch_idx, False)
self.log('validation/loss', sampling_loss, prog_bar=True, sync_dist=True)

# Reset the token tracking accordingly
self._counting_tokens = 0
self._counting_time_start = time.time()

return total_loss
return sampling_loss

### ---
# SimpleRWKV, a wrapper for RWKV that allows for simple usage of the model
Expand Down
Loading

0 comments on commit 66e37ad

Please sign in to comment.