From 3623ced55fed34bbb7f822c4e49e12f9c1ea07e5 Mon Sep 17 00:00:00 2001 From: alpavlenko Date: Wed, 27 Dec 2023 17:52:18 +0300 Subject: [PATCH] added partition space, divide function, updated combine solver --- core/impl/__init__.py | 2 - core/impl/combine_t.py | 460 ++++++++++++++-------------- core/impl/growing_t.py | 244 --------------- function/impl/__init__.py | 6 +- function/impl/function_div.py | 82 +++++ function/impl/function_gad.py | 5 +- function/impl/function_ibs.py | 5 +- function/impl/function_ips.py | 5 +- function/impl/function_rho.py | 33 +- function/impl/function_rho_t.py | 5 +- lib_satprob/solver/impl/__init__.py | 2 + lib_satprob/solver/impl/external.py | 221 +++++++++++++ lib_satprob/solver/impl/py2sat.py | 11 +- lib_satprob/solver/impl/pysat.py | 10 +- space/_utility.py | 14 +- space/impl/interval_set.py | 4 +- space/impl/partition_set.py | 41 +++ space/model/__init__.py | 1 + space/model/partition.py | 56 ++++ util/iterable.py | 7 +- util/wrapppers.py | 33 ++ 21 files changed, 739 insertions(+), 508 deletions(-) delete mode 100644 core/impl/growing_t.py create mode 100644 function/impl/function_div.py create mode 100644 lib_satprob/solver/impl/external.py create mode 100644 space/impl/partition_set.py create mode 100644 space/model/partition.py create mode 100644 util/wrapppers.py diff --git a/core/impl/__init__.py b/core/impl/__init__.py index aca3ed4..d3267d8 100644 --- a/core/impl/__init__.py +++ b/core/impl/__init__.py @@ -2,12 +2,10 @@ from .solving import * from .optimize import * from .combine_t import * -from .growing_t import * cores = { Solving.slug: Solving, Combine.slug: Combine, Optimize.slug: Optimize, CombineT.slug: CombineT, - GrowingT.slug: GrowingT } diff --git a/core/impl/combine_t.py b/core/impl/combine_t.py index 1214afd..27d5b51 100644 --- a/core/impl/combine_t.py +++ b/core/impl/combine_t.py @@ -2,173 +2,161 @@ import json from time import time as now -from itertools import product -from tempfile import NamedTemporaryFile as NTFile -from typing import Any, List, Dict, Optional, Tuple +from tempfile import NamedTemporaryFile +from typing import Any, List, Dict, Optional, Tuple, Iterable + +from space.model import Backdoor +from ..abc import Core from output import Logger from executor import Executor -from lib_satprob.encoding import WCNF - -from ..abc import Core -from lib_satprob.solver import Report from lib_satprob.problem import Problem -from lib_satprob.variables import Assumptions, Supplements, combine +from lib_satprob.encoding import Clauses +from lib_satprob.solver import Report, _Solver +from lib_satprob.derived import get_derived_by +from lib_satprob.variables import Assumptions, Supplements from function.model import Estimation from function.module.measure import Measure from function.module.budget import TaskBudget, KeyLimit +from util.wrapppers import timed from typings.searchable import Searchable from util.iterable import slice_into, split_by UnWeightTask = Tuple[int, Supplements] - # +# Formula patching logic (create, load and initialize solver) +# ============================================================================== +FORMULAS: Dict[int, Any] = {} +VERSIONS: Dict[int, Clauses] = {} + + +def create_patch( + clauses: Clauses +) -> Tuple[str, int]: + version = max(VERSIONS.keys()) \ + if len(VERSIONS) > 0 else 1 + VERSIONS[version] = clauses + + with NamedTemporaryFile( + delete=False, mode='w+' + ) as handle: + json.dump(clauses, handle) + return handle.name, version + + +def get_formula( + problem: Problem, + filename: str, + version: int, +) -> Any: + if version not in FORMULAS: + print('loading...', version, filename) + formula = problem.encoding.get_formula() + if filename is not None and version > 0: + with open(filename, 'r+') as handle: + clauses = json.load(handle) + formula.extend(clauses) + + FORMULAS[version] = formula + + return FORMULAS[version] + + +# def get_solver( +# problem: Problem, +# filename: str, +# version: int, +# ) -> _Solver: +# if version not in SOLVERS: +# for solver in SOLVERS.values(): +# solver.__exit__() +# SOLVERS.clear() # +# formula = get_formula(problem, filename, version) +# solver = problem.solver.get_instance(formula) +# SOLVERS[version] = solver.__enter__() # -def bool2sign(b): - return -1 if b else 1 - - -def signed(x, s): - return bool2sign(s) * x - - -def minimize_dnf(dnf): - from pyeda.inter import espresso_exprs - - min_dnf = espresso_exprs(dnf) - return min_dnf - - -def cnf_to_clauses(cnf): - assert cnf.is_cnf() - - litmap, nvars, clauses = cnf.encode_cnf() - result = [] - for clause in clauses: - c = [] - for lit in clause: - v = litmap[abs(lit)].indices[0] # 1-based variable index - s = lit < 0 # sign - c.append(signed(v, s)) - c.sort(key=lambda x: abs(x)) - result.append(c) +# return SOLVERS[version] - clauses = result - clauses.sort(key=lambda x: (len(x), tuple(map(abs, x)))) - return clauses - -def cubes_to_dnf(cubes): - from pyeda.inter import exprvar, And, Or - - var_map = dict() - cubes_expr = [] - - for cube in cubes: - lits_expr = [] - for lit in cube: - var = abs(lit) - if var not in var_map: - var_map[var] = exprvar("x", var) - if lit < 0: - lits_expr.append(~var_map[var]) - else: - lits_expr.append(var_map[var]) - cubes_expr.append(And(*lits_expr)) - - dnf = Or(*cubes_expr) - assert dnf.is_dnf() - return dnf - - -def backdoor_to_clauses_via_easy(easy): - # Note: here, 'dnf' represents the negation of characteristic function, - # because we use "easy" tasks here. - dnf = cubes_to_dnf(easy) - (min_dnf,) = minimize_dnf(dnf) - min_cnf = (~min_dnf).to_cnf() # here, we negate the function back - clauses = cnf_to_clauses(min_cnf) - return clauses - - -def grow_worker(easy_tasks: List[Supplements]) -> Tuple[Supplements, float]: - _stamp, easy_cubes = now(), [sups[0] for sups in easy_tasks] - clauses = backdoor_to_clauses_via_easy(easy_cubes) - constr, one_lit = split_by(clauses, lambda x: len(x) > 1) - return ([clause[0] for clause in one_lit], constr), now() - _stamp - - -# -# # - - -def is_unsat(clause: List[int], value_map: Dict[int, int]) -> bool: - size = len(clause) - for literal in clause: - value = value_map.get(abs(literal)) - if literal == value: - return False - if value is not None: - size -= 1 - return True if size == 0 else None - - -def calc_cost(formula, literals: Assumptions) -> int: - value_map = {abs(lit): lit for lit in literals} - return sum([ - weight if is_unsat(clause, value_map) else 0 for - weight, clause in zip(formula.wght, formula.soft) - ]) - - -def prep_worker( - problem: Problem, searchable: Searchable -) -> Tuple[Searchable, List[Supplements], List[Supplements]]: - clauses = problem.encoding.get_formula(copy=False).hard - with problem.solver.get_instance(clauses) as solver: - easy, hard = split_by( - searchable.enumerate(), lambda sups: - solver.propagate(sups).status is False - ) - return searchable, easy, hard +# Hard task product logic (combine, worker) +# ============================================================================== +# def is_unsat(clause: List[int], value_map: Dict[int, int]) -> bool: +# size = len(clause) +# for literal in clause: +# value = value_map.get(abs(literal)) +# if literal == value: +# return False +# if value is not None: +# size -= 1 +# return True if size == 0 else None + + +# def calc_cost(formula, literals: Assumptions) -> int: +# value_map = {abs(lit): lit for lit in literals} +# return sum([ +# weight if is_unsat(clause, value_map) else 0 for +# weight, clause in zip(formula.wght, formula.soft) +# ]) + +def prod_combine( + acc_tasks: List[Assumptions], tasks: List[Assumptions] +) -> Iterable[Supplements]: + for assumptions in map(set, tasks): + for acc_assumptions in acc_tasks: + if sum([ + -literal in assumptions for + literal in acc_assumptions + ]) > 0: continue + + yield assumptions.union( + set(acc_assumptions) + ), [] def prod_worker( - problem: Problem, acc_tasks: List[UnWeightTask], - tasks: List[UnWeightTask], -) -> Tuple[List[UnWeightTask], float]: - tasks, acc_tasks = [t[1] for t in tasks], [t[1] for t in acc_tasks] - _stamp, formula = now(), problem.encoding.get_formula(copy=False) - w_acc_hard_task, prod = [], product(acc_tasks, tasks) - - with problem.solver.get_instance(formula.hard) as solver: - for acc_hard_task in [combine(*prs) for prs in prod]: + acc_tasks: List[Assumptions], tasks: List[Assumptions], + problem: Problem, patch: str, version: int +) -> Tuple[List[Assumptions], float]: + _stamp, prod_hard_task = now(), [] + formula = get_formula(problem, patch, version) + with problem.solver.get_instance(formula, False) as solver: + for acc_hard_task in prod_combine(acc_tasks, tasks): report = solver.propagate(acc_hard_task) if report.status is None or report.status: - cost = 0 if len(report.model) == 0 \ - else calc_cost(formula, report.model) - w_acc_hard_task.append((cost, acc_hard_task)) + # cost = 0 if len(report.model) == 0 \ + # else calc_cost(formula, report.model) + prod_hard_task.append(acc_hard_task[0]) - return w_acc_hard_task, now() - _stamp + return prod_hard_task, now() - _stamp -def limit_worker( - problem: Problem, task: UnWeightTask, limit: KeyLimit -) -> Tuple[UnWeightTask, Report]: +# +# ============================================================================== +def prep_worker( + problem: Problem, backdoor: Backdoor +) -> Tuple[Backdoor, List[Supplements], List[Supplements]]: formula = problem.encoding.get_formula(copy=False) with problem.solver.get_instance(formula) as solver: - return task, solver.solve(task[1], limit, extract_model=False) + easy, hard = split_by( + backdoor.enumerate(), lambda sups: + solver.propagate(sups).status is False + ) + return backdoor, easy, hard -def hard_worker(problem: Problem, hard_task: Assumptions) -> Report: - formula = problem.encoding.get_formula(copy=False) - return problem.solver.solve(formula, (hard_task, [])) +def hard_worker( + task: Assumptions, limit: KeyLimit, + problem: Problem, patch: str, version: int +) -> Tuple[Assumptions, Report]: + formula = get_formula(problem, patch, version) + with problem.solver.get_instance(formula) as solver: + return task, solver.solve((task, []), limit) class CombineT(Core): @@ -183,14 +171,18 @@ def __init__(self, logger: Logger, measure: Measure, problem: Problem, super().__init__(logger, problem, random_seed) self.clauses = [] - self.stats_sum = {} + self.stats_sum = { + 'prod_time': 0., + 'grow_time': 0. + } self.best_model = (None, []) - def sifting(self, tasks: List[UnWeightTask]) -> List[UnWeightTask]: + def sifting( + self, tasks: List[Assumptions], patch: str, version: int + ) -> List[Assumptions]: hard_tasks, limit = [], self.measure.get_limit(self.budget) - future_all, count = self.executor.submit_all(limit_worker, *( - (self.problem, task, limit) for task in tasks if - task[0] is None or task[0] < self.best_model[0] + future_all, count = self.executor.submit_all(hard_worker, *( + (task, limit, self.problem, patch, version) for task in tasks )), len(tasks) print('weight penalty:', f'{len(tasks)} -> {len(future_all)}') @@ -212,7 +204,7 @@ def sifting(self, tasks: List[UnWeightTask]) -> List[UnWeightTask]: return hard_tasks - def _preprocess(self, *searchables: Searchable) -> List[List[UnWeightTask]]: + def _preprocess(self, *backdoors: Backdoor) -> List[List[Assumptions]]: current_var_set, all_hard_tasks = set(), [] all_assumptions, all_constraints = set(), set() @@ -223,7 +215,7 @@ def var_distance(_searchable: Searchable) -> int: ]) results = [future.result() for future in self.executor.submit_all( - prep_worker, *((self.problem, sch) for sch in searchables) + prep_worker, *((self.problem, bd) for bd in backdoors) ).as_complete()] one_hard, results = split_by(results, lambda r: len(r[2]) == 1) processed = sorted(results, key=lambda r: len(r[2])) @@ -239,105 +231,127 @@ def add_supplements(_supplements: Supplements): for searchable, _, hard_tasks in one_hard: add_supplements(hard_tasks[0]) - for future in self.executor.submit_all(grow_worker, *( + for future in self.executor.submit_all(timed(get_derived_by), *( (easy_tasks,) for _, easy_tasks, _ in processed )).as_complete(): - supplements, grow_time = future.result() - self.stats_sum['grow_time'] += grow_time + supplements, _time = future.result() + self.stats_sum['grow_time'] += _time add_supplements(supplements) + print(all_assumptions) + for searchable, _, hard_tasks in processed: + fil_hard_tasks = [] + for hard_task in hard_tasks: + fil_hard_task = [] + for literal in hard_task[0]: + if -literal in all_assumptions: + break + if literal not in all_assumptions: + fil_hard_task.append(literal) + else: + print(hard_task[0], '->', fil_hard_task) + fil_hard_tasks.append((fil_hard_task, [])) + + print(searchable, len(hard_tasks), '->', len(fil_hard_tasks)) + if var_distance(searchable) > 1: - all_hard_tasks.append(hard_tasks) + all_hard_tasks.append(fil_hard_tasks) for var in searchable.variables(): current_var_set.add(var.name) - if len(all_assumptions) > 0: - assumptions = list(all_assumptions) - all_hard_tasks.insert(0, [(assumptions, [])]) + # if len(all_assumptions) > 0: + # assumptions = list(all_assumptions) + # # all_hard_tasks.insert(0, [(assumptions, [])]) - if len(all_constraints) > 0: - constraints = map(list, all_constraints) - self.clauses = list(constraints) + # if len(all_constraints) > 0: + # constraints = map(list, all_constraints) + + constraints = map(list, all_constraints) + self.clauses = list(constraints) + [ + [lit] for lit in all_assumptions + ] return [ - [(0, task) for task in tasks] + [task[0] for task in tasks] for tasks in all_hard_tasks ] - def launch(self, *searchables: Searchable) -> Estimation: - formula, files = self.problem.encoding.get_formula(), [] - self.best_model, start_stamp = (sum(formula.wght), []), now() - self.stats_sum['prod_time'], self.stats_sum['grow_time'] = 0, 0 - - all_hard_tasks = self._preprocess(*searchables) - [acc_hard_tasks, *all_hard_tasks] = all_hard_tasks - - with NTFile(delete=False) as wcnf_file: - formula.extend(self.clauses) - formula.to_file(wcnf_file.name) - files.append(wcnf_file.name) - self.problem.encoding = WCNF( - from_file=wcnf_file.name - ) - - plot_data = [( - len(set(map(abs, acc_hard_tasks[0][1][0]))), - len(acc_hard_tasks), len(acc_hard_tasks) - )] - - for i, hard_tasks in enumerate(all_hard_tasks): - next_acc_hard_tasks = [] - prod_size = len(acc_hard_tasks) * len(hard_tasks) - for future in self.executor.submit_all(prod_worker, *(( - self.problem, acc_part_hard_tasks, hard_tasks - ) for acc_part_hard_tasks in slice_into( - acc_hard_tasks, self.executor.max_workers - ))).as_complete(): - prod_tasks, prod_time = future.result() - self.stats_sum['prod_time'] += prod_time - next_acc_hard_tasks.extend(prod_tasks) - - var_set = sorted(set(map(abs, next_acc_hard_tasks[0][1][0]))) - print(f'var set ({len(var_set)}):', ' '.join(map(str, var_set))) - print('stats:', self.stats_sum) - print(f'time ({self.executor.max_workers})', - now() - start_stamp) - - print( - f'reduced: {prod_size} -> {len(next_acc_hard_tasks)}', - f'({round(len(next_acc_hard_tasks) / prod_size, 2)})', - ) - acc_hard_tasks = self.sifting(next_acc_hard_tasks) - - plot_data.append(( - len(var_set), len(next_acc_hard_tasks), len(acc_hard_tasks) - )) - - print( - f'sifted: {len(next_acc_hard_tasks)} -> {len(acc_hard_tasks)}', - f'({round(len(acc_hard_tasks) / len(next_acc_hard_tasks), 2)})' - ) - if len(acc_hard_tasks) == 0: - print(f'total bds: {i + 1}') - print('total stats:', self.stats_sum) - print(f'total time ({self.executor.max_workers})', - now() - start_stamp) - print(f'total var set ({len(var_set)}):', - ' '.join(map(str, var_set))) - break - - print(self.stats_sum) - print(self.best_model) - print('parallel time:', now() - start_stamp) - print('sequential time:', self.stats_sum['grow_time'] + - self.stats_sum['time'] + self.stats_sum['prod_time']) + def launch(self, *backdoors: Backdoor) -> Estimation: + start_stamp, files = now(), [] + # self.best_model = (sum(formula.wght), []) + try: + all_hard_tasks = self._preprocess(*backdoors) + [acc_hard_tasks, *all_hard_tasks] = all_hard_tasks - print('plot data') - print(json.dumps(plot_data)) + if len(self.clauses) > 0: + patch_file, version = create_patch(self.clauses) + else: + patch_file, version = None, 0 + + plot_data = [( + len(set(map(abs, acc_hard_tasks[0]))), + len(acc_hard_tasks), len(acc_hard_tasks) + )] + + for i, hard_tasks in enumerate(all_hard_tasks): + next_acc_hard_tasks = [] + prod_size = len(acc_hard_tasks) * len(hard_tasks) + # print(acc_hard_tasks, 'x', hard_tasks) + + for future in self.executor.submit_all(prod_worker, *(( + acc_part_hard_tasks, hard_tasks, + self.problem, patch_file, version + ) for acc_part_hard_tasks in slice_into( + acc_hard_tasks, self.executor.max_workers + ))).as_complete(): + prod_tasks, prod_time = future.result() + self.stats_sum['prod_time'] += prod_time + next_acc_hard_tasks.extend(prod_tasks) + + var_set = sorted(set(map(abs, next_acc_hard_tasks[0]))) + print(f'var set ({len(var_set)}):', ' '.join(map(str, var_set))) + print('stats:', self.stats_sum) + print(f'time ({self.executor.max_workers})', + now() - start_stamp) - [os.remove(file) for file in files] - return self.stats_sum + print( + f'reduced: {prod_size} -> {len(next_acc_hard_tasks)}', + f'({round(len(next_acc_hard_tasks) / prod_size, 2)})', + ) + acc_hard_tasks = self.sifting( + next_acc_hard_tasks, patch_file, version + ) + + plot_data.append(( + len(var_set), len(next_acc_hard_tasks), len(acc_hard_tasks) + )) + + print( + f'sifted: {len(next_acc_hard_tasks)} -> {len(acc_hard_tasks)}', + f'({round(len(acc_hard_tasks) / len(next_acc_hard_tasks), 2)})' + ) + if len(acc_hard_tasks) == 0: + print(f'total bds: {i + 1}') + print('total stats:', self.stats_sum) + print(f'total time ({self.executor.max_workers})', + now() - start_stamp) + print(f'total var set ({len(var_set)}):', + ' '.join(map(str, var_set))) + break + + print(self.stats_sum) + print(self.best_model) + print('parallel time:', now() - start_stamp) + print('sequential time:', self.stats_sum['grow_time'] + + self.stats_sum['time'] + self.stats_sum['prod_time']) + + print('plot data') + print(json.dumps(plot_data)) + + return self.stats_sum + finally: + [os.remove(file) for file in files] def __config__(self) -> Dict[str, Any]: return {} diff --git a/core/impl/growing_t.py b/core/impl/growing_t.py deleted file mode 100644 index 417b2f7..0000000 --- a/core/impl/growing_t.py +++ /dev/null @@ -1,244 +0,0 @@ -import json -from typing import List, Optional, Tuple - -from output import Logger -from executor import Executor - -from lib_satprob.solver import Report -from lib_satprob.problem import Problem -from typings.searchable import Searchable -from lib_satprob.variables import Assumptions, Supplements, combine -from util.iterable import split_by - -from ..abc import Core - -from function.module.measure import Measure -from function.module.budget import TaskBudget, KeyLimit - - -def is_sat(clause: List[int], solution: List[int]) -> bool: - for literal in clause: - if literal in solution: - return True - return False - - -def calc_weight(formula, solution: Assumptions) -> int: - return sum([ - weight if is_sat(clause, solution) else 0 for - weight, clause in zip(formula.wght, formula.soft) - ]) - - -def bool2sign(b): - return -1 if b else 1 - - -def signed(x, s): - return bool2sign(s) * x - - -def minimize_dnf(dnf): - from pyeda.inter import espresso_exprs - - print(f"Minimizing DNF via Espresso...") - min_dnf = espresso_exprs(dnf) - return min_dnf - - -def cnf_to_clauses(cnf): - print("Converting CNF into clauses...") - - assert cnf.is_cnf() - - litmap, nvars, clauses = cnf.encode_cnf() - result = [] - for clause in clauses: - c = [] - for lit in clause: - v = litmap[abs(lit)].indices[0] # 1-based variable index - s = lit < 0 # sign - c.append(signed(v, s)) - c.sort(key=lambda x: abs(x)) - result.append(c) - - clauses = result - clauses.sort(key=lambda x: (len(x), tuple(map(abs, x)))) - print( - f"Total {len(clauses)} clauses: {sum(1 for clause in clauses if len(clause) == 1)} units, {sum(1 for clause in clauses if len(clause) == 2)} binary, {sum(1 for clause in clauses if len(clause) == 3)} ternary, {sum(1 for clause in clauses if len(clause) > 3)} larger" - ) - return clauses - - -def cubes_to_dnf(cubes): - from pyeda.inter import exprvar, And, Or - - var_map = dict() - cubes_expr = [] - - for cube in cubes: - lits_expr = [] - for lit in cube: - var = abs(lit) - if var not in var_map: - var_map[var] = exprvar("x", var) - if lit < 0: - lits_expr.append(~var_map[var]) - else: - lits_expr.append(var_map[var]) - cubes_expr.append(And(*lits_expr)) - - dnf = Or(*cubes_expr) - assert dnf.is_dnf() - return dnf - - -def backdoor_to_clauses_via_easy(easy): - # Note: here, 'dnf' represents the negation of characteristic function, - # because we use "easy" tasks here. - dnf = cubes_to_dnf(easy) - (min_dnf,) = minimize_dnf(dnf) - min_cnf = (~min_dnf).to_cnf() # here, we negate the function back - clauses = cnf_to_clauses(min_cnf) - return clauses - - -def propagate( - problem: Problem, searchable: Searchable -) -> Tuple[List[Supplements], List[Supplements]]: - up_tasks, no_up_tasks = [], [] - formula = problem.encoding.get_formula(copy=False) - with problem.solver.get_instance(formula.hard) as solver: - for supplements in searchable.enumerate(): - _, status, _, _ = solver.propagate(supplements) - if status is None or status: - no_up_tasks.append(supplements) - else: - up_tasks.append(supplements) - - return up_tasks, no_up_tasks - - -def limit_worker( - problem: Problem, task: Supplements, limit: KeyLimit -) -> Tuple[Supplements, Report]: - formula = problem.encoding.get_formula(copy=False) - with problem.solver.get_instance(formula) as solver: - weight, status, stats, model = solver.solve(task, limit) - if status: weight = calc_weight(formula, model) - return task, Report(weight, status, stats, model) - - -def grow(easy_tasks: List[Supplements]) -> Supplements: - easy_cubes = [sups[0] for sups in easy_tasks] - clauses = backdoor_to_clauses_via_easy(easy_cubes) - const, one_lit = split_by(clauses, lambda x: len(x) > 1) - return [clause[0] for clause in one_lit], const - - -class GrowingT(Core): - slug = 'core:growing' - - def __init__(self, logger: Logger, measure: Measure, problem: Problem, - executor: Executor, budget: TaskBudget, - random_seed: Optional[int] = None): - self.budget = budget - self.measure = measure - self.executor = executor - super().__init__(logger, problem, random_seed) - - self.stats_sum = {} - self.max_weight = None - self.best_model = (0, []) - - def sifting(self, tasks: List[Supplements]) -> List[Supplements]: - easy_tasks, limit = [], self.measure.get_limit(self.budget) - future_all, count = self.executor.submit_all(limit_worker, *( - (self.problem, task, limit) for task in tasks - )), len(tasks) - - while len(future_all) > 0: - for future in future_all.as_complete(count=1): - task, report = future.result() - for key, value in report.stats.items(): - self.stats_sum[key] = \ - self.stats_sum.get(key, 0.) + value - - if report.status is False: easy_tasks.append(task) - if report.cost and report.cost > self.best_model[0]: - self.best_model = (report.cost, report.model) - print(f'{count - len(future_all)}/{count}: {report}') - - return easy_tasks - - def launch(self, *searchables: Searchable) -> Report: - formula = self.problem.encoding.get_formula(copy=False) - self.max_weight, grown_sups = sum(formula.wght), [] - - # start load formula cache - # --------------------------------- - filepath = self.problem.encoding.from_file - filename = filepath and filepath.split('/')[-1] - try: - with open('growing-cache.json') as handle: - growing_cache = json.load(handle) - except FileNotFoundError: - growing_cache = {} - - limit = str(self.measure.get_limit(self.budget)) - if filename and (filename in growing_cache): - formula_cache = growing_cache[filename] - limit_cache = formula_cache.get(limit, {}) - else: - formula_cache, limit_cache = {}, {} - # --------------------------------- - # end load formula cache - - print(len(limit_cache)) - uniq_assumptions, uniq_constraints = set(), set() - for index, searchable in enumerate(searchables): - if str(searchable) not in limit_cache: - easy, hard = propagate(self.problem, searchable) - grow_clauses = grow([*easy, *self.sifting(hard)]) - limit_cache[str(searchable)] = grow_clauses - - if filename is not None: - with open('growing-cache.json', 'w+') as handle: - json.dump({ - **growing_cache, - filename: { - **formula_cache, - limit: limit_cache - } - }, handle) - - grown_sups.append(limit_cache[str(searchable)]) - - for assumptions, constraints in grown_sups: - uniq_assumptions.update(assumptions) - uniq_constraints.update([ - tuple(sorted(constraint, key=abs)) - for constraint in constraints - ]) - - if filename is not None: - filename = filename.split('.')[0] - _formula = self.problem.encoding.get_formula() - _formula.extend(list(uniq_constraints) + [ - [lit] for lit in uniq_assumptions - ]) - _formula.to_file(f'{filename}_{len(searchables)}.wcnf') - - print('one lit:', len(uniq_assumptions)) - print('clauses:', len(uniq_constraints)) - - print('max weight:', self.max_weight) - print('best solution:', self.best_model) - return self.problem.solver.solve(formula, ( - list(uniq_assumptions), list(map(list, uniq_constraints)) - )) - - -__all__ = [ - 'GrowingT' -] diff --git a/function/impl/__init__.py b/function/impl/__init__.py index 2b34372..20ec535 100644 --- a/function/impl/__init__.py +++ b/function/impl/__init__.py @@ -1,11 +1,13 @@ from .function_gad import GuessAndDetermine -from .function_ibs import InverseBackdoorSets - +from .function_div import DivFunction from .function_rho import RhoFunction from .function_rho_t import RhoTFunction + +from .function_ibs import InverseBackdoorSets from .function_ips import InversePolynomialSets functions = { + DivFunction.slug: DivFunction, RhoFunction.slug: RhoFunction, RhoTFunction.slug: RhoTFunction, GuessAndDetermine.slug: GuessAndDetermine, diff --git a/function/impl/function_div.py b/function/impl/function_div.py new file mode 100644 index 0000000..7aa154c --- /dev/null +++ b/function/impl/function_div.py @@ -0,0 +1,82 @@ +from os import getpid +from time import time as now + +from util.iterable import list_of +from ..model import WorkerArgs, WorkerResult, \ + WorkerCallable, Payload, Results, Estimation +from .function_gad import GuessAndDetermine, gad_supplements +from ..abc.function import aggregate_results, format_statuses + +from ..module.budget import AutoBudget +from ..module.measure import SolvingTime + +from typings.searchable import Searchable + + +def div_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: + space, budget, measure, problem, bytemask = payload + searchable, timestamp = space.unpack(bytemask), now() + + _formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + + hard, soft, vector = [], [], searchable.get_vector() + for bit, clause in zip(vector, _formula.clauses): + (hard if bit else soft).append(clause) + + # todo: optimize + formula = _formula.weighted() + formula.wght = list_of(1, soft) + formula.topw = len(soft) + 1 + formula.hard = hard + formula.soft = soft + + for supplements in gad_supplements(args, problem, searchable): + report = problem.solver.solve(formula, supplements) + time, value, status = measure.check_and_get(report, budget) + + times[status.value] = times.get(status.value, 0.) + time + values[status.value] = values.get(status.value, 0.) + value + statuses[status.value] = statuses.get(status.value, 0) + 1 + + times2[status.value] = times2.get(status.value, 0.) + time ** 2 + values2[status.value] = values2.get(status.value, 0.) + value ** 2 + return getpid(), now() - timestamp, times, times2, values, values2, statuses, args + + +class DivFunction(GuessAndDetermine): + slug = 'function:div' + + def __init__(self, budget: AutoBudget): + super().__init__(budget, SolvingTime()) + + def get_worker_fn(self) -> WorkerCallable: + return div_worker_fn + + def calculate(self, searchable: Searchable, results: Results) -> Estimation: + times, values, statuses, stats = aggregate_results(results) + time_sum, value_sum = sum(times.values()), sum(values.values()) + + power = searchable.power() + value = value_sum if stats.count else float('inf') + if stats.count > 0 and stats.count != power: + value = float(value_sum) / stats.count * power + + return { + 'power': power, + 'count': stats.count, + 'value': round(value, 2), + 'ptime': round(stats.ptime_sum, 4), + 'time_sum': round(time_sum, 4), + 'time_avg': round(stats.time_avg, 6), + 'time_var': round(stats.time_var, 6), + 'statuses': format_statuses(statuses), + 'value_sum': round(value_sum, 4), + 'value_avg': round(stats.value_avg, 6), + 'value_var': round(stats.value_var, 6), + } + + +__all__ = [ + 'DivFunction' +] diff --git a/function/impl/function_gad.py b/function/impl/function_gad.py index 6d19bd8..922796b 100644 --- a/function/impl/function_gad.py +++ b/function/impl/function_gad.py @@ -47,8 +47,9 @@ def gad_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: searchable, timestamp = space.unpack(bytemask), now() # limit = measure.get_limit(budget) - times, times2, values, values2 = {}, {}, {}, {} - formula, statuses = problem.encoding.get_formula(), {} + formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + for supplements in gad_supplements(args, problem, searchable): report = problem.solver.solve(formula, supplements) time, value, status = measure.check_and_get(report, budget) diff --git a/function/impl/function_ibs.py b/function/impl/function_ibs.py index 524a3ef..616dfb5 100644 --- a/function/impl/function_ibs.py +++ b/function/impl/function_ibs.py @@ -37,8 +37,9 @@ def ibs_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: backdoor, timestamp = space.unpack(bytemask), now() limit = measure.get_limit(budget) - times, times2, values, values2 = {}, {}, {}, {} - formula, statuses = problem.encoding.get_formula(), {} + formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + for supplements in ibs_supplements(args, problem, backdoor): report = problem.solver.solve(formula, supplements, limit) time, value, status = measure.check_and_get(report, budget) diff --git a/function/impl/function_ips.py b/function/impl/function_ips.py index ef134ab..6716b50 100644 --- a/function/impl/function_ips.py +++ b/function/impl/function_ips.py @@ -16,8 +16,9 @@ def ips_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: space, budget, measure, problem, bytemask = payload searchable, timestamp = space.unpack(bytemask), now() - times, times2, values, values2 = {}, {}, {}, {} - formula, statuses = problem.encoding.get_formula(), {} + formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + with problem.solver.get_instance(formula) as incremental: for supplements in ibs_supplements(args, problem, searchable): report = incremental.propagate(supplements) diff --git a/function/impl/function_rho.py b/function/impl/function_rho.py index 1681425..7ec8086 100644 --- a/function/impl/function_rho.py +++ b/function/impl/function_rho.py @@ -11,24 +11,34 @@ from typings.searchable import Searchable +solvers = {} + def rho_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: space, budget, measure, problem, bytemask = payload searchable, timestamp = space.unpack(bytemask), now() - times, times2, values, values2 = {}, {}, {}, {} - formula, statuses = problem.encoding.get_formula(), {} - with problem.solver.get_instance(formula) as incremental: - for supplements in gad_supplements(args, problem, searchable): - report = incremental.propagate(supplements) - time, value, status = measure.check_and_get(report, budget) + formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + + key = problem.encoding._get_formula_key() + if key not in solvers: + test_stamp = now() + solvers[key] = problem.solver.get_instance(formula) + print('created instance:', now() - test_stamp) + + # with problem.solver.get_instance(formula) as incremental: + for supplements in gad_supplements(args, problem, searchable): + report = solvers[key].propagate(supplements) + time, value, status = measure.check_and_get(report, budget) + + times[status.value] = times.get(status.value, 0.) + time + values[status.value] = values.get(status.value, 0.) + value + statuses[status.value] = statuses.get(status.value, 0) + 1 - times[status.value] = times.get(status.value, 0.) + time - values[status.value] = values.get(status.value, 0.) + value - statuses[status.value] = statuses.get(status.value, 0) + 1 + times2[status.value] = times2.get(status.value, 0.) + time ** 2 + values2[status.value] = values2.get(status.value, 0.) + value ** 2 - times2[status.value] = times2.get(status.value, 0.) + time ** 2 - values2[status.value] = values2.get(status.value, 0.) + value ** 2 return getpid(), now() - timestamp, times, times2, values, values2, statuses, args @@ -59,6 +69,7 @@ def calculate(self, searchable: Searchable, results: Results) -> Estimation: value = rho_value * power + penalty_value return { + 'power': power, 'count': stats.count, 'value': round(value, 2), 'ptime': round(stats.ptime_sum, 4), diff --git a/function/impl/function_rho_t.py b/function/impl/function_rho_t.py index 9d8905b..f70943f 100644 --- a/function/impl/function_rho_t.py +++ b/function/impl/function_rho_t.py @@ -17,8 +17,9 @@ def tau_worker_fn(args: WorkerArgs, payload: Payload) -> WorkerResult: searchable, timestamp = space.unpack(bytemask), now() limit = measure.get_limit(budget) - times, times2, values, values2 = {}, {}, {}, {} - formula, statuses = problem.encoding.get_formula(), {} + formula = problem.encoding.get_formula(copy=False) + statuses, times, times2, values, values2 = {}, {}, {}, {}, {} + with problem.solver.get_instance(formula) as incremental: for supplements in gad_supplements(args, problem, searchable): # todo: clear interrupt in incremental diff --git a/lib_satprob/solver/impl/__init__.py b/lib_satprob/solver/impl/__init__.py index f3eed02..c75e6b6 100644 --- a/lib_satprob/solver/impl/__init__.py +++ b/lib_satprob/solver/impl/__init__.py @@ -1,7 +1,9 @@ from .pysat import * from .py2sat import * +from .external import * solvers = { + Kissat.slug: Kissat, PySatSolver.slug: PySatSolver, Py2SatSolver.slug: Py2SatSolver, } diff --git a/lib_satprob/solver/impl/external.py b/lib_satprob/solver/impl/external.py new file mode 100644 index 0000000..a93157f --- /dev/null +++ b/lib_satprob/solver/impl/external.py @@ -0,0 +1,221 @@ +import os +import re + +from time import time as now +from typing import List, Dict +from tempfile import NamedTemporaryFile +from subprocess import PIPE, Popen, TimeoutExpired + +from util.iterable import concat +from ...variables import Supplements + +from .pysat import PySatSetts, \ + FormulaError, _PySatSolver, PySatSolver +from ..solver import Report, KeyLimit, UNLIMITED +from ...encoding import PySatFormula, SatFormula, \ + MaxSatFormula, is_sat_formula, is_max_sat_formula + +STATUSES = { + 10: True, + 20: False +} + + +class _ExternalSolver(_PySatSolver): + limits = {} + statistic = {} + stdin_file = None + stdout_file = None + + def __init__( + self, + formula: PySatFormula, + settings: PySatSetts, + from_executable: str, + use_timer: bool = True, + ): + super().__init__(formula, settings, use_timer) + self.from_executable = from_executable + + def _parse_stats(self, output: str) -> Dict[str, int]: + def get_number(res): + return res and int(res.group(1)) + + return { + key: get_number(p.search(output)) + for key, p in self.statistic.items() + } + + def _parse_solution(self, output: str) -> List[int]: + raise NotImplementedError + + def solve( + self, supplements: Supplements, + limit: KeyLimit = UNLIMITED, + extract_model: bool = True, + ) -> Report: + files, source = [], None + launch_args = [self.from_executable] + assumptions, constraints = supplements + + str_clauses = [ + *map(str, assumptions), + *(' '.join(map(str, cl)) + for cl in constraints) + ] + + if is_sat_formula(self.formula): + str_supplements = '\n'.join([ + f'{cl} 0' for cl in str_clauses + ]) + '\n' + formula_len = sum(( + len(str_clauses), + len(self.formula.clauses) + )) + elif is_max_sat_formula(self.formula): + str_supplements = '\n'.join([ + f'{self.formula.topw} {cl} 0' + for cl in str_clauses + ]) + '\n' + formula_len = sum(( + len(str_clauses), + len(self.formula.hard), + len(self.formula.soft), + )) + else: + raise FormulaError(self.formula) + + if self.stdin_file is not None: + with NamedTemporaryFile( + delete=False, mode='w+' + ) as in_file: + files.append(in_file.name) + self.formula.to_fp(in_file) + in_file.write(str_supplements) + launch_args.append( + self.stdin_file % in_file.name + ) + else: + source = self.formula.to_dimacs() + source += str_supplements + + if self.stdout_file is not None: + with NamedTemporaryFile( + delete=False, mode='w+' + ) as out_file: + files.append(out_file.name) + launch_args.append( + self.stdout_file % out_file.name + ) + + timeout, (key, value) = None, limit + if value is not None and key == 'time': + timeout = value + formula_len * 2e-06 + if value is not None and key in self.limits: + launch_args.append(self.limits[key] % value) + + timestamp, process = now(), Popen( + launch_args, stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + try: + data = None if source is None else source.encode() + output, error = process.communicate(data, timeout) + + # todo: handle error + if self.stdout_file is not None: + with open(files[-1], 'r') as handle: + output = handle.read() + else: + output = output.decode() + + stats = self._parse_stats(output) + stats['time'] = now() - timestamp + + status = STATUSES.get(process.returncode) + solution = self._parse_solution(output) \ + if extract_model and status else None + except TimeoutExpired: + process.terminate() + status, solution = None, None + stats = {'time': now() - timestamp} + finally: + [os.remove(file) for file in files] + + return Report(status, stats, solution, stats.get('cost')) + + +class ExternalSolver(PySatSolver): + slug = 'solver:external' + + def __init__(self, from_executable: str, pysat_propagator: str = 'm22'): + super().__init__(sat_name=pysat_propagator) + self.from_executable = from_executable + + def get_instance( + self, formula: PySatFormula, use_timer: bool = True + ) -> _ExternalSolver: + raise NotImplementedError + + +class _Kissat(_ExternalSolver): + limits = { + 'time': '--time=%d', + 'conflicts': '--conflicts=%d', + 'decisions': '--decisions=%d', + } + statistic = { + 'restarts': re.compile(r'^c restarts:\s+(\d+)', re.MULTILINE), + 'conflicts': re.compile(r'^c conflicts:\s+(\d+)', re.MULTILINE), + 'decisions': re.compile(r'^c decisions:\s+(\d+)', re.MULTILINE), + 'propagations': re.compile(r'^c propagations:\s+(\d+)', re.MULTILINE) + } + + def _parse_solution(self, output: str) -> List[int]: + return concat(*( + [int(var) for var in line.split()] for line in + re.findall(r'^v ([-\d ]*)', output, re.MULTILINE) + )) + + +class Kissat(ExternalSolver): + slug = 'solver:external:kissat' + + def get_instance( + self, formula: SatFormula, use_timer: bool = True + ) -> _Kissat: + return _Kissat( + formula, self.settings, self.from_executable, use_timer + ) + + +class _Loandra(_ExternalSolver): + stdin_file = '%s' + statistic = { + 'cost': re.compile(r'^v (\d*) \d*', re.MULTILINE) + } + + def _parse_solution(self, output: str) -> List[int]: + line = re.findall(r'^v \d* (\d*)', output, re.MULTILINE)[-1] + return [i + 1 if c else -(i + 1) for i, c in enumerate(line)] + + +class Loandra(ExternalSolver): + slug = 'solver:external:loandra' + + def get_instance( + self, formula: MaxSatFormula, use_timer: bool = True + ) -> _Loandra: + return _Loandra( + formula, self.settings, self.from_executable, use_timer + ) + + +__all__ = [ + 'Kissat', + '_Kissat', + 'Loandra', + '_Loandra', + # types + 'ExternalSolver', + '_ExternalSolver', +] diff --git a/lib_satprob/solver/impl/py2sat.py b/lib_satprob/solver/impl/py2sat.py index 2c34eb3..833b28e 100644 --- a/lib_satprob/solver/impl/py2sat.py +++ b/lib_satprob/solver/impl/py2sat.py @@ -1,13 +1,12 @@ from time import time as now -from ...encoding import Clause -from ...variables import Supplements -from ...variables.vars import VarMap - from .pysat import PySatSetts, \ _PySatSolver, PySatSolver + from ..solver import Report -from ...encoding import SatFormula +from ...variables import Supplements +from ...variables.vars import VarMap +from ...encoding import Clause, SatFormula def is2clause(clause: Clause, var_map: VarMap) -> bool: @@ -47,7 +46,7 @@ def propagate( stamp, no2clause = now() - stats['time'], 0 var_map = {abs(lit): lit for lit in literals} - for clause in self.formula: + for clause in self.formula.clauses: no2clause += not is2clause(clause, var_map) if no2clause > self.limit: break else: diff --git a/lib_satprob/solver/impl/pysat.py b/lib_satprob/solver/impl/pysat.py index c6d8496..de1511e 100644 --- a/lib_satprob/solver/impl/pysat.py +++ b/lib_satprob/solver/impl/pysat.py @@ -7,10 +7,11 @@ from pysat import solvers as slv from pysat.examples.rc2 import RC2 -from ..solver import Solver, _Solver, Report, KeyLimit, UNLIMITED -from ...encoding import SatFormula, PySatFormula, MaxSatFormula, \ - to_sat_formula, is_sat_formula, is_max_sat_formula +from ..solver import Report, Solver, \ + _Solver, KeyLimit, UNLIMITED from ...variables import Assumptions, Supplements +from ...encoding import PySatFormula, MaxSatFormula, \ + to_sat_formula, is_sat_formula, is_max_sat_formula # @@ -190,7 +191,6 @@ def get_max_sat_alg(settings: PySatSetts, formula: MaxSatFormula): class _PySatSolver(_Solver): _solver = None _last_stats = {} - _propagator = None def __init__( self, @@ -306,4 +306,6 @@ def __config__(self) -> Dict[str, Any]: # types 'PySatTimer', 'PySatSetts', + # errors + 'FormulaError' ] diff --git a/space/_utility.py b/space/_utility.py index 8576005..e1730e7 100644 --- a/space/_utility.py +++ b/space/_utility.py @@ -1,4 +1,5 @@ from numpy import argsort +from util.polyfill import prod from util.iterable import pick_by from lib_satprob.problem import Problem @@ -6,13 +7,20 @@ def rho_subset( - problem: Problem, variables: Variables, of_size: int + problem: Problem, variables: Variables, + of_size: int = None, by_weight: int = None ) -> Variables: formula = problem.encoding.get_formula() with problem.solver.get_instance(formula, use_timer=False) as solver: - _indexes = argsort([sum(( + print('no assumptions:', solver.propagate(([], [])).stats) + _weights = [prod(( solver.propagate(var.substitute({var: 0})).stats['propagations'], solver.propagate(var.substitute({var: 1})).stats['propagations'], - )) for var in variables])[::-1][:of_size] + )) for var in variables] + print('weights:', set(_weights)) + _indexes = argsort(_weights)[::-1][:of_size] if by_weight is None else [ + _i for _i, _weight in enumerate(_weights) if _weight >= by_weight + ] + print(len(_indexes)) return Variables(from_vars=pick_by(variables.variables(), _indexes)) diff --git a/space/impl/interval_set.py b/space/impl/interval_set.py index b00a1c8..d6da19d 100644 --- a/space/impl/interval_set.py +++ b/space/impl/interval_set.py @@ -3,7 +3,6 @@ from ..abc import Space from ..model import Interval -from lib_satprob.problem import Problem from typings.searchable import Vector from lib_satprob.variables import Indexes @@ -12,8 +11,7 @@ class IntervalSet(Space): slug = 'space:interval_set' def __init__( - self, - indexes: Indexes, + self, indexes: Indexes, by_vector: Optional[Vector] = None ): super().__init__(by_vector) diff --git a/space/impl/partition_set.py b/space/impl/partition_set.py new file mode 100644 index 0000000..a3618a6 --- /dev/null +++ b/space/impl/partition_set.py @@ -0,0 +1,41 @@ +from typing import Dict, Any, Optional + +from ..abc import Space + +from typings.searchable import Vector +from ..model import Interval, Partition + + +class PartitionSet(Space): + slug = 'space:partition_set' + + def __init__( + self, length: int, interval: Interval, + by_vector: Optional[Vector] = None + ): + super().__init__(by_vector) + self.interval = interval + self.length = length + + # noinspection PyProtectedMember + def get_initial(self) -> Partition: + interval = self._get_searchable() + if self.by_vector is not None: + interval._set_vector(self.by_vector) + return interval + + def _get_searchable(self) -> Partition: + return Partition(self.length, self.interval) + + def __config__(self) -> Dict[str, Any]: + return { + 'slug': self.slug, + 'length': self.length, + 'interval': self.interval.__config__(), + 'by_vector': self.by_vector, + } + + +__all__ = [ + 'PartitionSet' +] diff --git a/space/model/__init__.py b/space/model/__init__.py index 6c1aeac..40eca06 100644 --- a/space/model/__init__.py +++ b/space/model/__init__.py @@ -1,2 +1,3 @@ from .backdoor import Backdoor from .interval import Interval +from .partition import Partition diff --git a/space/model/partition.py b/space/model/partition.py new file mode 100644 index 0000000..5188f60 --- /dev/null +++ b/space/model/partition.py @@ -0,0 +1,56 @@ +from typing import Any, List, Dict, Optional + +from lib_satprob.variables import Supplements +from lib_satprob.variables.vars import VarMap, Var + +from .interval import Interval +from typings.searchable import Searchable + + +class Partition(Searchable): + slug = 'searchable:partition' + + def __init__( + self, length: int, + interval: Interval + ): + super().__init__(length=length) + self._interval = interval + + def power(self) -> int: + return self._interval.power() + + def dimension(self) -> List[int]: + return self._interval.dimension() + + def variables(self) -> List[Var]: + return self._interval.variables() + + def substitute(self, using_values: Optional[List[int]] = None, + using_var_map: Optional[VarMap] = None) -> Supplements: + return self._interval.substitute(using_values, using_var_map) + + def __len__(self) -> int: + return len(self._interval) + + def __str__(self) -> str: + return ''.join(map(str, self._vector)) + + def __repr__(self) -> str: + return f'[{str(self)}]({sum(self._vector)})' + + def __copy__(self) -> 'Partition': + return Partition(self._length, self._interval) + + def __eq__(self, other: 'Partition') -> bool: + return str(self) == str(other) + + def __config__(self) -> Dict[str, Any]: + return { + 'slug': self.slug, + 'interval': self._interval.__config__() + } + +__all__ = [ + 'Partition' +] \ No newline at end of file diff --git a/util/iterable.py b/util/iterable.py index bb230e7..a122fcc 100644 --- a/util/iterable.py +++ b/util/iterable.py @@ -39,10 +39,12 @@ def pick_by(iterable: Iterable[T], predicate: Predicate = identity) -> List[T]: if isinstance(predicate, Callable): return [item for item in iterable if predicate(item)] elif isinstance(predicate, Iterable): + predicate = set(predicate) return [item for i, item in enumerate(iterable) if i in predicate] else: raise TypeError( - f'unexpected predicate type: \'{type(predicate).__name__}\'') + f'unexpected predicate type: \'{type(predicate).__name__}\'' + ) def omit_by(iterable: Iterable[T], predicate: Predicate = identity) -> List[T]: @@ -65,7 +67,8 @@ def slice_into(sized: List[T], count: int) -> Iterable[List[T]]: def split_by( - iterable: Iterable[T], predicate: Predicate = identity + iterable: Iterable[T], + predicate: Predicate = identity ) -> Tuple[List[T], List[T]]: left, right = [], [] for i, item in enumerate(iterable): diff --git a/util/wrapppers.py b/util/wrapppers.py new file mode 100644 index 0000000..1d0ea9b --- /dev/null +++ b/util/wrapppers.py @@ -0,0 +1,33 @@ +import time + +from functools import partial +from typing import Callable, Any, TypeVar, NamedTuple + +R = TypeVar('R', covariant=True) + + +class Timed(NamedTuple): + result: R + time: float + + +def _timed( + fn: Callable[[Any, ...], R], + *args: Any, **kwargs: Any +): + stamp = time.time() + return Timed( + fn(*args, **kwargs), + time.time() - stamp + ) + + +def timed( + fn: Callable[[Any, ...], R] +) -> Callable[[Any, ...], Timed]: + return partial(_timed, fn) + + +__all__ = [ + 'timed' +]