Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stride into KJT pytree #2587

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TroyGarden
Copy link
Contributor

Summary:

context

  • Previously for a KJT, only the following fields and _keys are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
  • Particularly, the stride (int) of a KJT, which represents the batch_size, is computed by _maybe_compute_stride_kjt:
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
  • The previously stored pytree flatten specs are enough if the batch_size is static, however, this no longer holds true in a variable batch size scenario, where the stride_per_key_per_rank is not None.
  • An example is that with dedup_ebc, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the true stride (static).
  • During ir_export, the output shape will be calculated from kjt.stride() function, which would be incorrect if the pytree specs only contains the keys.
  • This diff adds the stride into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 23, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

Summary:

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. 
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). 
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. 
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants