Skip to content

Commit

Permalink
Extend support for symbolic assembly via spilling (#81)
Browse files Browse the repository at this point in the history
SLOTHY has long supported assembly which uses symbolic register
names instead of architectural ones. This commit improves the
compatibility of this feature with both (a) complex workloads
using numerous symbolic regisrters, and (b) architectures offering
only a small register file, by adding experimental support for
a simplified form of register spilling.
  • Loading branch information
hanno-becker committed Aug 3, 2024
1 parent 5b4c09a commit 7e929e4
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 15 deletions.
33 changes: 31 additions & 2 deletions slothy/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,12 @@ def objective_precision(self):

@property
def has_objective(self):
"""Indicates whether a secondary objective (beyond minimization of stalls)
"""Indicates whether a different objective than minimization of stalls
has been registered."""
objectives = sum([self.sw_pipelining.enabled and
self.sw_pipelining.minimize_overlapping is True,
self.constraints.maximize_register_lifetimes is True,
self.constraints.minimize_spills is True,
self.constraints.move_stalls_to_top is True,
self.constraints.move_stalls_to_bottom is True,
self.constraints.minimize_register_usage is not None,
Expand Down Expand Up @@ -896,6 +897,27 @@ def allow_renaming(self):
in order to find the number of model violations in a piece of code."""
return self._allow_renaming

@property
def allow_spills(self):
"""Allow Slothy to introduce stack spills
When this option is enabled, Slothy will consider the introduction
of stack spills to reduce register pressure.
This option should only be disabled if it is known that the input
assembly suffers from high register pressure. For example, this can
be the case for symbolic input assembly."""
return self._allow_spills

@property
def minimize_spills(self):
"""Minimize number of stack spills
When this option is enabled, the Slothy will pass minimization of
stack spills as the optimization objective to the solver.
"""
return self._minimize_spills

@property
def max_displacement(self):
"""The maximum relative displacement of an instruction.
Expand Down Expand Up @@ -926,7 +948,7 @@ def __init__(self):
self._max_displacement = 1.0

self.maximize_register_lifetimes = False

self.minimize_spills = False
self.move_stalls_to_top = False
self.move_stalls_to_bottom = False
self.minimize_register_usage = None
Expand All @@ -944,6 +966,7 @@ def __init__(self):
self._model_functional_units = True
self._allow_reordering = True
self._allow_renaming = True
self._allow_spills = False

self.lock()

Expand Down Expand Up @@ -980,6 +1003,12 @@ def allow_reordering(self,val):
@allow_renaming.setter
def allow_renaming(self,val):
self._allow_renaming = val
@allow_spills.setter
def allow_spills(self,val):
self._allow_spills = val
@minimize_spills.setter
def minimize_spills(self,val):
self._minimize_spills = val
@functional_only.setter
def functional_only(self,val):
self._model_latencies = val is False
Expand Down
164 changes: 152 additions & 12 deletions slothy/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,20 @@ def _get_code(self, visualize_reordering):
core_char = self.config.core_char
d = self.config.placeholder_char

def gen_restore(reg, loc, vis):
yield SourceLine(self.config.arch.Stack.restore(reg, loc)).\
set_length(self.fixlen).\
set_comment(vis).\
add_tag("is_restore", True).\
add_tag("reads", f"stack_{loc}")

def gen_spill(reg, loc, vis):
yield SourceLine(self.config.arch.Stack.spill(reg, loc)).\
set_length(self.fixlen).\
set_comment(vis).\
add_tag("is_spill", True).\
add_tag("writes", f"stack_{loc}")

def center_str_fixlen(txt, fixlen, char='-'):
txt = ' ' + txt + ' '
l = min(len(txt), fixlen)
Expand Down Expand Up @@ -614,6 +628,8 @@ def gen_vis(p, c):
yield SourceLine("").set_comment(legend2).set_length(fixlen)
for i in range(self.codesize_with_bubbles):
p = ri.get(i, None)
spills = self._spills.get(i, [])
restores = self._restores.get(i, [])
if p is not None:
c = core_char
if self.is_pre(p):
Expand All @@ -622,7 +638,11 @@ def gen_vis(p, c):
c = late_char
s = code[self.periodic_reordering[p]]
yield s.copy().set_length(fixlen).set_comment(gen_vis(p, c))
if p is None:
for (reg, loc) in restores:
yield from gen_restore(reg, loc, gen_vis(0, d))
for (reg, loc) in spills:
yield from gen_spill(reg, loc, gen_vis(0, d))
if p is None and len(s) == 0 and len(r) == 0:
gap_str = "gap"
yield SourceLine("") \
.set_comment(f"{gap_str:{fixlen-4}s}") \
Expand All @@ -648,6 +668,8 @@ def gen_vis(cc, c):
yield SourceLine("").set_comment(legend2).set_length(fixlen)
for i in range(cs):
p = ri.get(i, None)
spills = self._spills.get(i, [])
restores = self._restores.get(i, [])
cc = i // self.config.target.issue_rate
if p is not None:
c = core_char
Expand All @@ -657,6 +679,10 @@ def gen_vis(cc, c):
c = late_char
s = code[self.periodic_reordering[p]]
yield s.copy().set_length(fixlen).set_comment(gen_vis(cc, c))
for (reg, loc) in restores:
yield from gen_restore(reg, loc, gen_vis(cc, "r"))
for (reg, loc) in spills:
yield from gen_spill(reg, loc, gen_vis(cc, "s"))

def gen_visualized_code_with_old():
orig_code = self._orig_code
Expand All @@ -677,9 +703,15 @@ def gen_visualized_code_with_old():

for i in range(self.codesize_with_bubbles):
p = ri.get(i, None)
spills = self._spills.get(i, [])
restores = self._restores.get(i, [])
if p is not None:
s = code[self.periodic_reordering[p]]
yield s.copy().set_length(fixlen).set_comment(f"{old_code[i]:{old_maxlen + 8}s}")
for (reg, loc) in restores:
yield from gen_restore(reg, loc, f"{'-':{old_maxlen+8}s}")
for (reg, loc) in spills:
yield from gen_spill(reg, loc, f"{'-':{old_maxlen+8}s}")

def gen_visualized_code():
if self.config.visualize_expected_performance is True:
Expand Down Expand Up @@ -730,7 +762,8 @@ def gen_visualized_code():
@property
def code(self):
"""The optimized source code"""
return self._get_code(self.config.visualize_reordering)
c = self._get_code(self.config.visualize_reordering)
return c

@code.setter
def code(self, val):
Expand Down Expand Up @@ -1274,6 +1307,8 @@ def __init__(self, config):
self._codesize_with_bubbles = None
self._optimization_wall_time = None
self._optimization_user_time = None
self._spills = {}
self._restores = {}

self.lock()

Expand Down Expand Up @@ -1502,7 +1537,7 @@ def _init_model_internals(self):
self._model.intervals_for_unit = { k : [] for k in self.target.ExecutionUnit }
self._model.register_usages = {}
self._model.register_usage_vars = {}

self._model.spill_vars = []
self._model.variables = []

def _usage_check(self):
Expand Down Expand Up @@ -1640,7 +1675,12 @@ def _dump_avail_renaming_registers(self):
for ty in self.arch.RegisterType:
self.logger.input.debug(f"- {ty} available: {self._model.avail_renaming_regs[ty]}")

def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var):
def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var, spill_var):

spill_ival_active = self._NewBoolVar("")
self._AddMultiplicationEquality([var, spill_var], spill_ival_active)
spill_point_start = self._NewOptionalIntervalVar(t.program_start_var,
1, t.program_start_var + 1, spill_ival_active, "")

interval = self._NewOptionalIntervalVar(
start_var, dur_var, end_var, var, f"Usage({t.inst})({reg})<{var}>")
Expand All @@ -1651,6 +1691,7 @@ def _add_register_usage(self, t, reg, reg_ty, var, start_var, dur_var, end_var):

self._model.register_usages.setdefault(reg, [])
self._model.register_usages[reg].append(interval)
self._model.register_usages[reg].append(spill_point_start)

if var is None:
return
Expand Down Expand Up @@ -1831,6 +1872,7 @@ def _extract_result(self):
self._extract_register_renamings(get_value)
self._extract_input_output_renaming()

self._extract_spills()
self._extract_code()
self._result.selfcheck_with_fixup(self.logger.getChild("selfcheck"))
self._result.offset_fixup(self.logger.getChild("fixup"))
Expand Down Expand Up @@ -1920,6 +1962,10 @@ def _extract_true_key(var_dict):
t.inst.args_out = [ _extract_true_key(vars) for vars in t.alloc_out_var ]
t.inst.args_in = [ _extract_true_key(vars) for vars in t.alloc_in_var ]
t.inst.args_in_out = [ _extract_true_key(vars) for vars in t.alloc_in_out_var ]
t.out_spills = [(get_value(v) == 1) for v in t.out_spill_vars ]
t.in_out_spills = [(get_value(v) == 1) for v in t.in_out_spill_vars ]
t.out_lifetime_start = list(map(get_value, t.out_lifetime_start))
t.inout_lifetime_start = list(map(get_value, t.inout_lifetime_start))
def _dump_renaming(name,lst,inst):
for idx, reg in enumerate(lst):
self.logger.debug("%s %s of '%s' renamed to %s", name, idx, inst, reg)
Expand All @@ -1939,6 +1985,44 @@ def _extract_kernel_input_output(self):
DFG(self._result.code_raw, dfg_log,
DFGConfig(conf,inputs_are_outputs=True)).inputs)

def _extract_spills(self):
if self.config.constraints.allow_spills is None:
return
# Extract spills and restores with textual spill identifiers
spills = self._result._spills
restores = self._result._restores
for t in self._model.tree.nodes:
p = t.real_pos_program
def remember_spill(i, spilled, restore, arg, txt):
if spilled is True:
spills.setdefault(p, [])
restores.setdefault(restore, [])
spill_id = f"{txt}_{p}_{i}"
spills[p].append((arg, spill_id))
restores[restore].append((arg, spill_id))
for (i, (spilled, restore, arg)) in enumerate(zip(t.out_spills, t.out_lifetime_start, t.inst.args_out, strict=True)):
remember_spill(i, spilled, restore, arg, "out")
for (i, (spilled, restore, arg)) in enumerate(zip(t.in_out_spills, t.inout_lifetime_start, t.inst.args_in_out, strict=True)):
remember_spill(i, spilled, restore, arg, "inout")
# Map textual spill identifiers to numerical locations
free_locs = set(range(64))
spill_id_to_loc = {}
m = max(list(spills.keys()) + list(restores.keys()), default=0) + 1
for i in range(m):
s = spills.get(i, [])
r = restores.get(i, [])
for j in range(len(r)):
(reg, spill_id) = r[j]
loc = spill_id_to_loc[spill_id]
free_locs.add(loc)
r[j] = (reg, loc)
for j in range(len(s)):
(reg, spill_id) = s[j]
loc = min(free_locs)
free_locs.remove(loc)
spill_id_to_loc[spill_id] = loc
s[j] = (reg, loc)

def _extract_code(self):

def get_code(filter_func=None, top=False):
Expand Down Expand Up @@ -2354,19 +2438,22 @@ def add_arg_combination_vars(combinations, vs, name, t=t):

for t in self._get_nodes(allnodes=True):
self.logger.debug("Create register usage intervals for %s", t)

t.out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_out ]
t.in_out_spill_vars = [ self._NewBoolVar("") for _ in t.inst.arg_types_in_out ]
ivals = []
ivals += list(zip(t.inst.arg_types_out, t.alloc_out_var,
t.out_lifetime_start, t.out_lifetime_duration,
t.out_lifetime_end, strict=True))
t.out_lifetime_end, t.out_spill_vars, strict=True))
ivals += list(zip(t.inst.arg_types_in_out, t.alloc_in_out_var,
t.inout_lifetime_start, t.inout_lifetime_duration,
t.inout_lifetime_end, strict=True))
t.inout_lifetime_end, t.in_out_spill_vars, strict=True))

for arg_ty, var_dict, start_var, dur_var, end_var in ivals:
for arg_ty, var_dict, start_var, dur_var, end_var, spill_var in ivals:
self._model.spill_vars.append(spill_var)
self._Add(start_var == t.program_start_var).OnlyEnforceIf(spill_var.Not())
for reg, var in var_dict.items():
self._add_register_usage(t, reg, arg_ty, var,
start_var, dur_var, end_var)
start_var, dur_var, end_var, spill_var)

# ================================================================
# VARIABLES (Loop rolling) #
Expand Down Expand Up @@ -2471,7 +2558,6 @@ def _add_basic_constraints(start_list, end_list):
# one instruction. Otherwise, instructions producing outputs that
# are never used would be able to overwrite life registers.
self._Add(end_var > t.program_start_var)
self._Add(start_var == t.program_start_var)

_add_basic_constraints(t.out_lifetime_start, t.out_lifetime_end)
_add_basic_constraints(t.inout_lifetime_start, t.inout_lifetime_end)
Expand All @@ -2482,11 +2568,14 @@ def _add_constraints_lifetime_bounds(self):
self._add_constraints_lifetime_bounds_single(t)

# For every instruction depending on the output, add a lifetime bound
for (consumer, producer, _, _, _, end_var, _) in \
for (consumer, producer, _, _, start_var, end_var, _) in \
self._iter_dependencies_with_lifetime():
self._add_path_constraint(consumer, producer.src,
lambda end_var=end_var, consumer=consumer:
self._Add(end_var >= consumer.program_start_var))
self._add_path_constraint(consumer, producer.src,
lambda start_var=start_var, consumer=consumer:
self._Add(start_var < consumer.program_start_var))

# ================================================================
# CONSTRAINTS (Register allocation) #
Expand Down Expand Up @@ -2549,6 +2638,23 @@ def _add_constraints_register_renaming(self):
else:
self._Add(self._model._register_used[reg] == False)

for t in self._get_nodes(allnodes=True):
can_spill = True
if t.is_virtual is True:
can_spill = False
if self.config.constraints.allow_spills is False:
can_spill = False

if can_spill is False:
for v in t.out_spill_vars + t.in_out_spill_vars:
self._Add(v == False)
continue

for (ty, v) in list(zip(t.inst.arg_types_out, t.out_spill_vars)) + \
list(zip(t.inst.arg_types_in_out, t.out_spill_vars)):
if self.arch.RegisterType.spillable(ty) is False:
self._Add(v == False)

# Ensure that outputs are unambiguous
for t in self._get_nodes(allnodes=True):
self.logger.debug("Ensure unambiguous register renaming for %s", str(t.inst))
Expand Down Expand Up @@ -3097,6 +3203,26 @@ def _print_stalls(self, stalls):
(cycles, ipc) = r
return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})"

def _print_stalls_and_spills(self, bound, variables):

if "stalls" in variables.keys():
stalls = variables["stalls"]
# TODO: This needs fixing once the objective measure is no longer stalls + spills
spills = bound - stalls
else:
stalls = bound
spills = None

r = self._stalls_to_stats(stalls)
if r is None:
return " (?)"
(cycles, ipc) = r

if spills is not None:
return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f}, Spills {spills})"
else:
return f" (Cycles ~ {cycles}, IPC ~ {ipc:.2f})"

def _add_objective(self, force_objective=False):
minlist = []
maxlist = []
Expand All @@ -3108,15 +3234,29 @@ def _add_objective(self, force_objective=False):
if force_objective is False and self.config.variable_size:
name = "minimize cycles"
if self.config.constraints.functional_only is False:
printer = self._print_stalls
if self.config.constraints.minimize_spills is False:
printer = self._print_stalls
else:
printer = self._print_stalls_and_spills
objective_vars = { "stalls" : self._model.stalls }
minlist = [self._model.stalls]

# Spill minimization is integrated into the variable-size optimization,
# currently by counting every spill as 1 stall
#
# TODO: This needs refining!
if self.config.constraints.minimize_spills:
minlist += self._model.spill_vars
elif self.config.has_objective and not self.config.ignore_objective:
if self.config.sw_pipelining.enabled is True and \
self.config.sw_pipelining.minimize_overlapping is True:
# Minimize the amount of iteration interleaving
corevars = [ t.core_var.Not() for t in self._get_nodes(low=True) ]
minlist = corevars
name = "minimize iteration overlapping"
elif self.config.constraints.minimize_spills:
name = "minimize spills"
minlist = [self._model.spill_vars]
elif self.config.constraints.maximize_register_lifetimes:
name = "maximize register lifetimes"
maxlist = [ v for t in self._get_nodes(allnodes=True)
Expand Down
Loading

0 comments on commit 7e929e4

Please sign in to comment.