diff --git a/slothy/core/config.py b/slothy/core/config.py index a653203..c02e351 100644 --- a/slothy/core/config.py +++ b/slothy/core/config.py @@ -452,11 +452,12 @@ def objective_precision(self): @property def has_objective(self): - """Indicates whether a secondary objective (beyond minimization of stalls) + """Indicates whether a different objective than minimization of stalls has been registered.""" objectives = sum([self.sw_pipelining.enabled and self.sw_pipelining.minimize_overlapping is True, self.constraints.maximize_register_lifetimes is True, + self.constraints.minimize_spills is True, self.constraints.move_stalls_to_top is True, self.constraints.move_stalls_to_bottom is True, self.constraints.minimize_register_usage is not None, @@ -896,6 +897,27 @@ def allow_renaming(self): in order to find the number of model violations in a piece of code.""" return self._allow_renaming + @property + def allow_spills(self): + """Allow Slothy to introduce stack spills + + When this option is enabled, Slothy will consider the introduction + of stack spills to reduce register pressure. + + This option should only be disabled if it is known that the input + assembly suffers from high register pressure. For example, this can + be the case for symbolic input assembly.""" + return self._allow_spills + + @property + def minimize_spills(self): + """Minimize number of stack spills + + When this option is enabled, the Slothy will pass minimization of + stack spills as the optimization objective to the solver. + """ + return self._minimize_spills + @property def max_displacement(self): """The maximum relative displacement of an instruction. @@ -926,7 +948,7 @@ def __init__(self): self._max_displacement = 1.0 self.maximize_register_lifetimes = False - + self.minimize_spills = False self.move_stalls_to_top = False self.move_stalls_to_bottom = False self.minimize_register_usage = None @@ -944,6 +966,7 @@ def __init__(self): self._model_functional_units = True self._allow_reordering = True self._allow_renaming = True + self._allow_spills = False self.lock() @@ -980,6 +1003,12 @@ def allow_reordering(self,val): @allow_renaming.setter def allow_renaming(self,val): self._allow_renaming = val + @allow_spills.setter + def allow_spills(self,val): + self._allow_spills = val + @minimize_spills.setter + def minimize_spills(self,val): + self._minimize_spills = val @functional_only.setter def functional_only(self,val): self._model_latencies = val is False diff --git a/slothy/core/core.py b/slothy/core/core.py index 75e21d9..7f356cb 100644 --- a/slothy/core/core.py +++ b/slothy/core/core.py @@ -587,6 +587,20 @@ def _get_code(self, visualize_reordering): core_char = self.config.core_char d = self.config.placeholder_char + def gen_restore(reg, loc, vis): + yield SourceLine(self.config.arch.Stack.restore(reg, loc)).\ + set_length(self.fixlen).\ + set_comment(vis).\ + add_tag("is_restore", True).\ + add_tag("reads", f"stack_{loc}") + + def gen_spill(reg, loc, vis): + yield SourceLine(self.config.arch.Stack.spill(reg, loc)).\ + set_length(self.fixlen).\ + set_comment(vis).\ + add_tag("is_spill", True).\ + add_tag("writes", f"stack_{loc}") + def center_str_fixlen(txt, fixlen, char='-'): txt = ' ' + txt + ' ' l = min(len(txt), fixlen) @@ -614,6 +628,8 @@ def gen_vis(p, c): yield SourceLine("").set_comment(legend2).set_length(fixlen) for i in range(self.codesize_with_bubbles): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) if p is not None: c = core_char if self.is_pre(p): @@ -622,7 +638,11 @@ def gen_vis(p, c): c = late_char s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(gen_vis(p, c)) - if p is None: + for (reg, loc) in restores: + yield from gen_restore(reg, loc, gen_vis(0, d)) + for (reg, loc) in spills: + yield from gen_spill(reg, loc, gen_vis(0, d)) + if p is None and len(s) == 0 and len(r) == 0: gap_str = "gap" yield SourceLine("") \ .set_comment(f"{gap_str:{fixlen-4}s}") \ @@ -648,6 +668,8 @@ def gen_vis(cc, c): yield SourceLine("").set_comment(legend2).set_length(fixlen) for i in range(cs): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) cc = i // self.config.target.issue_rate if p is not None: c = core_char @@ -657,6 +679,10 @@ def gen_vis(cc, c): c = late_char s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(gen_vis(cc, c)) + for (reg, loc) in restores: + yield from gen_restore(reg, loc, gen_vis(cc, "r")) + for (reg, loc) in spills: + yield from gen_spill(reg, loc, gen_vis(cc, "s")) def gen_visualized_code_with_old(): orig_code = self._orig_code @@ -677,9 +703,15 @@ def gen_visualized_code_with_old(): for i in range(self.codesize_with_bubbles): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) if p is not None: s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(f"{old_code[i]:{old_maxlen + 8}s}") + for (reg, loc) in restores: + yield from gen_restore(reg, loc, f"{'-':{old_maxlen+8}s}") + for (reg, loc) in spills: + yield from gen_spill(reg, loc, f"{'-':{old_maxlen+8}s}") def gen_visualized_code(): if self.config.visualize_expected_performance is True: @@ -730,7 +762,8 @@ def gen_visualized_code(): @property def code(self): """The optimized source code""" - return self._get_code(self.config.visualize_reordering) + c = self._get_code(self.config.visualize_reordering) + return c @code.setter def code(self, val): @@ -1274,6 +1307,8 @@ def __init__(self, config): self._codesize_with_bubbles = None self._optimization_wall_time = None self._optimization_user_time = None + self._spills = {} + self._restores = {} self.lock() @@ -1502,7 +1537,7 @@ def _init_model_internals(self): self._model.intervals_for_unit = { k : [] for k in self.target.ExecutionUnit } self._model.register_usages = {} self._model.register_usage_vars = {} - + self._model.spill_vars = [] self._model.variables = [] def _usage_check(self): @@ -1640,7 +1675,12 @@ def _dump_avail_renaming_registers(self): for ty in self.arch.RegisterType: self.logger.input.debug(f"- {ty} available: {self._model.avail_renaming_regs[ty]}") - def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var): + def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var, spill_var): + + spill_ival_active = self._NewBoolVar("") + self._AddMultiplicationEquality([var, spill_var], spill_ival_active) + spill_point_start = self._NewOptionalIntervalVar(t.program_start_var, + 1, t.program_start_var + 1, spill_ival_active, "") interval = self._NewOptionalIntervalVar( start_var, dur_var, end_var, var, f"Usage({t.inst})({reg})<{var}>") @@ -1651,6 +1691,7 @@ def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var): self._model.register_usages.setdefault(reg, []) self._model.register_usages[reg].append(interval) + self._model.register_usages[reg].append(spill_point_start) if var is None: return @@ -1831,6 +1872,7 @@ def _extract_result(self): self._extract_register_renamings(get_value) self._extract_input_output_renaming() + self._extract_spills() self._extract_code() self._result.selfcheck_with_fixup(self.logger.getChild("selfcheck")) self._result.offset_fixup(self.logger.getChild("fixup")) @@ -1920,6 +1962,10 @@ def _extract_true_key(var_dict): t.inst.args_out = [ _extract_true_key(vars) for vars in t.alloc_out_var ] t.inst.args_in = [ _extract_true_key(vars) for vars in t.alloc_in_var ] t.inst.args_in_out = [ _extract_true_key(vars) for vars in t.alloc_in_out_var ] + t.out_spills = [(get_value(v) == 1) for v in t.out_spill_vars ] + t.in_out_spills = [(get_value(v) == 1) for v in t.in_out_spill_vars ] + t.out_lifetime_start = list(map(get_value, t.out_lifetime_start)) + t.inout_lifetime_start = list(map(get_value, t.inout_lifetime_start)) def _dump_renaming(name,lst,inst): for idx, reg in enumerate(lst): self.logger.debug("%s %s of '%s' renamed to %s", name, idx, inst, reg) @@ -1939,6 +1985,44 @@ def _extract_kernel_input_output(self): DFG(self._result.code_raw, dfg_log, DFGConfig(conf,inputs_are_outputs=True)).inputs) + def _extract_spills(self): + if self.config.constraints.allow_spills is None: + return + # Extract spills and restores with textual spill identifiers + spills = self._result._spills + restores = self._result._restores + for t in self._model.tree.nodes: + p = t.real_pos_program + def remember_spill(i, spilled, restore, arg, txt): + if spilled is True: + spills.setdefault(p, []) + restores.setdefault(restore, []) + spill_id = f"{txt}_{p}_{i}" + spills[p].append((arg, spill_id)) + restores[restore].append((arg, spill_id)) + for (i, (spilled, restore, arg)) in enumerate(zip(t.out_spills, t.out_lifetime_start, t.inst.args_out, strict=True)): + remember_spill(i, spilled, restore, arg, "out") + for (i, (spilled, restore, arg)) in enumerate(zip(t.in_out_spills, t.inout_lifetime_start, t.inst.args_in_out, strict=True)): + remember_spill(i, spilled, restore, arg, "inout") + # Map textual spill identifiers to numerical locations + free_locs = set(range(64)) + spill_id_to_loc = {} + m = max(list(spills.keys()) + list(restores.keys()), default=0) + 1 + for i in range(m): + s = spills.get(i, []) + r = restores.get(i, []) + for j in range(len(r)): + (reg, spill_id) = r[j] + loc = spill_id_to_loc[spill_id] + free_locs.add(loc) + r[j] = (reg, loc) + for j in range(len(s)): + (reg, spill_id) = s[j] + loc = min(free_locs) + free_locs.remove(loc) + spill_id_to_loc[spill_id] = loc + s[j] = (reg, loc) + def _extract_code(self): def get_code(filter_func=None, top=False): @@ -2354,19 +2438,22 @@ def add_arg_combination_vars(combinations, vs, name, t=t): for t in self._get_nodes(allnodes=True): self.logger.debug("Create register usage intervals for %s", t) - + t.out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_out ] + t.in_out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_in_out ] ivals = [] ivals += list(zip(t.inst.arg_types_out, t.alloc_out_var, t.out_lifetime_start, t.out_lifetime_duration, - t.out_lifetime_end, strict=True)) + t.out_lifetime_end, t.out_spill_vars, strict=True)) ivals += list(zip(t.inst.arg_types_in_out, t.alloc_in_out_var, t.inout_lifetime_start, t.inout_lifetime_duration, - t.inout_lifetime_end, strict=True)) + t.inout_lifetime_end, t.in_out_spill_vars, strict=True)) - for arg_ty, var_dict, start_var, dur_var, end_var in ivals: + for arg_ty, var_dict, start_var, dur_var, end_var, spill_var in ivals: + self._model.spill_vars.append(spill_var) + self._Add(start_var == t.program_start_var).OnlyEnforceIf(spill_var.Not()) for reg, var in var_dict.items(): self._add_register_usage(t, reg, arg_ty, var, - start_var, dur_var, end_var) + start_var, dur_var, end_var, spill_var) # ================================================================ # VARIABLES (Loop rolling) # @@ -2471,7 +2558,6 @@ def _add_basic_constraints(start_list, end_list): # one instruction. Otherwise, instructions producing outputs that # are never used would be able to overwrite life registers. self._Add(end_var > t.program_start_var) - self._Add(start_var == t.program_start_var) _add_basic_constraints(t.out_lifetime_start, t.out_lifetime_end) _add_basic_constraints(t.inout_lifetime_start, t.inout_lifetime_end) @@ -2482,11 +2568,14 @@ def _add_constraints_lifetime_bounds(self): self._add_constraints_lifetime_bounds_single(t) # For every instruction depending on the output, add a lifetime bound - for (consumer, producer, _, _, _, end_var, _) in \ + for (consumer, producer, _, _, start_var, end_var, _) in \ self._iter_dependencies_with_lifetime(): self._add_path_constraint(consumer, producer.src, lambda end_var=end_var, consumer=consumer: self._Add(end_var >= consumer.program_start_var)) + self._add_path_constraint(consumer, producer.src, + lambda start_var=start_var, consumer=consumer: + self._Add(start_var < consumer.program_start_var)) # ================================================================ # CONSTRAINTS (Register allocation) # @@ -2549,6 +2638,23 @@ def _add_constraints_register_renaming(self): else: self._Add(self._model._register_used[reg] == False) + for t in self._get_nodes(allnodes=True): + can_spill = True + if t.is_virtual is True: + can_spill = False + if self.config.constraints.allow_spills is False: + can_spill = False + + if can_spill is False: + for v in t.out_spill_vars + t.in_out_spill_vars: + self._Add(v == False) + continue + + for (ty, v) in list(zip(t.inst.arg_types_out, t.out_spill_vars)) + \ + list(zip(t.inst.arg_types_in_out, t.out_spill_vars)): + if self.arch.RegisterType.spillable(ty) is False: + self._Add(v == False) + # Ensure that outputs are unambiguous for t in self._get_nodes(allnodes=True): self.logger.debug("Ensure unambiguous register renaming for %s", str(t.inst)) @@ -3097,6 +3203,26 @@ def _print_stalls(self, stalls): (cycles, ipc) = r return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})" + def _print_stalls_and_spills(self, bound, variables): + + if "stalls" in variables.keys(): + stalls = variables["stalls"] + # TODO: This needs fixing once the objective measure is no longer stalls + spills + spills = bound - stalls + else: + stalls = bound + spills = None + + r = self._stalls_to_stats(stalls) + if r is None: + return " (?)" + (cycles, ipc) = r + + if spills is not None: + return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f}, Spills {spills})" + else: + return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})" + def _add_objective(self, force_objective=False): minlist = [] maxlist = [] @@ -3108,8 +3234,19 @@ def _add_objective(self, force_objective=False): if force_objective is False and self.config.variable_size: name = "minimize cycles" if self.config.constraints.functional_only is False: - printer = self._print_stalls + if self.config.constraints.minimize_spills is False: + printer = self._print_stalls + else: + printer = self._print_stalls_and_spills + objective_vars = { "stalls" : self._model.stalls } minlist = [self._model.stalls] + + # Spill minimization is integrated into the variable-size optimization, + # currently by counting every spill as 1 stall + # + # TODO: This needs refining! + if self.config.constraints.minimize_spills: + minlist += self._model.spill_vars elif self.config.has_objective and not self.config.ignore_objective: if self.config.sw_pipelining.enabled is True and \ self.config.sw_pipelining.minimize_overlapping is True: @@ -3117,6 +3254,9 @@ def _add_objective(self, force_objective=False): corevars = [ t.core_var.Not() for t in self._get_nodes(low=True) ] minlist = corevars name = "minimize iteration overlapping" + elif self.config.constraints.minimize_spills: + name = "minimize spills" + minlist = [self._model.spill_vars] elif self.config.constraints.maximize_register_lifetimes: name = "maximize register lifetimes" maxlist = [ v for t in self._get_nodes(allnodes=True) diff --git a/slothy/core/heuristics.py b/slothy/core/heuristics.py index 2361b24..c291d9c 100644 --- a/slothy/core/heuristics.py +++ b/slothy/core/heuristics.py @@ -258,7 +258,8 @@ def optimize_binsearch_internal(source, logger, conf, flexible=True, **kwargs): logger.info(f"Minimum number of stalls: {min_stalls}") - if conf.has_objective is False: + # Spill minimization is integrated into the stall minimization objective + if conf.has_objective is False or conf.constraints.minimize_spills is True: return core.result logger.info("Optimize again with minimal number of %d stalls, with objective...", diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py index 79c24b2..a19ce37 100644 --- a/slothy/targets/aarch64/aarch64_neon.py +++ b/slothy/targets/aarch64/aarch64_neon.py @@ -61,6 +61,11 @@ def __str__(self): def __repr__(self): return self.name + @cache + @staticmethod + def spillable(reg_type): + return reg_type in [RegisterType.GPR, RegisterType.NEON] + @cache @staticmethod def list_registers(reg_type, only_extra=False, only_normal=False, with_variants=False): @@ -3025,6 +3030,15 @@ class cmge(ASimdCompare): # pylint: disable=missing-docstring,invalid-name inputs = ["Va", "Vb"] outputs = ["Vd"] + +class Stack: + def spill(reg, loc): + # TODO: Use store instruction + return f"str {reg}, [sp, #STACK_LOC_{loc}]" + def restore(reg, loc): + # TODO: Use load instruction + return f"ldr {reg}, [sp, #STACK_LOC_{loc}]" + # In a pair of vins writing both 64-bit lanes of a vector, mark the # target vector as output rather than input/output. This enables further # renaming opportunities.