From cdb78678d543178aab59f7216dc0458f2242f629 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 20 May 2024 17:39:10 -0700 Subject: [PATCH] set vocab_size=32 to avoid must be divisible by 16 error (#265) Summary: `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic` ``` E File "/home/weif/local/pytorch-official/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 205, in forward E output = self.output(h).float() E File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl E return self._call_impl(*args, **kwargs) E File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl E return forward_call(*args, **kwargs) E File "/data/users/weif/float8_experimental/float8_experimental/float8_dynamic_linear.py", line 71, in forward E y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) E File "/data/users/weif/float8_experimental/float8_experimental/float8_tensor.py", line 297, in __torch_dispatch__ E return FLOAT8_OPS_TABLE[func](func, args, kwargs) E File "/data/users/weif/float8_experimental/float8_experimental/float8_ops.py", line 151, in float8_mm E tensor_out, amax = addmm_float8_unwrapped( E File "/data/users/weif/float8_experimental/float8_experimental/float8_python_api.py", line 55, in addmm_float8_unwrapped E output, output_amax = torch._scaled_mm( E RuntimeError: mat2 shape (768x8 must be divisible by 16 E Exception raised from _scaled_mm_out_cuda at /data/users/weif/pytorch-official/pytorch/aten/src/ATen/native/cuda/Blas.cpp:874 (most recent call first): ``` Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/265 Reviewed By: drisspg, awgu Differential Revision: D57596582 Pulled By: weifengpy fbshipit-source-id: 8a00601457c4e72271adbba29dd2af8273173aa3 --- test/test_fsdp2/test_fsdp2_eager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index d9a0824..98ef92b 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -57,7 +57,12 @@ def init_multi_module(self) -> nn.Module: def init_transformer(self, weight_tying: bool) -> nn.Module: torch.manual_seed(42) args = ModelArgs( - n_layers=3, dim=768, n_heads=12, dropout_p=0.0, weight_tying=weight_tying + n_layers=3, + dim=768, + n_heads=12, + dropout_p=0.0, + weight_tying=weight_tying, + vocab_size=32, ) module = Transformer(args).cuda() self.broadcast_module(module)