Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
change x, w, dL_dY variable names to input, weight, grad_output (#323)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #323

The following naming scheme matches the rest of PyTorch better:

```Python
// forward
output = input @ weight_t
// backward
grad_input = grad_output @ weight
grad_weight = input_t @ grad_output
```

This PR changes all the previous references to `x`, `w`, `dL_dY` to
match the naming scheme above.

Reviewed By: drisspg

Differential Revision: D60072596

fbshipit-source-id: 74e89d154a698a0dae8c92f39e2267409b151642
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 23, 2024
1 parent 9d5f892 commit 603efc2
Show file tree
Hide file tree
Showing 19 changed files with 349 additions and 305 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ufmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ jobs:
pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1
- name: Analyzing the code with ufmt
run: |
ufmt format .
git diff
git restore .
ufmt check .
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pip install -e ".[dev]"

# Single GPU User API

We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).

## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`
## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`

This is the most accurate recipe as every tensor is scaled dynamically.

Expand Down Expand Up @@ -95,9 +95,9 @@ m = Model(...)
# type
swap_linear_with_float8_linear(
m,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
Expand Down
40 changes: 21 additions & 19 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def main(
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
Expand Down Expand Up @@ -136,9 +136,9 @@ def main(
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
scaling_type_input=scaling_type_input,
scaling_type_weight=scaling_type_weight,
scaling_type_grad_output=scaling_type_grad_output,
)
scaling_repr = linear_float8.scaling_repr()

Expand All @@ -153,7 +153,9 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

Expand Down Expand Up @@ -278,18 +280,18 @@ def invoke_main() -> None:
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--scaling_type_x", type=str, required=False)
parser.add_argument("--scaling_type_w", type=str, required=False)
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.scaling_type_x is not None:
kwargs["scaling_type_x"] = args.scaling_type_x
if args.scaling_type_w is not None:
kwargs["scaling_type_w"] = args.scaling_type_w
if args.scaling_type_dL_dY is not None:
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
if args.scaling_type_input is not None:
kwargs["scaling_type_input"] = args.scaling_type_input
if args.scaling_type_weight is not None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
main(
output_path,
not args.disable_compile,
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
swap_linear_with_float8_linear(
m,
emulate=False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
)
return m

Expand Down
27 changes: 16 additions & 11 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,23 @@ def profile_function(
def main(
profile_path_prefix: Path,
compile: bool = True,
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_repr = "_".join(
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
[
s.short_str()
for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
]
)

print(f"Compile is set to | {compile}")
Expand Down Expand Up @@ -254,9 +257,9 @@ def main(
m_ref = m_ref.to(device).to(ref_dtype)

extra_kwargs = {
"scaling_type_x": scaling_type_x,
"scaling_type_w": scaling_type_w,
"scaling_type_dL_dY": scaling_type_dL_dY,
"scaling_type_input": scaling_type_input,
"scaling_type_weight": scaling_type_weight,
"scaling_type_grad_output": scaling_type_grad_output,
}

m_float8 = copy.deepcopy(m_ref)
Expand All @@ -278,7 +281,9 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down
Loading

0 comments on commit 603efc2

Please sign in to comment.