Skip to content

Commit

Permalink
add save/resume to comtune.py (#60)
Browse files Browse the repository at this point in the history
* add save/resume to comtune.py
  • Loading branch information
double16 authored Jun 12, 2024
1 parent 9be3291 commit ed32d4a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 16 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pytest==8.2.1
pytest-cov==5.0.0
ruff==0.4.8
97 changes: 83 additions & 14 deletions dvrprocess/comtune.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/env python3
import atexit
import configparser
import copy
import getopt
import hashlib
import json
import logging

import cloudpickle
import math
import os.path
import random
Expand All @@ -20,7 +21,7 @@
from math import ceil
from multiprocessing import Pool
from statistics import stdev, mean, median
from typing import Union
from typing import Union, Any

import numpy
import pygad
Expand Down Expand Up @@ -625,9 +626,9 @@ def setup_gad(process_pool: Pool, thread_pool: ThreadPoolExecutor, files, workdi
# construct list of genes
# TODO: support locking genes, i.e. detect_method if we need to exclude methods we know are broken for the recording
genes_all = GENES.copy()
if episode_common_duration == 30*60:
if episode_common_duration == 30 * 60:
genes_all.extend(GENES_30)
elif episode_common_duration == 60*60:
elif episode_common_duration == 60 * 60:
genes_all.extend(GENES_60)
genes = list(
filter(lambda g: (experimental or not g.experimental) and g.space_has_elements() and (
Expand Down Expand Up @@ -713,7 +714,8 @@ def f(gad: pygad.GA, solution, solution_idx):

# added_recording may be a gene, so we need to calculate it for each run
if added_recording_gene_idx >= 0:
expected_adjusted_duration = common.round_episode_duration(dvr_infos_sample[0]) - ((int(solution[added_recording_gene_idx])+1) * 60.0)
expected_adjusted_duration = common.round_episode_duration(dvr_infos_sample[0]) - (
(int(solution[added_recording_gene_idx]) + 1) * 60.0)
else:
expected_adjusted_duration = expected_adjusted_duration_default

Expand Down Expand Up @@ -891,6 +893,38 @@ def find_comskip_starter_ini():
raise OSError(f"Cannot find comskip-starter.ini in any of {','.join(get_comskip_starter_ini_sources())}")


GA_INSTANCE_ATTR_SAVE = [
'best_solutions',
'best_solutions_fitness',
'solutions',
'solutions_fitness',
'last_generation_fitness',
'last_generation_parents',
'last_generation_offspring_crossover',
'last_generation_offspring_mutation',
'previous_generation_fitness',
'last_generation_elitism',
'last_generation_elitism_indices',
'pareto_fronts',
]


def ga_instance_save(ga_instance: pygad.GA, filename):
gad_state = {
'num_generations': ga_instance.num_generations,
'generations_completed': ga_instance.generations_completed,
'population': ga_instance.population,
}

for attr_name in GA_INSTANCE_ATTR_SAVE:
gad_state[attr_name] = getattr(ga_instance, attr_name)

with open(filename, 'wb') as file:
cloudpickle.dump(gad_state, file)

return None


def generate_initial_solutions(genes: list[ComskipGene], values_in: dict[ComskipGene, list]) -> list[list]:
# remove values that are not in the list of genes
values: dict[ComskipGene, list] = dict()
Expand Down Expand Up @@ -965,6 +999,14 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex

target_comskip_ini = os.path.join(season_dir, 'comskip.ini')

gad_name_parts = ['gad']
if expensive_genes:
gad_name_parts.append('expensive_genes')
if experimental:
gad_name_parts.append('experimental')
gad_state_filename = os.path.join(season_dir, f"{'-'.join(gad_name_parts)}.pkl")
gad_state_filename_tmp = gad_state_filename + ".tmp"

# https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class
num_generations = 100
sol_per_pop = 500
Expand Down Expand Up @@ -1003,6 +1045,32 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex
for s in initial_solutions:
logger.info("Initial solution : {solution}".format(solution=solution_repl(genes, s)))

convergence_gauge = progress.gauge("CONV")
convergence_gauge.renderer = lambda v: f"{v:.2f}"

def gen_callback(ga_instance: pygad.GA):
best_fitness = ga_instance.best_solutions_fitness
if len(best_fitness) > 1:
convergence_gauge.value(stdev(best_fitness))

ga_instance_save(ga_instance, gad_state_filename_tmp)
shutil.move(gad_state_filename_tmp, gad_state_filename)

return None

try:
ga_in: dict[str, Any] = dict()
with open(gad_state_filename, 'rb') as file:
ga_in = cloudpickle.load(file)
num_generations = max(1, int(ga_in.get('num_generations', num_generations)) - int(ga_in.get('generations_completed', 0)))
initial_solutions = ga_in.get('population', initial_solutions)
logging.info("Resuming from %s, %d generations left", gad_state_filename, num_generations)
except FileNotFoundError:
pass
except BaseException as e:
logging.debug("Error resuming from %s, starting over", gad_state_filename, e)
logging.warning("Error resuming from %s, starting over", gad_state_filename)

# ensure we have the desired population size
additional_solutions_needed = max(0, sol_per_pop - len(initial_solutions))
if additional_solutions_needed > 0:
Expand All @@ -1022,15 +1090,6 @@ def tune_show(season_dir, process_pool: Pool, files, workdir, dry_run, force, ex
else:
initial_population = initial_solutions

convergence_gauge = progress.gauge("CONV")
convergence_gauge.renderer = lambda v: f"{v:.2f}"

def gen_callback(ga_instance: pygad.GA):
best_fitness = ga_instance.best_solutions_fitness
if len(best_fitness) > 1:
convergence_gauge.value(stdev(best_fitness))
return None

# https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class
ga_instance = pygad.GA(num_generations=num_generations,
num_parents_mating=num_parents_mating,
Expand All @@ -1048,6 +1107,11 @@ def gen_callback(ga_instance: pygad.GA):
suppress_warnings=True,
on_generation=gen_callback,
)

# restore state
for attr_name in filter(lambda e: e in ga_in, GA_INSTANCE_ATTR_SAVE):
setattr(ga_instance, attr_name, ga_in[attr_name])

ga_instance.run()
tuning_progress.stop()
solution, solution_fitness, solution_idx = ga_instance.best_solution()
Expand All @@ -1060,6 +1124,8 @@ def gen_callback(ga_instance: pygad.GA):
solution_fitness=(1.0 / solution_fitness)))
logger.debug("Best solutions = {best}\n fitness = {fitness}".format(best=ga_instance.best_solutions,
fitness=ga_instance.best_solutions_fitness))
logger.info("Best fitness reached at generation %d", ga_instance.best_solution_generation)

solution = solution.copy()
gene_values = []
for idx in range(len(genes)):
Expand Down Expand Up @@ -1093,6 +1159,9 @@ def gen_callback(ga_instance: pygad.GA):
if fitness_json_path and os.path.exists(fitness_json_path):
shutil.move(fitness_json_path, os.path.join(season_dir, 'fitness.json'))

if os.path.isfile(gad_state_filename):
os.remove(gad_state_filename)

return_code = common.ReturnCodeReducer()
for filepath in files:
process_pool.apply_async(common.pool_apply_wrapper(comchap), (filepath, filepath),
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ plexapi==4.15.13
psutil==5.9.8
requests==2.32.3
pygad==3.3.1
# specific version of numpy required for pygad 2023-01-18
# numpy==1.23.5
# vosk==0.3.45 doesn't install correctly
vosk==0.3.44
websockets==12.0
IMDbPY==2022.7.9
Expand Down

0 comments on commit ed32d4a

Please sign in to comment.