diff --git a/dev_requirements.txt b/dev_requirements.txt index 9efe4b3..a6453dd 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,2 +1,3 @@ pytest==8.2.1 pytest-cov==5.0.0 +ruff==0.4.8 diff --git a/dvrprocess/comtune.py b/dvrprocess/comtune.py index 7c0b319..832eea5 100755 --- a/dvrprocess/comtune.py +++ b/dvrprocess/comtune.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 -import atexit import configparser import copy import getopt import hashlib import json import logging + +import cloudpickle import math import os.path import random @@ -20,7 +21,7 @@ from math import ceil from multiprocessing import Pool from statistics import stdev, mean, median -from typing import Union +from typing import Union, Any import numpy import pygad @@ -625,9 +626,9 @@ def setup_gad(process_pool: Pool, thread_pool: ThreadPoolExecutor, files, workdi # construct list of genes # TODO: support locking genes, i.e. detect_method if we need to exclude methods we know are broken for the recording genes_all = GENES.copy() - if episode_common_duration == 30*60: + if episode_common_duration == 30 * 60: genes_all.extend(GENES_30) - elif episode_common_duration == 60*60: + elif episode_common_duration == 60 * 60: genes_all.extend(GENES_60) genes = list( filter(lambda g: (experimental or not g.experimental) and g.space_has_elements() and ( @@ -713,7 +714,8 @@ def f(gad: pygad.GA, solution, solution_idx): # added_recording may be a gene, so we need to calculate it for each run if added_recording_gene_idx >= 0: - expected_adjusted_duration = common.round_episode_duration(dvr_infos_sample[0]) - ((int(solution[added_recording_gene_idx])+1) * 60.0) + expected_adjusted_duration = common.round_episode_duration(dvr_infos_sample[0]) - ( + (int(solution[added_recording_gene_idx]) + 1) * 60.0) else: expected_adjusted_duration = expected_adjusted_duration_default @@ -891,6 +893,38 @@ def find_comskip_starter_ini(): raise OSError(f"Cannot find comskip-starter.ini in any of {','.join(get_comskip_starter_ini_sources())}") +GA_INSTANCE_ATTR_SAVE = [ + 'best_solutions', + 'best_solutions_fitness', + 'solutions', + 'solutions_fitness', + 'last_generation_fitness', + 'last_generation_parents', + 'last_generation_offspring_crossover', + 'last_generation_offspring_mutation', + 'previous_generation_fitness', + 'last_generation_elitism', + 'last_generation_elitism_indices', + 'pareto_fronts', +] + + +def ga_instance_save(ga_instance: pygad.GA, filename): + gad_state = { + 'num_generations': ga_instance.num_generations, + 'generations_completed': ga_instance.generations_completed, + 'population': ga_instance.population, + } + + for attr_name in GA_INSTANCE_ATTR_SAVE: + gad_state[attr_name] = getattr(ga_instance, attr_name) + + with open(filename, 'wb') as file: + cloudpickle.dump(gad_state, file) + + return None + + def generate_initial_solutions(genes: list[ComskipGene], values_in: dict[ComskipGene, list]) -> list[list]: # remove values that are not in the list of genes values: dict[ComskipGene, list] = dict() @@ -965,6 +999,14 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex target_comskip_ini = os.path.join(season_dir, 'comskip.ini') + gad_name_parts = ['gad'] + if expensive_genes: + gad_name_parts.append('expensive_genes') + if experimental: + gad_name_parts.append('experimental') + gad_state_filename = os.path.join(season_dir, f"{'-'.join(gad_name_parts)}.pkl") + gad_state_filename_tmp = gad_state_filename + ".tmp" + # https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class num_generations = 100 sol_per_pop = 500 @@ -1003,6 +1045,32 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex for s in initial_solutions: logger.info("Initial solution : {solution}".format(solution=solution_repl(genes, s))) + convergence_gauge = progress.gauge("CONV") + convergence_gauge.renderer = lambda v: f"{v:.2f}" + + def gen_callback(ga_instance: pygad.GA): + best_fitness = ga_instance.best_solutions_fitness + if len(best_fitness) > 1: + convergence_gauge.value(stdev(best_fitness)) + + ga_instance_save(ga_instance, gad_state_filename_tmp) + shutil.move(gad_state_filename_tmp, gad_state_filename) + + return None + + try: + ga_in: dict[str, Any] = dict() + with open(gad_state_filename, 'rb') as file: + ga_in = cloudpickle.load(file) + num_generations = max(1, int(ga_in.get('num_generations', num_generations)) - int(ga_in.get('generations_completed', 0))) + initial_solutions = ga_in.get('population', initial_solutions) + logging.info("Resuming from %s, %d generations left", gad_state_filename, num_generations) + except FileNotFoundError: + pass + except BaseException as e: + logging.debug("Error resuming from %s, starting over", gad_state_filename, e) + logging.warning("Error resuming from %s, starting over", gad_state_filename) + # ensure we have the desired population size additional_solutions_needed = max(0, sol_per_pop - len(initial_solutions)) if additional_solutions_needed > 0: @@ -1022,15 +1090,6 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex else: initial_population = initial_solutions - convergence_gauge = progress.gauge("CONV") - convergence_gauge.renderer = lambda v: f"{v:.2f}" - - def gen_callback(ga_instance: pygad.GA): - best_fitness = ga_instance.best_solutions_fitness - if len(best_fitness) > 1: - convergence_gauge.value(stdev(best_fitness)) - return None - # https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class ga_instance = pygad.GA(num_generations=num_generations, num_parents_mating=num_parents_mating, @@ -1048,6 +1107,11 @@ def gen_callback(ga_instance: pygad.GA): suppress_warnings=True, on_generation=gen_callback, ) + + # restore state + for attr_name in filter(lambda e: e in ga_in, GA_INSTANCE_ATTR_SAVE): + setattr(ga_instance, attr_name, ga_in[attr_name]) + ga_instance.run() tuning_progress.stop() solution, solution_fitness, solution_idx = ga_instance.best_solution() @@ -1060,6 +1124,8 @@ def gen_callback(ga_instance: pygad.GA): solution_fitness=(1.0 / solution_fitness))) logger.debug("Best solutions = {best}\n fitness = {fitness}".format(best=ga_instance.best_solutions, fitness=ga_instance.best_solutions_fitness)) + logger.info("Best fitness reached at generation %d", ga_instance.best_solution_generation) + solution = solution.copy() gene_values = [] for idx in range(len(genes)): @@ -1093,6 +1159,9 @@ def gen_callback(ga_instance: pygad.GA): if fitness_json_path and os.path.exists(fitness_json_path): shutil.move(fitness_json_path, os.path.join(season_dir, 'fitness.json')) + if os.path.isfile(gad_state_filename): + os.remove(gad_state_filename) + return_code = common.ReturnCodeReducer() for filepath in files: process_pool.apply_async(common.pool_apply_wrapper(comchap), (filepath, filepath), diff --git a/requirements.txt b/requirements.txt index d3180a0..a2a7386 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,7 @@ plexapi==4.15.13 psutil==5.9.8 requests==2.32.3 pygad==3.3.1 -# specific version of numpy required for pygad 2023-01-18 -# numpy==1.23.5 +# vosk==0.3.45 doesn't install correctly vosk==0.3.44 websockets==12.0 IMDbPY==2022.7.9