-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental interface for torch ops #189
Experimental interface for torch ops #189
Conversation
See SHI-Labs#184 Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are: * No stable interface for supporting autocasting to fp16/bf16, * Gradient scaling doesn't seem to be supported either, leading to training instability. * Ops cannot indicate that they expect contiguous operands, and need to call `.contiguous()` within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.)
18ecf32
to
5e96411
Compare
@Birch-san could you check and see if the changes here resolve what you need when you get a chance? Note: I kind of assumed you're using NATTEN ops ( I'm still thinking about how to unit test the new FLOP counter and torch compile, but I have already verified both work manually (fvcore reports exactly 0.5 of PyTorch, because I imagine it's set up to report MACs, not FLOPs.) |
hi @alihassanijr, thanks for implementing this and sorry for the delay. I've tried now invoking dispatch works (FlopCounterMode sees a I was able to fix this by declaring this decorated function before running my model under FlopCounterMode context: def fna_generic_flops_(
q: torch.Size,
k: torch.Size,
v: torch.Size,
has_bias: bool,
kernel_size: Sequence[int],
) -> int:
batch_size, heads, dim = (
q[0],
q[-2],
q[-1],
)
spatial_extent: Sequence[int] = q[1 : len(kernel_size) + 1]
spatial_extent_int = math.prod(spatial_extent)
kernel_size_int = math.prod(kernel_size)
flops = batch_size * heads * spatial_extent_int * dim * kernel_size_int # QK
# NOTE: PyTorch doesn't count softmax flops in SDPA;
# Reference:
# https://github.com/pytorch/pytorch/blob/7ced49d2ccf219ec896810e6d988709c3a3a2d9a/torch/utils/flop_counter.py#L241-L256
# flops += batch_size * heads * spatial_extent_int * kernel_size_int # softmax
flops += batch_size * heads * spatial_extent_int * dim * kernel_size_int # AV
if has_bias:
flops += batch_size * heads * spatial_extent_int * kernel_size_int # RPB
return flops
@register_flop_formula(torch.ops.natten.na2d_forward_op)
def na2d_flop(
query: torch.Size,
key: torch.Size,
value: torch.Size,
bias: Optional[torch.Size],
kernel_size_: Sequence[int],
dilation_: Sequence[int],
is_causal_: Sequence[bool],
scale: float,
q_tiler_: Sequence[int],
kv_tiler_: Sequence[int],
*args,
out_shape=tuple[torch.Size, torch.Size],
**kwargs,
) -> int:
return fna_generic_flops_(query, key, value, bias is not None, kernel_size_) Now FLOPs are counted successfully:
I note that this flop count algorithm is different to the one that I wrote by guesswork, in my original post #184 (comment). my own FLOP count made this to be exactly 2x the amount, 36.541B operations. I think the torch counters return FLOs whereas your also, I think it's problematic that users cannot import flops.py if they don't have fvcore installed. this is what forced me to write my own I did initially try to delegate flop counting to your @register_flop_formula(torch.ops.natten.na2d_forward_op)
def na2d_flop(
query: torch.Size,
key: torch.Size,
value: torch.Size,
bias: Optional[torch.Size],
kernel_size_: Sequence[int],
dilation_: Sequence[int],
is_causal_: Sequence[bool],
scale: float,
q_tiler_: Sequence[int],
kv_tiler_: Sequence[int],
*args,
out_shape=tuple[torch.Size, torch.Size],
**kwargs,
) -> int:
inputs: list[torch.Size] = [query, key, value, *[bias]*(bias is not None)]
outputs: list[torch.Size] = list(out_shape)
return fna_generic_flops(inputs, outputs) but it failed on this assertion anyway I recommend to expose flop counter functions that can be used without having fvcore installed, and which take Sizes as inputs rather than tensors, and I also recommend to register flop counters via |
as for graph breaks: I slapped a under
whereas under |
regarding FLOP counter unit test, how about: import torch
from torch import no_grad
from torch._ops import OpOverloadPacket
from torch.utils.flop_counter import FlopCounterMode
from natten.experimental import na2d
bsz = 2
heads = 8
head_dim = 64
wid = 32
hei = 32
kernel_size = 7
dtype = torch.float16
device = torch.device('cuda')
gen = torch.Generator(device=device).manual_seed(42)
q = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)
k = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)
v = torch.randn(bsz, heads, hei, wid, head_dim, device=device, dtype=dtype, generator=gen)
counter = FlopCounterMode()
# we avoid FLOP counting under inference_mode() context because it's poorly-supported;
# even basics like matmuls dispatch different ops under inference_mode, for which no flop counter is registered
with counter, no_grad():
na2d(q, k, v, kernel_size)
global_flops: dict[OpOverloadPacket, int] = counter.flop_counts['Global']
assert torch.ops.natten.na2d_forward_op in global_flops, "na2d FLOPs not counted"
na2d_flops: int = global_flops[torch.ops.natten.na2d_forward_op]
assert na2d_flops != 102760448, "na2d returned MACs instead of FLOs"
assert na2d_flops == 205520896, "na2d returned unexpected FLOP count" |
for compile unit test, I think all you need to do is attempt to invoke the model under fullgraph compilation and see that the program doesn't explode. import torch
from torch import Tensor, inference_mode
from torch.nn import Module
from torch._dynamo.exc import Unsupported
from einops import rearrange
from natten.experimental import na2d
class Attention(Module):
def __init__(
self,
in_dim: int,
heads: int,
head_dim: int,
kernel_size: int,
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.float16,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.kernel_size = kernel_size
self.head_dim = head_dim
self.qkv_proj = torch.nn.Linear(in_dim, 3*head_dim*heads, bias=False, **factory_kwargs)
@torch.compile(fullgraph=True)
def forward(self, x: Tensor):
qkv = self.qkv_proj(x)
q, k, v = rearrange(qkv, "... h w (proj heads head_dim) -> proj ... h w heads head_dim", proj=3, head_dim=self.head_dim)
return na2d(q, k, v, self.kernel_size)
dtype = torch.float16
device = torch.device('cuda')
bsz = 2
inner_dim = 320
wid = 32
hei = 32
gen = torch.Generator(device=device).manual_seed(42)
x = torch.randn(bsz, hei, wid, inner_dim, device=device, dtype=dtype, generator=gen)
attn = Attention(
in_dim=inner_dim,
heads=5,
head_dim=64,
kernel_size=7,
device=device,
dtype=dtype,
)
try:
with inference_mode():
attn(x)
except Unsupported as e:
if 'graph break' in e.msg.lower():
raise AssertionError('Test failure')
raise e
print('Test success') |
and if you're thinking of using torch FlopCounterMode yourself in order to benchmark NATTEN, beware this gotcha about how FlopCounterMode seems to make torch.compile fall back to eager mode in torch 2.5+: |
@Birch-san thank you so much for the feedback. I'll check the FLOPs vs MACs issue -- I remember checking this and finding out fvcore computed one and torch the other, and specifically adjusted everything according to that. I think I even used torch's Thanks for the feedback on The only thing that worries me is that given fvcore and torch report different metrics (flops v macs), I might have to rename a few things to make them less confusing. I'll have to add some documentation for that. And thanks for the unit test idea, and verifying compilation isn't breaking the graph. I think we should be ready to merge this in soon. |
Okay according to a dummy example I set up, I'm getting flops reported by torch and fvcore as follows:
After checking Back to the discrepancy you were observing, I also noticed that you also used torch's
UPDATE: sorry, I just caught a mistake I made in the new flop count. I'll push the fix soon. |
+ Separate FLOPs doc
@Birch-san Turns out the "QK" part of the new FLOP counter was wrong. I didn't notice because my test case was using identical kernel size and feature map sizes. I just pushed a commit doing what should've been done from the get go, which is a major refactor so that whether we use fvcore, torch, or manually count flops, it all ends up calling the same underlying API, as to prevent mistakes due to code duplication. Now the FLOP counting is correct, experimental ops are using I'll work on adding those extra unit tests now, but already verified it's working as expected. |
+ Add interface for using experimental ops within NATTEN modules. + Update docs
ah, awesome! glad you found the QK thing. |
This should be ready to land now -- I'll wait until you get a chance to confirm the FLOPs from your use case are correct now @Birch-san. |
great work @alihassanijr; I've pulled the changes, deleted my own manual op wrappers and manual flop counter registration, and just invoked NATTEN experimental without tricks, like in the unit test (#189 (comment)). the assertion passes (it's the FLOP count I expect). so it's now FLOPs instead of MACs. and I also tried it on my NATTEN'd SDXL VAE decoder benchmark script, the official NATTEN FLOP counter reports 36.541B FLOPs, just like my FLOP counter did (#189 (comment)). so yeah I think that's ready for merge. 🙂 |
Awesome! Thanks again for your feedback, and bringing this up! Merging. |
See #184
Only supports forward pass for now, due to current limitations of registering custom ops with torch compared to autograd functions. Some of those limitations are:
No stable interface for supporting autocasting to fp16/bf16,
Ops cannot indicate that they expect contiguous operands, and need to call
.contiguous()
within, and this incurs additional tensor copy costs, and brings down throughput (in some cases it's hard to even tell the difference between compiled and eager.)torch.no_grad
, compiled graph isn't dumped to file withTORCH_COMPILE_DEBUG=1
, but logs and assertions confirm it's working)