Skip to content

Commit

Permalink
Adding R3 changes for v5
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCreator committed Sep 6, 2023
1 parent 8582908 commit beb46d5
Show file tree
Hide file tree
Showing 19 changed files with 170,392 additions and 21 deletions.
2 changes: 1 addition & 1 deletion RWKV-v4neo/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ data:

# Use data_dir, if you are using source=text/json/etc
# If using relative path, this should be relative to the trainer script path
# source_data_dir: ../dataset-text/
source_data_dir: ../dataset-text/

# After loading the dataset, split out test data used for validation,
# This process is skipped if the dataset includes a test split
Expand Down
2 changes: 1 addition & 1 deletion RWKV-v4neo/config-minimum-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ data:

# Use data_dir, if you are using source=text/json/etc
# If using relative path, this should be relative to the trainer script path
# source_data_dir: ../dataset-json-dir/
source_data_dir: ../dataset-json-dir/

# Tokenizer to use, use either the inbuilt 'neox', or 'world' tokenizer
# If using a custom tokenizer, provide the HF tokenizer name/path
Expand Down
6 changes: 5 additions & 1 deletion RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,11 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
if self.trainer.num_devices > 1:
if self.bptt_learning_range <= 0:
# We perform forward/backward on the shared max segment count across all GPUs
forward_segment_count = self.trainer.strategy.reduce(segment_count, reduce_op="max")
# ---
# we map it to be a tensor, instead of the int directly, as this is more reliable across certain versions of torch/lightning
# https://discord.com/channels/992359628979568762/1148755392638234697/1148821863749931008
forward_segment_count = self.trainer.strategy.reduce(torch.Tensor([segment_count]).to(torch.int), reduce_op="max")

# Convert to int, if its a torch tensor
if isinstance(forward_segment_count, torch.Tensor):
forward_segment_count = forward_segment_count.item()
Expand Down
Loading

0 comments on commit beb46d5

Please sign in to comment.