Skip to content

Commit

Permalink
Merge pull request #251 from KernelTuner/simulation-searchspace-impro…
Browse files Browse the repository at this point in the history
…vements

Small improvements to searchspaces and simulation mode
  • Loading branch information
benvanwerkhoven authored May 23, 2024
2 parents 9e8a59a + ea7ca58 commit 0fc4ad2
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 41 deletions.
2 changes: 2 additions & 0 deletions kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ def tune_kernel(

# create search space
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads)
restrictions = searchspace._modified_restrictions
tuning_options.restrictions = restrictions
if verbose:
print(f"Searchspace has {searchspace.size} configurations after restrictions.")

Expand Down
6 changes: 4 additions & 2 deletions kernel_tuner/runners/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def run(self, parameter_space, tuning_options):
continue

# if the element is not in the cache, raise an error
logging.debug(f"kernel configuration {element} not in cache")
raise ValueError(f"Kernel configuration {element} not in cache - in simulation mode, all configurations must be present in the cache")
check = util.check_restrictions(tuning_options.restrictions, dict(zip(tuning_options['tune_params'].keys(), element)), True)
err_string = f"kernel configuration {element} not in cache, does {'' if check else 'not '}pass extra restriction check ({check})"
logging.debug(err_string)
raise ValueError(f"{err_string} - in simulation mode, all configurations must be present in the cache")

return results
15 changes: 15 additions & 0 deletions kernel_tuner/searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
restrictions = restrictions if restrictions is not None else []
self.tune_params = tune_params
self.restrictions = restrictions
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
self._modified_restrictions = restrictions
self.param_names = list(self.tune_params.keys())
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
self.params_values_indices = None
Expand Down Expand Up @@ -166,6 +168,10 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in
block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}"
if block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions:
restrictions.append(block_size_restriction_spaced)
if isinstance(self._modified_restrictions, list) and block_size_restriction_spaced not in self._modified_restrictions:
self._modified_restrictions.append(block_size_restriction_spaced)
if isinstance(self.restrictions, list):
self.restrictions.append(block_size_restriction_spaced)

# check for search space restrictions
if restrictions is not None:
Expand Down Expand Up @@ -293,6 +299,11 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
)
if len(valid_block_size_names) > 0:
parameter_space.addConstraint(MaxProdConstraint(max_threads), valid_block_size_names)
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
if isinstance(self._modified_restrictions, list) and max_block_size_product not in self._modified_restrictions:
self._modified_restrictions.append(max_block_size_product)
if isinstance(self.restrictions, list):
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names))

# construct the parameter space with the constraints applied
return parameter_space.getSolutionsAsListDict(order=self.param_names)
Expand All @@ -302,10 +313,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
if isinstance(self.restrictions, list):
for restriction in self.restrictions:
required_params = self.param_names

# convert to a Constraint type if necessary
if isinstance(restriction, tuple):
restriction, required_params = restriction
if callable(restriction) and not isinstance(restriction, Constraint):
restriction = FunctionConstraint(restriction)

# add the Constraint
if isinstance(restriction, FunctionConstraint):
parameter_space.addConstraint(restriction, required_params)
elif isinstance(restriction, Constraint):
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(self, x, check_restrictions=True):
# check if max_fevals is reached or time limit is exceeded
util.check_stop_criterion(self.tuning_options)

# snap values in x to nearest actual value for each parameter unscale x if needed
# snap values in x to nearest actual value for each parameter, unscale x if needed
if self.snap:
if self.scaling:
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)
Expand Down
87 changes: 49 additions & 38 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ class StopCriterionReached(Exception):
"block_size_x",
"block_size_y",
"block_size_z",
"ngangs",
"nworkers",
"vlength",
]


Expand Down Expand Up @@ -248,9 +245,37 @@ def check_block_size_params_names_list(block_size_names, tune_params):
UserWarning,
)

def check_restriction(restrict, params: dict) -> bool:
"""Check whether a configuration meets a search space restriction."""
# if it's a python-constraint, convert to function and execute
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
return restrict(list(params.values()))
# if it's a string, fill in the parameters and evaluate
elif isinstance(restrict, str):
return eval(replace_param_occurrences(restrict, params))
# if it's a function, call it
elif callable(restrict):
return restrict(**params)
# if it's a tuple, use only the parameters in the second argument to call the restriction
elif (isinstance(restrict, tuple) and len(restrict) == 2
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
# unpack the tuple
restrict, selected_params = restrict
# look up the selected parameters and their value
selected_params = dict((key, params[key]) for key in selected_params)
# call the restriction
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
return restrict(list(selected_params.values()))
else:
return restrict(**selected_params)
# otherwise, raise an error
else:
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")

def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
"""Check whether a specific configuration meets the search space restrictions."""
"""Check whether a configuration meets the search space restrictions."""
if callable(restrictions):
valid = restrictions(params)
if not valid and verbose is True:
Expand All @@ -260,40 +285,13 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
for restrict in restrictions:
# Check the type of each restriction and validate accordingly. Re-implement as a switch when Python >= 3.10.
try:
# if it's a python-constraint, convert to function and execute
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
if not restrict(params.values()):
valid = False
break
# if it's a string, fill in the parameters and evaluate
elif isinstance(restrict, str):
if not eval(replace_param_occurrences(restrict, params)):
valid = False
break
# if it's a function, call it
elif callable(restrict):
if not restrict(**params):
valid = False
break
# if it's a tuple, use only the parameters in the second argument to call the restriction
elif (isinstance(restrict, tuple) and len(restrict) == 2
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
# unpack the tuple
restrict, selected_params = restrict
# look up the selected parameters and their value
selected_params = dict((key, params[key]) for key in selected_params)
# call the restriction
if not restrict(**selected_params):
valid = False
break
# otherwise, raise an error
else:
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
valid = check_restriction(restrict, params)
if not valid:
break
except ZeroDivisionError:
logging.debug(f"Restriction {restrict} with configuration {get_instance_string(params)} divides by zero.")
if not valid and verbose is True:
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction")
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction {restrict}")
return valid


Expand All @@ -311,6 +309,9 @@ def f_restrict(p):
elif isinstance(restrict, MaxProdConstraint):
def f_restrict(p):
return np.prod(p) <= restrict._maxprod
elif isinstance(restrict, MinProdConstraint):
def f_restrict(p):
return np.prod(p) >= restrict._minprod
elif isinstance(restrict, MaxSumConstraint):
def f_restrict(p):
return sum(p) <= restrict._maxsum
Expand Down Expand Up @@ -1005,6 +1006,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio
params_used = list(params_used)
finalized_constraint = None
if try_to_constraint and " or " not in res and " and " not in res:
# if applicable, strip the outermost round brackets
while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]:
parsed_restriction = parsed_restriction[1:-1]
# check if we can turn this into the built-in numeric comparison constraint
finalized_constraint = to_numeric_constraint(parsed_restriction, params_used)
if finalized_constraint is None:
Expand Down Expand Up @@ -1059,8 +1063,15 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal
# return the restrictions and used parameters
if len(restrictions_ignore) == 0:
return compiled_restrictions
restrictions_ignore = list(zip(restrictions_ignore, (() for _ in restrictions_ignore)))
return restrictions_ignore + compiled_restrictions

# use the required parameters or add an empty tuple for unknown parameters of ignored restrictions
noncompiled_restrictions = []
for r in restrictions_ignore:
if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)):
noncompiled_restrictions.append(r)
else:
noncompiled_restrictions.append((r, ()))
return noncompiled_restrictions + compiled_restrictions


def process_cache(cache, kernel_options, tuning_options, runner):
Expand Down Expand Up @@ -1181,7 +1192,7 @@ def correct_open_cache(cache, open_cache=True):
filestr = cachefile.read().strip()

# if file was not properly closed, pretend it was properly closed
if len(filestr) > 0 and not filestr[-3:] == "}\n}":
if len(filestr) > 0 and not filestr[-3:] in ["}\n}", "}}}"]:
# remove the trailing comma if any, and append closing brackets
if filestr[-1] == ",":
filestr = filestr[:-1]
Expand Down

0 comments on commit 0fc4ad2

Please sign in to comment.