forked from raoyongming/DynamicViT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_flops.py
60 lines (53 loc) · 1.64 KB
/
calc_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import torch
from numbers import Number
from typing import Any, Callable, List, Optional, Union
from numpy import prod
import numpy as np
from fvcore.nn import FlopCountAnalysis
def rfft_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the rfft/rfftn operator.
"""
input_shape = inputs[0].type().sizes()
B, H, W, C = input_shape
N = H * W
flops = N * C * np.ceil(np.log2(N))
return flops
def calc_flops(model, img_size=224, show_details=False, ratios=None):
with torch.no_grad():
x = torch.randn(1, 3, img_size, img_size)
model.default_ratio = ratios
fca1 = FlopCountAnalysis(model, x)
handlers = {
'aten::fft_rfft2': rfft_flop_jit,
'aten::fft_irfft2': rfft_flop_jit,
}
fca1.set_op_handle(**handlers)
flops1 = fca1.total()
if show_details:
print(fca1.by_module())
print("#### GFLOPs: {} for ratio {}".format(flops1 / 1e9, ratios))
return flops1 / 1e9
@torch.no_grad()
def throughput(images, model):
model.eval()
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
model(images)
torch.cuda.synchronize()
print(f"throughput averaged with 30 times")
tic1 = time.time()
for i in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
print(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
MB = 1024.0 * 1024.0
print('memory:', torch.cuda.max_memory_allocated() / MB)