Skip to content

Commit

Permalink
Merge pull request #46 from Mon-ius/main
Browse files Browse the repository at this point in the history
rm fairescale
  • Loading branch information
pengchongjin authored Mar 21, 2024
2 parents acd24a8 + bbffa7a commit cf8658c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 13 deletions.
2 changes: 0 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ RUN apt-get install -y --no-install-recommends git
# Install libraries.
ENV PIP_ROOT_USER_ACTION=ignore
RUN python -m pip install --upgrade pip
RUN pip install fairscale==0.4.13
RUN pip install numpy==1.24.4
RUN pip install immutabledict==4.1.0
RUN pip install sentencepiece==0.1.99

# Install from source.
Expand Down
2 changes: 0 additions & 2 deletions docker/xla.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ RUN apt-get install -y --no-install-recommends git
# Install libraries.
ENV PIP_ROOT_USER_ACTION=ignore
RUN python3 -m pip install --upgrade pip
RUN pip install fairscale==0.4.13
RUN pip install numpy==1.24.4
RUN pip install immutabledict==4.1.0
RUN pip install sentencepiece==0.1.99

# Install from source.
Expand Down
2 changes: 0 additions & 2 deletions docker/xla_gpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ RUN python3 -m pip install --upgrade pip
RUN pip uninstall -y torch
RUN pip install torch==2.1.1
RUN pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.2.0rc1-cp38-cp38-linux_x86_64.whl
RUN pip install fairscale==0.4.13
RUN pip install numpy==1.24.4
RUN pip install immutabledict==4.1.0
RUN pip install sentencepiece==0.1.99

# Install from source.
Expand Down
5 changes: 2 additions & 3 deletions gemma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""Gemma model config."""

import dataclasses
import immutabledict
import torch
from typing import Optional


# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({
_STR_DTYPE_TO_TORCH_DTYPE = dict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
Expand Down Expand Up @@ -81,4 +80,4 @@ def get_model_config(variant: str) -> GemmaConfig:
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')
'and "7b"')
36 changes: 34 additions & 2 deletions gemma/xla_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from copy import deepcopy
from dataclasses import dataclass
import os
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple

from fairscale.nn.model_parallel.utils import divide_and_check_no_remainder, split_tensor_along_last_dim
import torch
import torch.ao.quantization.fx._decomposed
import torch.distributed as dist
Expand Down Expand Up @@ -212,6 +211,39 @@ def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank)


def ensure_divisibility(numerator: int, denominator: int) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)


def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator


def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)

return tensor_list

# Below copied from fairscale/nn/model_parallel/layers.py


Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
fairscale == 0.4.13
numpy == 1.24.4
immutabledict == 4.1.0
sentencepiece == 0.1.99

0 comments on commit cf8658c

Please sign in to comment.