Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel sharding #21

Merged
merged 22 commits into from
Apr 10, 2024
Merged

Parallel sharding #21

merged 22 commits into from
Apr 10, 2024

Commits on Apr 9, 2024

  1. Configuration menu
    Copy the full SHA
    d75ba94 View commit details
    Browse the repository at this point in the history
  2. feat: import transformer's gemma modeling code

    It will be used to adapt it for sharding. Only imports have been
    adapted, and only code relevant for GemmaForCausalLM has been added.
    tengomucho committed Apr 9, 2024
    Configuration menu
    Copy the full SHA
    0ee7430 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    ca88068 View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    a3de4d7 View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    80170a9 View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    9a9bcf8 View commit details
    Browse the repository at this point in the history
  7. Configuration menu
    Copy the full SHA
    5bf6c70 View commit details
    Browse the repository at this point in the history
  8. fix(TpuGemma): avoid using device_map when loading model

    It seems that device_map parameter triggers a chain of calls that will
    try to use accelerate to load the model using less memory. The problem
    is that it skips the load state pre-hooks, making the weights loading
    impossible.
    tengomucho committed Apr 9, 2024
    Configuration menu
    Copy the full SHA
    9dfb7b6 View commit details
    Browse the repository at this point in the history
  9. feat(gemma): sharding o_proj

    It will now be running in parallel. More changes to come.
    tengomucho committed Apr 9, 2024
    Configuration menu
    Copy the full SHA
    ec3b752 View commit details
    Browse the repository at this point in the history
  10. Configuration menu
    Copy the full SHA
    a7d7c0b View commit details
    Browse the repository at this point in the history
  11. Configuration menu
    Copy the full SHA
    b6fe32e View commit details
    Browse the repository at this point in the history
  12. Configuration menu
    Copy the full SHA
    e13d9ec View commit details
    Browse the repository at this point in the history
  13. Configuration menu
    Copy the full SHA
    6cdede2 View commit details
    Browse the repository at this point in the history
  14. feat: model il loaded using pytorch_dtype from config

    This will lead to loading the model in bfloat16 when specified in the
    config.
    tengomucho committed Apr 9, 2024
    Configuration menu
    Copy the full SHA
    cd99226 View commit details
    Browse the repository at this point in the history
  15. Configuration menu
    Copy the full SHA
    550e1fb View commit details
    Browse the repository at this point in the history
  16. Configuration menu
    Copy the full SHA
    2215595 View commit details
    Browse the repository at this point in the history

Commits on Apr 10, 2024

  1. Configuration menu
    Copy the full SHA
    fe888a9 View commit details
    Browse the repository at this point in the history
  2. fix: get_generation_mode is now a method of generation_config

    API change when transformers was updated.
    tengomucho committed Apr 10, 2024
    Configuration menu
    Copy the full SHA
    dbf11f7 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    a96903b View commit details
    Browse the repository at this point in the history
  4. fix(generator): fix sample generation again

    I wrongly chose the model's generation config instead of the one to the
    token selector.
    tengomucho committed Apr 10, 2024
    Configuration menu
    Copy the full SHA
    6e6b44e View commit details
    Browse the repository at this point in the history
  5. fix: better handle torch_dtype

    bfloat16 will be set by default in gemma models, other models will still
    load in float32 by default.
    tengomucho committed Apr 10, 2024
    Configuration menu
    Copy the full SHA
    92e9e31 View commit details
    Browse the repository at this point in the history
  6. fix: remove unused import

    tengomucho committed Apr 10, 2024
    Configuration menu
    Copy the full SHA
    7901d91 View commit details
    Browse the repository at this point in the history