Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 9, 2024
1 parent 13d3198 commit 06571e7
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ m = Model(...)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
swap_linear_with_float8_linear(
m,
m,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
swap_linear_with_float8_linear,
Expand Down
3 changes: 0 additions & 3 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from typing import Any, Optional, Tuple

import float8_experimental.config as config

import torch
import torch.nn as nn
import torch.utils._pytree as pytree

from float8_experimental.float8_tensor import (
Expand Down
4 changes: 1 addition & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type, Union
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
Expand Down
2 changes: 1 addition & 1 deletion test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn.functional as F

from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor_parallel import (
Expand Down
4 changes: 2 additions & 2 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
swap_linear_with_float8_linear,
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward_backward(model, optim, is_fp8, i):
model_fp8 = torch.compile(model_fp8)
y_local = forward_backward(model, optimizer, is_fp8=False, i=i)
y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i)
local_sqnr = compute_error(y_local, y_local_fp8)
local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841

# get global y
y_global = [
Expand Down
3 changes: 1 addition & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import List, Type
from typing import List

import float8_experimental.config as config

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear


def check_parity_no_mp(
Expand Down
3 changes: 1 addition & 2 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import itertools
import threading
import unittest
from typing import Any, List
Expand All @@ -9,7 +8,7 @@
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from test_fsdp2_common import (
check_parity_bf16_mp,
Expand Down
2 changes: 1 addition & 1 deletion test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental import config
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down
2 changes: 1 addition & 1 deletion test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import compute_error
Expand Down
2 changes: 1 addition & 1 deletion test/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
swap_linear_with_float8_linear,
Expand Down

0 comments on commit 06571e7

Please sign in to comment.