Skip to content

Commit

Permalink
fix lint issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Feb 5, 2024
1 parent c9d1bf5 commit 0665d21
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ jobs:
# 1.14.0a0+410ce96
- torch: "1.14"
nvcr: 22.12-py3
dir: torch1
# 2.1.0a0+fe05266f
- torch: "2.1"
nvcr: 23.10-py3
dir: torch2
container:
image: nvcr.io/nvidia/pytorch:${{ matrix.nvcr }}
options: --privileged --ipc=host --gpus=all
steps:
- name: Clean submodules
run: |
rm -rf third_party
- name: Checkout msamp
uses: actions/checkout@v2
with:
submodules: true
path: ${{ matrix.dir }}
- name: Install MSCCL
run: |
cd third_party/msccl
cd ${{ matrix.dir }}/third_party/msccl
make -j src.build NVCC_GENCODE="\
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_80,code=sm_80 \
Expand All @@ -42,15 +42,18 @@ jobs:
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
python3 -m pip install --upgrade pip
cd ${{ matrix.dir }}/
python3 -m pip install .[test]
make postinstall
- name: Run code lint
run: |
cd ${{ matrix.dir }}/
python3 setup.py lint
- name: Run unit tests
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
cd ${{ matrix.dir }}/
python3 setup.py test
# - name: Report coverage results
# run: |
Expand Down
1 change: 0 additions & 1 deletion msamp/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ def _fp8_post_backward_hook(state, handle, *unused):
setattr(state, comm_hook_attr, state._get_fp8_comm_hook())
old_post_backward_hook(state, handle, *unused)
setattr(state, comm_hook_attr, old_communication_hook)

3 changes: 2 additions & 1 deletion msamp/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
class FP8FlatParamHandle(FlatParamHandle):
"""A handle for a flat parameter which may have fp32 and fp8."""
def __init__(self, *args, **kwargs):
super().__init__( *args, **kwargs)
"""Constructor."""
super().__init__(*args, **kwargs)
self._init_fp8_meta()

def _init_fp8_meta(self):
Expand Down
7 changes: 4 additions & 3 deletions msamp/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.fsdp._init_utils import _get_default_comm_hook

from msamp.fsdp.flat_param import FP8FlatParamHandle
Expand All @@ -31,10 +30,12 @@ def _fp8_allreduce_hook(state, grad, output):
from msamp.operators.dist_op import DistOp
dtype = Dtypes.get_dtype_from_qtype(meta.qtype)
DistOp.enable_fp8(meta.qtype)
torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group if state else None)
torch.distributed.all_reduce(
grad[start:end].view(dtype), group=state.process_group if state else None
)
DistOp.disable_fp8()
else:
torch.distributed.all_reduce(grad[start:end], group=state.process_group if state else None)
torch.distributed.all_reduce(grad[start:end], group=state.process_group if state else None)

start = self.rank * output.numel()
end = (self.rank + 1) * output.numel()
Expand Down
6 changes: 2 additions & 4 deletions msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def set_fp8_weights(self):
return

setattr(
self,
weight_cast_attr,
self, weight_cast_attr,
Float8Tensor(
data=torch.empty(
shape,
Expand All @@ -80,8 +79,7 @@ def set_fp8_weights(self):
)

setattr(
self,
weight_transpose_attr,
self, weight_transpose_attr,
Float8Tensor(
data=torch.empty(
shape[1],
Expand Down

0 comments on commit 0665d21

Please sign in to comment.