12
12
import analyze
13
13
from analyze import Logs
14
14
15
+ # Some reference point to compute occupancies against.
16
+ # This would ideally be the maximum possible occupancy so that the .cost property will never be negative
17
+ OCCUPANCY_REFERENCE_POINT = 10
18
+
19
+ SPILL_COST_WEIGHT = 0
20
+
15
21
16
22
@dataclass
17
23
class DagInfo :
@@ -24,10 +30,20 @@ class DagInfo:
24
30
relative_cost : int
25
31
length : int
26
32
is_optimal : bool
33
+ # Spill cost is not absolute for SCF = TARGET. By recording the baseline, we can adjust the costs.
34
+ target_occupancy : Optional [int ]
35
+ spill_cost : int
27
36
28
37
@property
29
38
def cost (self ):
30
- return self .lower_bound + self .relative_cost
39
+ cost = self .lower_bound + self .relative_cost
40
+ if self .target_occupancy is not None :
41
+ # TargetOcc - SC is a "complement"-like operation, meaning that it undoes itself.
42
+ actual_occupancy = self .target_occupancy - self .spill_cost
43
+ absolute_spill_cost = OCCUPANCY_REFERENCE_POINT - actual_occupancy
44
+ cost += SPILL_COST_WEIGHT * absolute_spill_cost
45
+
46
+ return cost
31
47
32
48
33
49
class MismatchKind (Enum ):
@@ -133,6 +149,8 @@ def extract_dag_info(logs: Logs) -> Dict[str, List[List[DagInfo]]]:
133
149
print (block .raw_log )
134
150
exit (2 )
135
151
152
+ target_occ = block .single ('TargetOccupancy' )['target' ] if 'TargetOccupancy' in block else None
153
+
136
154
dags .setdefault (block .name , []).append (DagInfo (
137
155
id = block .name ,
138
156
benchmark = block .benchmark ,
@@ -142,6 +160,8 @@ def extract_dag_info(logs: Logs) -> Dict[str, List[List[DagInfo]]]:
142
160
relative_cost = best_result ['cost' ],
143
161
length = best_result ['length' ],
144
162
is_optimal = is_optimal ,
163
+ spill_cost = best_result ['spill_cost' ],
164
+ target_occupancy = target_occ ,
145
165
))
146
166
147
167
for k , block_passes in dags .items ():
@@ -338,6 +358,8 @@ def print_small_summary(mismatches: List[Mismatch]):
338
358
339
359
parser .add_argument ('-q' , '--quiet' , action = 'store_true' ,
340
360
help = 'Only print mismatch info, and only if there are mismatches' )
361
+ parser .add_argument ('--scw' , '--spill-cost-weight' , type = int , required = True ,
362
+ help = 'The weight of the spill cost in the cost calculation. Only relevant if the reported spill costs are not absolute (e.g. SCF = TARGET); put any value otherwise.' , dest = 'spill_cost_weight' , metavar = 'SCW' )
341
363
parser .add_argument ('--no-summarize-largest-cost-difference' , action = 'store_true' ,
342
364
help = 'Do not summarize the mismatches with the biggest difference in cost' )
343
365
parser .add_argument ('--no-summarize-smallest-mismatches' , action = 'store_true' ,
@@ -358,6 +380,7 @@ def print_small_summary(mismatches: List[Mismatch]):
358
380
NUM_SMALLEST_BLOCKS_PRINT = args .num_smallest_mismatches_print
359
381
MISSING_LOWER_BOUND_DUMP_COUNT = args .missing_lb_dump_count
360
382
MISSING_LOWER_BOUND_DUMP_LINES = args .missing_lb_dump_lines
383
+ SPILL_COST_WEIGHT = args .spill_cost_weight
361
384
362
385
main (
363
386
args .first , args .second ,
0 commit comments