Skip to content

Commit 7010a11

Browse files
authored
Enable mypy lintrunner, Part 3 (profiler/*)
Differential Revision: D67807621 Pull Request resolved: #7494
1 parent ae3d558 commit 7010a11

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

.lintrunner.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ include_patterns = [
299299
# 'exir/**/*.py',
300300
# 'extension/**/*.py',
301301
'kernels/**/*.py',
302-
# 'profiler/**/*.py',
302+
'profiler/**/*.py',
303303
'runtime/**/*.py',
304304
'scripts/**/*.py',
305305
# 'test/**/*.py',
@@ -310,6 +310,7 @@ exclude_patterns = [
310310
'third-party/**',
311311
'**/third-party/**',
312312
'scripts/check_binary_dependencies.py',
313+
'profiler/test/test_profiler_e2e.py',
313314
]
314315
command = [
315316
'python',

.mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ follow_untyped_imports = True
4040
[mypy-executorch.kernels.*]
4141
follow_untyped_imports = True
4242

43+
[mypy-executorch.profiler.*]
44+
follow_untyped_imports = True
45+
4346
[mypy-executorch.runtime.*]
4447
follow_untyped_imports = True
4548

profiler/parse_profiler_results.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from collections import OrderedDict
1010
from enum import Enum
1111

12-
from typing import Dict, List, Tuple
12+
from typing import Dict, List, Optional, Tuple
1313

14-
from prettytable import PrettyTable
14+
from prettytable import PrettyTable # type: ignore[import-not-found]
1515

1616
# This version number should match the one defined in profiler.h
1717
ET_PROF_VER = 0x00000001
@@ -89,8 +89,7 @@ class ProfileEvent:
8989
duration: List[float]
9090
chain_idx: int = -1
9191
instruction_idx: int = -1
92-
# pyre-ignore[8]: Incompatible attribute type
93-
stacktrace: str = None
92+
stacktrace: Optional[str] = None
9493

9594

9695
@dataclasses.dataclass
@@ -134,8 +133,8 @@ def parse_prof_blocks(
134133

135134
# Iterate through all the profiling blocks data that have been grouped by name.
136135
for name, data_list in prof_blocks.items():
137-
prof_data_list = []
138-
mem_prof_data_list = []
136+
prof_data_list: List[ProfileEvent] = []
137+
mem_prof_data_list: List[MemAllocation] = []
139138
# Each entry in data_list is a tuple in which the first entry is profiling data
140139
# and the second entry is memory allocation data, also each entry in data_list
141140
# represents one iteration of a code block.
@@ -168,13 +167,13 @@ def parse_prof_blocks(
168167

169168
# Group all the memory allocation events based on the allocator they were
170169
# allocated from.
171-
alloc_sum_dict = OrderedDict()
170+
alloc_sum_dict: OrderedDict[int, int] = OrderedDict()
172171
for alloc in mem_prof_data_list:
173172
alloc_sum_dict[alloc.allocator_id] = (
174173
alloc_sum_dict.get(alloc.allocator_id, 0) + alloc.allocation_size
175174
)
176175

177-
mem_prof_sum_list = []
176+
mem_prof_sum_list: List[MemEvent] = []
178177
for allocator_id, allocation_size in alloc_sum_dict.items():
179178
mem_prof_sum_list.append(
180179
MemEvent(allocator_dict[allocator_id], allocation_size)
@@ -243,7 +242,9 @@ def deserialize_profile_results(
243242
prof_allocator_struct_size = struct.calcsize(ALLOCATOR_STRUCT_FMT)
244243
prof_allocation_struct_size = struct.calcsize(ALLOCATION_STRUCT_FMT)
245244
prof_result_struct_size = struct.calcsize(PROF_RESULT_STRUCT_FMT)
246-
prof_blocks = OrderedDict()
245+
prof_blocks: OrderedDict[
246+
str, List[Tuple[List[ProfileData], List[MemAllocation]]]
247+
] = OrderedDict()
247248
allocator_dict = {}
248249
base_offset = 0
249250

@@ -375,19 +376,19 @@ def profile_aggregate_framework_tax(
375376
prof_framework_tax = OrderedDict()
376377

377378
for name, prof_data_list in prof_data.items():
378-
execute_max = []
379-
kernel_and_delegate_sum = []
379+
execute_max: List[int] = []
380+
kernel_and_delegate_sum: List[int] = []
380381

381382
for d in prof_data_list:
382383
if "Method::execute" in d.name:
383-
execute_max = max(execute_max, d.duration)
384+
execute_max = max(execute_max, d.duration) # type: ignore[arg-type]
384385

385386
if "native_call" in d.name or "delegate_execute" in d.name:
386387
for idx in range(len(d.duration)):
387388
if idx < len(kernel_and_delegate_sum):
388-
kernel_and_delegate_sum[idx] += d.duration[idx]
389+
kernel_and_delegate_sum[idx] += d.duration[idx] # type: ignore[call-overload]
389390
else:
390-
kernel_and_delegate_sum.append(d.duration[idx])
391+
kernel_and_delegate_sum.append(d.duration[idx]) # type: ignore[arg-type]
391392

392393
if len(execute_max) == 0 or len(kernel_and_delegate_sum) == 0:
393394
continue
@@ -408,10 +409,9 @@ def profile_aggregate_framework_tax(
408409

409410
def profile_framework_tax_table(
410411
prof_framework_tax_data: Dict[str, ProfileEventFrameworkTax]
411-
):
412-
tables = []
412+
) -> List[PrettyTable]:
413+
tables: List[PrettyTable] = []
413414
for name, prof_data_list in prof_framework_tax_data.items():
414-
tables = []
415415
table_agg = PrettyTable()
416416
table_agg.title = name + " framework tax calculations"
417417

profiler/profiler_results_cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import argparse
88
import sys
99

10-
from executorch.profiler.parse_profiler_results import (
10+
from executorch.profiler.parse_profiler_results import ( # type: ignore[import-not-found]
1111
deserialize_profile_results,
1212
mem_profile_table,
1313
profile_aggregate_framework_tax,

0 commit comments

Comments
 (0)