From 7e929e4de383fc9995a9465e4365a016e863bfc2 Mon Sep 17 00:00:00 2001 From: Hanno Becker Date: Sat, 3 Aug 2024 06:13:43 +0100 Subject: [PATCH] Extend support for symbolic assembly via spilling (#81) SLOTHY has long supported assembly which uses symbolic register names instead of architectural ones. This commit improves the compatibility of this feature with both (a) complex workloads using numerous symbolic regisrters, and (b) architectures offering only a small register file, by adding experimental support for a simplified form of register spilling. --- slothy/core/config.py | 33 ++++- slothy/core/core.py | 164 +++++++++++++++++++++++-- slothy/core/heuristics.py | 3 +- slothy/targets/aarch64/aarch64_neon.py | 14 +++ 4 files changed, 199 insertions(+), 15 deletions(-) diff --git a/slothy/core/config.py b/slothy/core/config.py index a653203..c02e351 100644 --- a/slothy/core/config.py +++ b/slothy/core/config.py @@ -452,11 +452,12 @@ def objective_precision(self): @property def has_objective(self): - """Indicates whether a secondary objective (beyond minimization of stalls) + """Indicates whether a different objective than minimization of stalls has been registered.""" objectives = sum([self.sw_pipelining.enabled and self.sw_pipelining.minimize_overlapping is True, self.constraints.maximize_register_lifetimes is True, + self.constraints.minimize_spills is True, self.constraints.move_stalls_to_top is True, self.constraints.move_stalls_to_bottom is True, self.constraints.minimize_register_usage is not None, @@ -896,6 +897,27 @@ def allow_renaming(self): in order to find the number of model violations in a piece of code.""" return self._allow_renaming + @property + def allow_spills(self): + """Allow Slothy to introduce stack spills + + When this option is enabled, Slothy will consider the introduction + of stack spills to reduce register pressure. + + This option should only be disabled if it is known that the input + assembly suffers from high register pressure. For example, this can + be the case for symbolic input assembly.""" + return self._allow_spills + + @property + def minimize_spills(self): + """Minimize number of stack spills + + When this option is enabled, the Slothy will pass minimization of + stack spills as the optimization objective to the solver. + """ + return self._minimize_spills + @property def max_displacement(self): """The maximum relative displacement of an instruction. @@ -926,7 +948,7 @@ def __init__(self): self._max_displacement = 1.0 self.maximize_register_lifetimes = False - + self.minimize_spills = False self.move_stalls_to_top = False self.move_stalls_to_bottom = False self.minimize_register_usage = None @@ -944,6 +966,7 @@ def __init__(self): self._model_functional_units = True self._allow_reordering = True self._allow_renaming = True + self._allow_spills = False self.lock() @@ -980,6 +1003,12 @@ def allow_reordering(self,val): @allow_renaming.setter def allow_renaming(self,val): self._allow_renaming = val + @allow_spills.setter + def allow_spills(self,val): + self._allow_spills = val + @minimize_spills.setter + def minimize_spills(self,val): + self._minimize_spills = val @functional_only.setter def functional_only(self,val): self._model_latencies = val is False diff --git a/slothy/core/core.py b/slothy/core/core.py index 75e21d9..7f356cb 100644 --- a/slothy/core/core.py +++ b/slothy/core/core.py @@ -587,6 +587,20 @@ def _get_code(self, visualize_reordering): core_char = self.config.core_char d = self.config.placeholder_char + def gen_restore(reg, loc, vis): + yield SourceLine(self.config.arch.Stack.restore(reg, loc)).\ + set_length(self.fixlen).\ + set_comment(vis).\ + add_tag("is_restore", True).\ + add_tag("reads", f"stack_{loc}") + + def gen_spill(reg, loc, vis): + yield SourceLine(self.config.arch.Stack.spill(reg, loc)).\ + set_length(self.fixlen).\ + set_comment(vis).\ + add_tag("is_spill", True).\ + add_tag("writes", f"stack_{loc}") + def center_str_fixlen(txt, fixlen, char='-'): txt = ' ' + txt + ' ' l = min(len(txt), fixlen) @@ -614,6 +628,8 @@ def gen_vis(p, c): yield SourceLine("").set_comment(legend2).set_length(fixlen) for i in range(self.codesize_with_bubbles): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) if p is not None: c = core_char if self.is_pre(p): @@ -622,7 +638,11 @@ def gen_vis(p, c): c = late_char s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(gen_vis(p, c)) - if p is None: + for (reg, loc) in restores: + yield from gen_restore(reg, loc, gen_vis(0, d)) + for (reg, loc) in spills: + yield from gen_spill(reg, loc, gen_vis(0, d)) + if p is None and len(s) == 0 and len(r) == 0: gap_str = "gap" yield SourceLine("") \ .set_comment(f"{gap_str:{fixlen-4}s}") \ @@ -648,6 +668,8 @@ def gen_vis(cc, c): yield SourceLine("").set_comment(legend2).set_length(fixlen) for i in range(cs): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) cc = i // self.config.target.issue_rate if p is not None: c = core_char @@ -657,6 +679,10 @@ def gen_vis(cc, c): c = late_char s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(gen_vis(cc, c)) + for (reg, loc) in restores: + yield from gen_restore(reg, loc, gen_vis(cc, "r")) + for (reg, loc) in spills: + yield from gen_spill(reg, loc, gen_vis(cc, "s")) def gen_visualized_code_with_old(): orig_code = self._orig_code @@ -677,9 +703,15 @@ def gen_visualized_code_with_old(): for i in range(self.codesize_with_bubbles): p = ri.get(i, None) + spills = self._spills.get(i, []) + restores = self._restores.get(i, []) if p is not None: s = code[self.periodic_reordering[p]] yield s.copy().set_length(fixlen).set_comment(f"{old_code[i]:{old_maxlen + 8}s}") + for (reg, loc) in restores: + yield from gen_restore(reg, loc, f"{'-':{old_maxlen+8}s}") + for (reg, loc) in spills: + yield from gen_spill(reg, loc, f"{'-':{old_maxlen+8}s}") def gen_visualized_code(): if self.config.visualize_expected_performance is True: @@ -730,7 +762,8 @@ def gen_visualized_code(): @property def code(self): """The optimized source code""" - return self._get_code(self.config.visualize_reordering) + c = self._get_code(self.config.visualize_reordering) + return c @code.setter def code(self, val): @@ -1274,6 +1307,8 @@ def __init__(self, config): self._codesize_with_bubbles = None self._optimization_wall_time = None self._optimization_user_time = None + self._spills = {} + self._restores = {} self.lock() @@ -1502,7 +1537,7 @@ def _init_model_internals(self): self._model.intervals_for_unit = { k : [] for k in self.target.ExecutionUnit } self._model.register_usages = {} self._model.register_usage_vars = {} - + self._model.spill_vars = [] self._model.variables = [] def _usage_check(self): @@ -1640,7 +1675,12 @@ def _dump_avail_renaming_registers(self): for ty in self.arch.RegisterType: self.logger.input.debug(f"- {ty} available: {self._model.avail_renaming_regs[ty]}") - def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var): + def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var, spill_var): + + spill_ival_active = self._NewBoolVar("") + self._AddMultiplicationEquality([var, spill_var], spill_ival_active) + spill_point_start = self._NewOptionalIntervalVar(t.program_start_var, + 1, t.program_start_var + 1, spill_ival_active, "") interval = self._NewOptionalIntervalVar( start_var, dur_var, end_var, var, f"Usage({t.inst})({reg})<{var}>") @@ -1651,6 +1691,7 @@ def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var): self._model.register_usages.setdefault(reg, []) self._model.register_usages[reg].append(interval) + self._model.register_usages[reg].append(spill_point_start) if var is None: return @@ -1831,6 +1872,7 @@ def _extract_result(self): self._extract_register_renamings(get_value) self._extract_input_output_renaming() + self._extract_spills() self._extract_code() self._result.selfcheck_with_fixup(self.logger.getChild("selfcheck")) self._result.offset_fixup(self.logger.getChild("fixup")) @@ -1920,6 +1962,10 @@ def _extract_true_key(var_dict): t.inst.args_out = [ _extract_true_key(vars) for vars in t.alloc_out_var ] t.inst.args_in = [ _extract_true_key(vars) for vars in t.alloc_in_var ] t.inst.args_in_out = [ _extract_true_key(vars) for vars in t.alloc_in_out_var ] + t.out_spills = [(get_value(v) == 1) for v in t.out_spill_vars ] + t.in_out_spills = [(get_value(v) == 1) for v in t.in_out_spill_vars ] + t.out_lifetime_start = list(map(get_value, t.out_lifetime_start)) + t.inout_lifetime_start = list(map(get_value, t.inout_lifetime_start)) def _dump_renaming(name,lst,inst): for idx, reg in enumerate(lst): self.logger.debug("%s %s of '%s' renamed to %s", name, idx, inst, reg) @@ -1939,6 +1985,44 @@ def _extract_kernel_input_output(self): DFG(self._result.code_raw, dfg_log, DFGConfig(conf,inputs_are_outputs=True)).inputs) + def _extract_spills(self): + if self.config.constraints.allow_spills is None: + return + # Extract spills and restores with textual spill identifiers + spills = self._result._spills + restores = self._result._restores + for t in self._model.tree.nodes: + p = t.real_pos_program + def remember_spill(i, spilled, restore, arg, txt): + if spilled is True: + spills.setdefault(p, []) + restores.setdefault(restore, []) + spill_id = f"{txt}_{p}_{i}" + spills[p].append((arg, spill_id)) + restores[restore].append((arg, spill_id)) + for (i, (spilled, restore, arg)) in enumerate(zip(t.out_spills, t.out_lifetime_start, t.inst.args_out, strict=True)): + remember_spill(i, spilled, restore, arg, "out") + for (i, (spilled, restore, arg)) in enumerate(zip(t.in_out_spills, t.inout_lifetime_start, t.inst.args_in_out, strict=True)): + remember_spill(i, spilled, restore, arg, "inout") + # Map textual spill identifiers to numerical locations + free_locs = set(range(64)) + spill_id_to_loc = {} + m = max(list(spills.keys()) + list(restores.keys()), default=0) + 1 + for i in range(m): + s = spills.get(i, []) + r = restores.get(i, []) + for j in range(len(r)): + (reg, spill_id) = r[j] + loc = spill_id_to_loc[spill_id] + free_locs.add(loc) + r[j] = (reg, loc) + for j in range(len(s)): + (reg, spill_id) = s[j] + loc = min(free_locs) + free_locs.remove(loc) + spill_id_to_loc[spill_id] = loc + s[j] = (reg, loc) + def _extract_code(self): def get_code(filter_func=None, top=False): @@ -2354,19 +2438,22 @@ def add_arg_combination_vars(combinations, vs, name, t=t): for t in self._get_nodes(allnodes=True): self.logger.debug("Create register usage intervals for %s", t) - + t.out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_out ] + t.in_out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_in_out ] ivals = [] ivals += list(zip(t.inst.arg_types_out, t.alloc_out_var, t.out_lifetime_start, t.out_lifetime_duration, - t.out_lifetime_end, strict=True)) + t.out_lifetime_end, t.out_spill_vars, strict=True)) ivals += list(zip(t.inst.arg_types_in_out, t.alloc_in_out_var, t.inout_lifetime_start, t.inout_lifetime_duration, - t.inout_lifetime_end, strict=True)) + t.inout_lifetime_end, t.in_out_spill_vars, strict=True)) - for arg_ty, var_dict, start_var, dur_var, end_var in ivals: + for arg_ty, var_dict, start_var, dur_var, end_var, spill_var in ivals: + self._model.spill_vars.append(spill_var) + self._Add(start_var == t.program_start_var).OnlyEnforceIf(spill_var.Not()) for reg, var in var_dict.items(): self._add_register_usage(t, reg, arg_ty, var, - start_var, dur_var, end_var) + start_var, dur_var, end_var, spill_var) # ================================================================ # VARIABLES (Loop rolling) # @@ -2471,7 +2558,6 @@ def _add_basic_constraints(start_list, end_list): # one instruction. Otherwise, instructions producing outputs that # are never used would be able to overwrite life registers. self._Add(end_var > t.program_start_var) - self._Add(start_var == t.program_start_var) _add_basic_constraints(t.out_lifetime_start, t.out_lifetime_end) _add_basic_constraints(t.inout_lifetime_start, t.inout_lifetime_end) @@ -2482,11 +2568,14 @@ def _add_constraints_lifetime_bounds(self): self._add_constraints_lifetime_bounds_single(t) # For every instruction depending on the output, add a lifetime bound - for (consumer, producer, _, _, _, end_var, _) in \ + for (consumer, producer, _, _, start_var, end_var, _) in \ self._iter_dependencies_with_lifetime(): self._add_path_constraint(consumer, producer.src, lambda end_var=end_var, consumer=consumer: self._Add(end_var >= consumer.program_start_var)) + self._add_path_constraint(consumer, producer.src, + lambda start_var=start_var, consumer=consumer: + self._Add(start_var < consumer.program_start_var)) # ================================================================ # CONSTRAINTS (Register allocation) # @@ -2549,6 +2638,23 @@ def _add_constraints_register_renaming(self): else: self._Add(self._model._register_used[reg] == False) + for t in self._get_nodes(allnodes=True): + can_spill = True + if t.is_virtual is True: + can_spill = False + if self.config.constraints.allow_spills is False: + can_spill = False + + if can_spill is False: + for v in t.out_spill_vars + t.in_out_spill_vars: + self._Add(v == False) + continue + + for (ty, v) in list(zip(t.inst.arg_types_out, t.out_spill_vars)) + \ + list(zip(t.inst.arg_types_in_out, t.out_spill_vars)): + if self.arch.RegisterType.spillable(ty) is False: + self._Add(v == False) + # Ensure that outputs are unambiguous for t in self._get_nodes(allnodes=True): self.logger.debug("Ensure unambiguous register renaming for %s", str(t.inst)) @@ -3097,6 +3203,26 @@ def _print_stalls(self, stalls): (cycles, ipc) = r return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})" + def _print_stalls_and_spills(self, bound, variables): + + if "stalls" in variables.keys(): + stalls = variables["stalls"] + # TODO: This needs fixing once the objective measure is no longer stalls + spills + spills = bound - stalls + else: + stalls = bound + spills = None + + r = self._stalls_to_stats(stalls) + if r is None: + return " (?)" + (cycles, ipc) = r + + if spills is not None: + return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f}, Spills {spills})" + else: + return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})" + def _add_objective(self, force_objective=False): minlist = [] maxlist = [] @@ -3108,8 +3234,19 @@ def _add_objective(self, force_objective=False): if force_objective is False and self.config.variable_size: name = "minimize cycles" if self.config.constraints.functional_only is False: - printer = self._print_stalls + if self.config.constraints.minimize_spills is False: + printer = self._print_stalls + else: + printer = self._print_stalls_and_spills + objective_vars = { "stalls" : self._model.stalls } minlist = [self._model.stalls] + + # Spill minimization is integrated into the variable-size optimization, + # currently by counting every spill as 1 stall + # + # TODO: This needs refining! + if self.config.constraints.minimize_spills: + minlist += self._model.spill_vars elif self.config.has_objective and not self.config.ignore_objective: if self.config.sw_pipelining.enabled is True and \ self.config.sw_pipelining.minimize_overlapping is True: @@ -3117,6 +3254,9 @@ def _add_objective(self, force_objective=False): corevars = [ t.core_var.Not() for t in self._get_nodes(low=True) ] minlist = corevars name = "minimize iteration overlapping" + elif self.config.constraints.minimize_spills: + name = "minimize spills" + minlist = [self._model.spill_vars] elif self.config.constraints.maximize_register_lifetimes: name = "maximize register lifetimes" maxlist = [ v for t in self._get_nodes(allnodes=True) diff --git a/slothy/core/heuristics.py b/slothy/core/heuristics.py index 2361b24..c291d9c 100644 --- a/slothy/core/heuristics.py +++ b/slothy/core/heuristics.py @@ -258,7 +258,8 @@ def optimize_binsearch_internal(source, logger, conf, flexible=True, **kwargs): logger.info(f"Minimum number of stalls: {min_stalls}") - if conf.has_objective is False: + # Spill minimization is integrated into the stall minimization objective + if conf.has_objective is False or conf.constraints.minimize_spills is True: return core.result logger.info("Optimize again with minimal number of %d stalls, with objective...", diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py index 79c24b2..a19ce37 100644 --- a/slothy/targets/aarch64/aarch64_neon.py +++ b/slothy/targets/aarch64/aarch64_neon.py @@ -61,6 +61,11 @@ def __str__(self): def __repr__(self): return self.name + @cache + @staticmethod + def spillable(reg_type): + return reg_type in [RegisterType.GPR, RegisterType.NEON] + @cache @staticmethod def list_registers(reg_type, only_extra=False, only_normal=False, with_variants=False): @@ -3025,6 +3030,15 @@ class cmge(ASimdCompare): # pylint: disable=missing-docstring,invalid-name inputs = ["Va", "Vb"] outputs = ["Vd"] + +class Stack: + def spill(reg, loc): + # TODO: Use store instruction + return f"str {reg}, [sp, #STACK_LOC_{loc}]" + def restore(reg, loc): + # TODO: Use load instruction + return f"ldr {reg}, [sp, #STACK_LOC_{loc}]" + # In a pair of vins writing both 64-bit lanes of a vector, mark the # target vector as output rather than input/output. This enables further # renaming opportunities.