From d29472ca4d7016967b5ae920ac4846ccfbe8c317 Mon Sep 17 00:00:00 2001 From: Payam Jome Yazdian Date: Tue, 6 Feb 2024 23:46:53 -0800 Subject: [PATCH] Add files via upload --- scripts/utils/Unityfier.py | 122 +++++------ scripts/utils/average_meter.py | 126 ++++++------ scripts/utils/data_utils.py | 242 +++++++++++----------- scripts/utils/data_utils_twh.py | 230 ++++++++++----------- scripts/utils/train_utils.py | 352 ++++++++++++++++---------------- scripts/utils/vocab_utils.py | 214 +++++++++---------- 6 files changed, 642 insertions(+), 644 deletions(-) diff --git a/scripts/utils/Unityfier.py b/scripts/utils/Unityfier.py index f2d4e95..cc40568 100644 --- a/scripts/utils/Unityfier.py +++ b/scripts/utils/Unityfier.py @@ -1,61 +1,61 @@ -"""This file converts transcript to a txt file that can be read by the unity. - -This script is meant to be run separately. -This file expects JSON files. jsons_path variable must be manually changed to the appropriate directory. -The expected structure of the original JSON file is as follow: - Subtitles are expected to be contained in a JSON file with the format: - { - 'alternative': - []: # only the first element contains the following: - { - 'words': [ - { - 'start_time': '0.100s', - 'end_time': '0.500s', - 'word': 'really' - }, - { } - ], - - } - } - Note: JSON uses double-quotes instead of single-quotes. Single quotes are used for doc-string reasons. - -The files are saved in a new folder in the jsons_path directory named 'Unity'. -""" - - -import glob -import os -from data_utils import SubtitleWrapper - -jsons_path = "/local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Test_data/Transcripts" - -json_output_path = jsons_path + "/Unity" -if not os.path.exists(json_output_path): - os.makedirs(json_output_path) - - -json_files = sorted(glob.glob(jsons_path + "/*.json")) - - -for jfile in json_files: - name = os.path.split(jfile)[1][:-5] - print(name) - - subtitle = SubtitleWrapper(jfile).get() - str_subtitle = "" - for word_boundle in subtitle: - start_time = word_boundle["start_time"][:-1] # Removing 's' - end_time = word_boundle["end_time"][:-1] # Removing 's' - word = word_boundle["word"] - str_subtitle += "{},{},{}\n".format(start_time, end_time, word) - - str_subtitle = str_subtitle[:-1] - - file2write = open(json_output_path + "/" + name + ".txt", "w") - file2write.write(str_subtitle) - file2write.flush() - file2write.close() - - print() +"""This file converts transcript to a txt file that can be read by the unity. + +This script is meant to be run separately. +This file expects JSON files. jsons_path variable must be manually changed to the appropriate directory. +The expected structure of the original JSON file is as follow: + Subtitles are expected to be contained in a JSON file with the format: + { + 'alternative': + []: # only the first element contains the following: + { + 'words': [ + { + 'start_time': '0.100s', + 'end_time': '0.500s', + 'word': 'really' + }, + { } + ], + + } + } + Note: JSON uses double-quotes instead of single-quotes. Single quotes are used for doc-string reasons. + +The files are saved in a new folder in the jsons_path directory named 'Unity'. +""" + + +import glob +import os +from data_utils import SubtitleWrapper + +jsons_path = "/local-scratch/pjomeyaz/GENEA_DATASET/trinityspeechgesture.scss.tcd.ie/data/GENEA_Challenge_2020_data_release/Test_data/Transcripts" + +json_output_path = jsons_path + "/Unity" +if not os.path.exists(json_output_path): + os.makedirs(json_output_path) + + +json_files = sorted(glob.glob(jsons_path + "/*.json")) + + +for jfile in json_files: + name = os.path.split(jfile)[1][:-5] + print(name) + + subtitle = SubtitleWrapper(jfile).get() + str_subtitle = "" + for word_boundle in subtitle: + start_time = word_boundle["start_time"][:-1] # Removing 's' + end_time = word_boundle["end_time"][:-1] # Removing 's' + word = word_boundle["word"] + str_subtitle += "{},{},{}\n".format(start_time, end_time, word) + + str_subtitle = str_subtitle[:-1] + + file2write = open(json_output_path + "/" + name + ".txt", "w") + file2write.write(str_subtitle) + file2write.flush() + file2write.close() + + print() diff --git a/scripts/utils/average_meter.py b/scripts/utils/average_meter.py index 5ecaad9..eca60d0 100644 --- a/scripts/utils/average_meter.py +++ b/scripts/utils/average_meter.py @@ -1,63 +1,63 @@ -"""Class to hold average and current values (such as loss values). - -Typical usage example: - l = AverageMeter('autoencoder_loss') - l.update(2.33) -""" - - -class AverageMeter(object): - """Computes and stores the average and current value (such as loss values). - - Attributes: - name: A string for the name for an instance of this object. - fmt: A string that acts as a formatting value during printing (ex. :f). - val: A float that is the most current value to be processed. - avg: A float average value calculated using the most recent sum and count values. - sum: A float running total that has been processed. - count: An integer count of the number of times that val has been updated. - """ - - def __init__(self, name: str, fmt: str = ":f"): - """Initialization method. - - Args: - name: The string name for an instance of this object. - fmt: A string formatting value during printing (ex. :f). - """ - self.name = name - self.fmt = fmt - self.reset() - - def reset(self) -> None: - """Reset all numerical attributes in this object. - - Modifies internal state of this object. - """ - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val: float, n: int = 1) -> None: - """Updates the numerical attributes in this object with the provided value and count. - - Modifies internal state of this object. - - Args: - val: A float value to be used to update the calculations. - n: A custom count value (default value is 1). - """ - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self) -> str: - """Print a custom formatted string with the val and avg values. - - Returns: - The custom format string. - """ - fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" - return fmtstr.format(**self.__dict__) +"""Class to hold average and current values (such as loss values). + +Typical usage example: + l = AverageMeter('autoencoder_loss') + l.update(2.33) +""" + + +class AverageMeter(object): + """Computes and stores the average and current value (such as loss values). + + Attributes: + name: A string for the name for an instance of this object. + fmt: A string that acts as a formatting value during printing (ex. :f). + val: A float that is the most current value to be processed. + avg: A float average value calculated using the most recent sum and count values. + sum: A float running total that has been processed. + count: An integer count of the number of times that val has been updated. + """ + + def __init__(self, name: str, fmt: str = ":f"): + """Initialization method. + + Args: + name: The string name for an instance of this object. + fmt: A string formatting value during printing (ex. :f). + """ + self.name = name + self.fmt = fmt + self.reset() + + def reset(self) -> None: + """Reset all numerical attributes in this object. + + Modifies internal state of this object. + """ + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val: float, n: int = 1) -> None: + """Updates the numerical attributes in this object with the provided value and count. + + Modifies internal state of this object. + + Args: + val: A float value to be used to update the calculations. + n: A custom count value (default value is 1). + """ + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self) -> str: + """Print a custom formatted string with the val and avg values. + + Returns: + The custom format string. + """ + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) diff --git a/scripts/utils/data_utils.py b/scripts/utils/data_utils.py index 956d0f1..78ea2e7 100644 --- a/scripts/utils/data_utils.py +++ b/scripts/utils/data_utils.py @@ -1,121 +1,121 @@ -"""Utility file to read JSON files containing subtitles into an object. - -Typical usage example: - s = SubtitleWrapper('dataset/subtitles.json') - words = s.get() -""" - -import json -import re - - -def normalize_string(s: str) -> str: - """Standardize strings to a specific format. - - Standardize the string by: - - converting to lowercase, - - trim, and - - remove non alpha-numeric (except ,.!?) characters. - - Args: - s: The string to standardize. - - Returns: - A standardized version of the input string. - """ - s = s.lower().strip() - s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks - s = re.sub(r"(['])", r"", s) # remove apostrophe (i.e., shouldn't --> shouldnt) - s = re.sub( - r"[^a-zA-Z0-9,.!?]+", r" ", s - ) # replace other characters with whitespace - s = re.sub(r"\s+", r" ", s).strip() - return s - - -class SubtitleWrapper: - """Contains the subtitles converted from a JSON file. - - Subtitles are expected to be contained in a JSON file with the format: - { - 'alternative': - []: # only the first element contains the following: - { - 'words': [ - { - 'start_time': '0.100s', - 'end_time': '0.500s', - 'word': 'really' - }, - { } - ], - - } - } - Note: JSON uses double-quotes instead of single-quotes. Single quotes are used for doc-string reasons. - - Attributes: - subtitle: A list of strings containing all the subtitles in order. - """ - - TIMESTAMP_PATTERN = re.compile("(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})") - - def __init__(self, subtitle_path: str): - """Initialization method. - - Args: - subtitle_path: The string filepath to the subtitle (JSON) file. - """ - self.subtitle = [] - self.load_gentle_subtitle(subtitle_path) - - def get(self) -> list: - """Returns the subtitles as a list of words.""" - return self.subtitle - - def load_gentle_subtitle(self, subtitle_path: str) -> None: - """Loads a single subtitle file into this object. - - Modifies the internal state of this object. - The subtitles are loaded in order with a single word appended as a single element. - - Args: - substitle_path: The string filepath to the subtitle (JSON) file. - - Raises: - An exception if the specified file cannot be found. - """ - try: - with open(subtitle_path) as data_file: - data = json.load(data_file) - for item in data: - if "words" in item["alternatives"][0]: - raw_subtitle = item["alternatives"][0]["words"] - for word in raw_subtitle: - self.subtitle.append(word) - except FileNotFoundError: - self.subtitle = None - - def get_seconds(self, word_time_e: str) -> float: - """Convert a timestamp into seconds. - - Args: - word_time_e: The timestamp as a string (ex. hrs:mins:secs.milli - 02:02:02.125). - - Returns: - The timestamp as a float seconds starting from zero. - """ - time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) - if not time_value: - print("wrong time stamp pattern") - exit() - - values = list(map(lambda x: int(x) if x else 0, time_value.groups())) - hours, minutes, seconds, milliseconds = ( - values[0], - values[1], - values[2], - values[3], - ) - - return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 +"""Utility file to read JSON files containing subtitles into an object. + +Typical usage example: + s = SubtitleWrapper('dataset/subtitles.json') + words = s.get() +""" + +import json +import re + + +def normalize_string(s: str) -> str: + """Standardize strings to a specific format. + + Standardize the string by: + - converting to lowercase, + - trim, and + - remove non alpha-numeric (except ,.!?) characters. + + Args: + s: The string to standardize. + + Returns: + A standardized version of the input string. + """ + s = s.lower().strip() + s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks + s = re.sub(r"(['])", r"", s) # remove apostrophe (i.e., shouldn't --> shouldnt) + s = re.sub( + r"[^a-zA-Z0-9,.!?]+", r" ", s + ) # replace other characters with whitespace + s = re.sub(r"\s+", r" ", s).strip() + return s + + +class SubtitleWrapper: + """Contains the subtitles converted from a JSON file. + + Subtitles are expected to be contained in a JSON file with the format: + { + 'alternative': + []: # only the first element contains the following: + { + 'words': [ + { + 'start_time': '0.100s', + 'end_time': '0.500s', + 'word': 'really' + }, + { } + ], + + } + } + Note: JSON uses double-quotes instead of single-quotes. Single quotes are used for doc-string reasons. + + Attributes: + subtitle: A list of strings containing all the subtitles in order. + """ + + TIMESTAMP_PATTERN = re.compile("(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})") + + def __init__(self, subtitle_path: str): + """Initialization method. + + Args: + subtitle_path: The string filepath to the subtitle (JSON) file. + """ + self.subtitle = [] + self.load_gentle_subtitle(subtitle_path) + + def get(self) -> list: + """Returns the subtitles as a list of words.""" + return self.subtitle + + def load_gentle_subtitle(self, subtitle_path: str) -> None: + """Loads a single subtitle file into this object. + + Modifies the internal state of this object. + The subtitles are loaded in order with a single word appended as a single element. + + Args: + substitle_path: The string filepath to the subtitle (JSON) file. + + Raises: + An exception if the specified file cannot be found. + """ + try: + with open(subtitle_path) as data_file: + data = json.load(data_file) + for item in data: + if "words" in item["alternatives"][0]: + raw_subtitle = item["alternatives"][0]["words"] + for word in raw_subtitle: + self.subtitle.append(word) + except FileNotFoundError: + self.subtitle = None + + def get_seconds(self, word_time_e: str) -> float: + """Convert a timestamp into seconds. + + Args: + word_time_e: The timestamp as a string (ex. hrs:mins:secs.milli - 02:02:02.125). + + Returns: + The timestamp as a float seconds starting from zero. + """ + time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) + if not time_value: + print("wrong time stamp pattern") + exit() + + values = list(map(lambda x: int(x) if x else 0, time_value.groups())) + hours, minutes, seconds, milliseconds = ( + values[0], + values[1], + values[2], + values[3], + ) + + return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 diff --git a/scripts/utils/data_utils_twh.py b/scripts/utils/data_utils_twh.py index 6cdfc77..61bbbab 100644 --- a/scripts/utils/data_utils_twh.py +++ b/scripts/utils/data_utils_twh.py @@ -1,115 +1,115 @@ -"""Utility file to read TSV files containing subtitles into an object. - -Typical usage example: - s = SubtitleWrapper('dataset/subtitles.tsv') - words = s.get() -""" - -import re - - -def normalize_string(s: str) -> str: - """Standardize strings to a specific format. - - Standardize the string by: - - converting to lowercase, - - trim, and - - remove non alpha-numeric characters. - - Args: - s: The string to standardize. - - Returns: - A standardized version of the input string. - """ - s = s.lower().strip() - # s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks - s = re.sub(r"([,.!?])", r"", s) # remove marks - s = re.sub(r"(['])", r"", s) # remove apostrophe (i.e., shouldn't --> shouldnt) - s = re.sub( - r"[^a-zA-Z0-9,.!?]+", r" ", s - ) # replace other characters with whitespace - s = re.sub(r"\s+", r" ", s).strip() - return s - - -class SubtitleWrapper: - """Contains the subtitles converted from a JSON file. - - Subtitles are expected to be contained in a TSV file with each line containing words. - - Attributes: - subtitle: A list of strings containing all the subtitles in order. - """ - - TIMESTAMP_PATTERN = re.compile("(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})") - - def __init__(self, subtitle_path: str): - """Initialization method. - - Args: - subtitle_path: The string filepath to the subtitle (JSON) file. - """ - self.subtitle = [] - self.load_tsv_subtitle(subtitle_path) - - def get(self) -> list: - """Returns the subtitles as a list of words.""" - return self.subtitle - - def load_tsv_subtitle(self, subtitle_path: str) -> None: - """Loads a single subtitle file into this object. - - Modifies the internal state of this object. - The subtitles are loaded in order with a single word appended as a single element. - - Args: - substitle_path: The string filepath to the subtitle (TSV) file. - - Raises: - An exception if the specified file cannot be found. - """ - try: - with open(subtitle_path) as file: - # I had to update it since file number 157 has a different structure and make some problems - for line in file: - line: str = line.strip() - if line.__contains__("content\t\t"): - line = line[len("content\t\t") :] - print("*****************************Error Content") - splitted = line.split("\t") - if len(splitted) == 2: - splitted.append("eh") - print( - "*****************************Error Lost Word " - + splitted[0] - ) - self.subtitle.append(splitted) - - except FileNotFoundError: - self.subtitle = None - print() - - def get_seconds(self, word_time_e: str) -> float: - """Convert a timestamp into seconds. - - Args: - word_time_e: The timestamp as a string (ex. hrs:mins:secs.milli - 02:02:02.125). - - Returns: - The timestamp as a float seconds starting from zero. - """ - time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) - if not time_value: - print("wrong time stamp pattern") - exit() - - values = list(map(lambda x: int(x) if x else 0, time_value.groups())) - hours, minutes, seconds, milliseconds = ( - values[0], - values[1], - values[2], - values[3], - ) - - return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 +"""Utility file to read TSV files containing subtitles into an object. + +Typical usage example: + s = SubtitleWrapper('dataset/subtitles.tsv') + words = s.get() +""" + +import re + + +def normalize_string(s: str) -> str: + """Standardize strings to a specific format. + + Standardize the string by: + - converting to lowercase, + - trim, and + - remove non alpha-numeric characters. + + Args: + s: The string to standardize. + + Returns: + A standardized version of the input string. + """ + s = s.lower().strip() + # s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks + s = re.sub(r"([,.!?])", r"", s) # remove marks + s = re.sub(r"(['])", r"", s) # remove apostrophe (i.e., shouldn't --> shouldnt) + s = re.sub( + r"[^a-zA-Z0-9,.!?]+", r" ", s + ) # replace other characters with whitespace + s = re.sub(r"\s+", r" ", s).strip() + return s + + +class SubtitleWrapper: + """Contains the subtitles converted from a JSON file. + + Subtitles are expected to be contained in a TSV file with each line containing words. + + Attributes: + subtitle: A list of strings containing all the subtitles in order. + """ + + TIMESTAMP_PATTERN = re.compile("(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})") + + def __init__(self, subtitle_path: str): + """Initialization method. + + Args: + subtitle_path: The string filepath to the subtitle (JSON) file. + """ + self.subtitle = [] + self.load_tsv_subtitle(subtitle_path) + + def get(self) -> list: + """Returns the subtitles as a list of words.""" + return self.subtitle + + def load_tsv_subtitle(self, subtitle_path: str) -> None: + """Loads a single subtitle file into this object. + + Modifies the internal state of this object. + The subtitles are loaded in order with a single word appended as a single element. + + Args: + substitle_path: The string filepath to the subtitle (TSV) file. + + Raises: + An exception if the specified file cannot be found. + """ + try: + with open(subtitle_path) as file: + # I had to update it since file number 157 has a different structure and make some problems + for line in file: + line: str = line.strip() + if line.__contains__("content\t\t"): + line = line[len("content\t\t") :] + print("*****************************Error Content") + splitted = line.split("\t") + if len(splitted) == 2: + splitted.append("eh") + print( + "*****************************Error Lost Word " + + splitted[0] + ) + self.subtitle.append(splitted) + + except FileNotFoundError: + self.subtitle = None + print() + + def get_seconds(self, word_time_e: str) -> float: + """Convert a timestamp into seconds. + + Args: + word_time_e: The timestamp as a string (ex. hrs:mins:secs.milli - 02:02:02.125). + + Returns: + The timestamp as a float seconds starting from zero. + """ + time_value = re.match(self.TIMESTAMP_PATTERN, word_time_e) + if not time_value: + print("wrong time stamp pattern") + exit() + + values = list(map(lambda x: int(x) if x else 0, time_value.groups())) + hours, minutes, seconds, milliseconds = ( + values[0], + values[1], + values[2], + values[3], + ) + + return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 diff --git a/scripts/utils/train_utils.py b/scripts/utils/train_utils.py index 92438b1..73c79e0 100644 --- a/scripts/utils/train_utils.py +++ b/scripts/utils/train_utils.py @@ -1,177 +1,175 @@ -"""Utility file to load and save model training checkpoints. - -Saves the model and associated parameters into a .bin file. -Typically used to save model progress after every 20%. -If loading a model, must specify the model string. -Options: 'c2g', 'text2embedding', 'DAE', 'autoencoder', 'baseline', 'text2embedding_gan', 'autoencoder_vq' - -Typical usage example: - save_checkpoint({ - args: ArgumentParser object with the current parameters. - epoch: The epoch that have been completed. - lang_model: Vocab object containing the trained word vector representation. - pose_dim: An integer value of the number of dimensions of a single gesture. - gen_dict: A state_dict from a PyTorch Neural Net subclass.args: - }, 'autoencoder_progress_20.bin') - - or - - m = load_checkpoint_and_model('output/model.bin', 'gpu', 'autoencoder_vq') -""" - - -import logging -import os -from logging.handlers import RotatingFileHandler - -import time -import math -from typing import Tuple -from configargparse import argparse -import torch - -from model.vocab import Vocab -from train_DAE import init_model as DAE_init -from train_Autoencoder import init_model as autoencoder_init -from train_cluster2gesture import init_model as c2g_init -from train_text2embedding import init_model as text2embedding_init -from train_gan import init_model as text2embedding_gan_init -from train_autoencoder_VQVAE import init_model as VQVAE_init -from train import init_model as baseline_init_model - - -def set_logger(log_path: str = None, log_filename: str = "log") -> None: - """Set the logger with a given log name and directory. - - Max filesize limit is 10MB with up to 5 logs by default. - - Args: - log_path: The string specifying the directory to save logs. - log_filename: The string specifying the name of the log. - """ - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - handlers = [logging.StreamHandler()] - if log_path is not None: - os.makedirs(log_path, exist_ok=True) - handlers.append( - RotatingFileHandler( - os.path.join(log_path, log_filename), - maxBytes=10 * 1024 * 1024, - backupCount=5, - ) - ) - logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s: %(message)s", handlers=handlers - ) - logging.getLogger("matplotlib").setLevel(logging.WARNING) - - -def as_minutes(s: int) -> str: - """Convert seconds into a mins, seconds string. - - Args: - s: The seconds as an integer. - - Returns: - A string in the format of ' '. - """ - m = math.floor(s / 60) - s -= m * 60 - return "%dm %ds" % (m, s) - - -def time_since(since: int) -> str: - """Calculate the time between now and a specific time in seconds. - - Args: - since: The starting time in seconds. - - Returns: - A string showing the elapsed time in the format of 'minutes seconds'. - """ - now = time.time() - s = now - since - return "%s" % as_minutes(s) - - -def save_checkpoint(state: dict, filename: str) -> None: - """Save the current model into a bin file. - - Args: - state: A dictionary. - { - args: ArgumentParser object with the current parameters. - epoch: The epoch that have been completed. - lang_model: Vocab object containing the trained word vector representation. - pose_dim: An integer value of the number of dimensions of a single gesture. - gen_dict: A state_dict from a PyTorch Neural Net subclass. - } - filename: The filename to save the current model into. - """ - torch.save(state, filename) - logging.info("Saved the checkpoint") - - -def load_checkpoint_and_model( - checkpoint_path: str, _device: str | torch.device = "cpu", what: str = "" -) -> Tuple[ - argparse.Namespace, torch.nn.Module, torch.nn.modules.loss._Loss, Vocab, int -]: - """Load a checkpoint file representing a saved state of a model into memory. - - Args: - checkpoint_path: The string filepath to find the file to load. - _device: A string or torch.device indicating the availability of a GPU. Default value is 'cpu'. - what: A string specifying the particular model to load. - Options: 'c2g', 'text2embedding', 'DAE', 'autoencoder', 'baseline', 'text2embedding_gan', 'autoencoder_vq' - - Returns a Tuple: - args: ArgumentParser object containing model and data parameters used. - generator: The model loaded. - loss_fn: A PyTorch loss function used to score the model. - lang_model: A Vocab object that contains the pre-trained word vector representations used. - pose_dim: An integer value of the number of dimensions of a gesture. - - Raises: - An assertion if the 'what' arg specifies a non-existing model. - """ - print("loading checkpoint {}".format(checkpoint_path)) - checkpoint = torch.load(checkpoint_path, map_location=_device) - args: argparse.Namespace = checkpoint["args"] - epoch: int = checkpoint["epoch"] - lang_model: Vocab = checkpoint["lang_model"] - pose_dim: int = checkpoint["pose_dim"] - print("epoch {}".format(epoch)) - - if what == "c2g": - generator, loss_fn = c2g_init(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - elif what == "text2embedding": - generator, loss_fn = text2embedding_init(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - # Todo: should change to what==? and then fix all it's callings. - elif what == "DAE": # Todo - generator, loss_fn = DAE_init(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - elif what == "autoencoder": # Todo - generator, loss_fn = autoencoder_init(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - elif what == "baseline": - generator, loss_fn = baseline_init_model(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - elif what == "text2embedding_gan": - generator, loss_fn = text2embedding_gan_init( - args, lang_model, pose_dim, _device - ) - generator.load_state_dict(checkpoint["gen_dict"]) - elif what == "autoencoder_vq": - generator, loss_fn = VQVAE_init(args, lang_model, pose_dim, _device) - generator.load_state_dict(checkpoint["gen_dict"]) - else: - assert 1 == 2 - - # set to eval mode - generator.train(False) - - return args, generator, loss_fn, lang_model, pose_dim +"""Utility file to load and save model training checkpoints. + +Saves the model and associated parameters into a .bin file. +Typically used to save model progress after every 20%. +If loading a model, must specify the model string. +Options: 'c2g', 'text2embedding', 'DAE', 'autoencoder', 'baseline', 'text2embedding_gan', 'autoencoder_vq' + +Typical usage example: + save_checkpoint({ + args: ArgumentParser object with the current parameters. + epoch: The epoch that have been completed. + lang_model: Vocab object containing the trained word vector representation. + pose_dim: An integer value of the number of dimensions of a single gesture. + gen_dict: A state_dict from a PyTorch Neural Net subclass.args: + }, 'autoencoder_progress_20.bin') + + or + + m = load_checkpoint_and_model('output/model.bin', 'gpu', 'autoencoder_vq') +""" + + +import logging +import os +from logging.handlers import RotatingFileHandler + +import time +import math +from typing import Tuple +from configargparse import argparse +import torch + +from model.vocab import Vocab +from train_DAE import init_model as DAE_init +from train_Autoencoder import init_model as autoencoder_init +from train_cluster2gesture import init_model as c2g_init +from train_text2embedding import init_model as text2embedding_init +from train_gan import init_model as text2embedding_gan_init +from train_autoencoder_VQVAE import init_model as VQVAE_init +from train import init_model as baseline_init_model + + +def set_logger(log_path: str = None, log_filename: str = "log") -> None: + """Set the logger with a given log name and directory. + + Max filesize limit is 10MB with up to 5 logs by default. + + Args: + log_path: The string specifying the directory to save logs. + log_filename: The string specifying the name of the log. + """ + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + handlers = [logging.StreamHandler()] + if log_path is not None: + os.makedirs(log_path, exist_ok=True) + handlers.append( + RotatingFileHandler( + os.path.join(log_path, log_filename), + maxBytes=10 * 1024 * 1024, + backupCount=5, + ) + ) + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s: %(message)s", handlers=handlers + ) + logging.getLogger("matplotlib").setLevel(logging.WARNING) + + +def as_minutes(s: int) -> str: + """Convert seconds into a mins, seconds string. + + Args: + s: The seconds as an integer. + + Returns: + A string in the format of ' '. + """ + m = math.floor(s / 60) + s -= m * 60 + return "%dm %ds" % (m, s) + + +def time_since(since: int) -> str: + """Calculate the time between now and a specific time in seconds. + + Args: + since: The starting time in seconds. + + Returns: + A string showing the elapsed time in the format of 'minutes seconds'. + """ + now = time.time() + s = now - since + return "%s" % as_minutes(s) + + +def save_checkpoint(state: dict, filename: str) -> None: + """Save the current model into a bin file. + + Args: + state: A dictionary. + { + args: ArgumentParser object with the current parameters. + epoch: The epoch that have been completed. + lang_model: Vocab object containing the trained word vector representation. + pose_dim: An integer value of the number of dimensions of a single gesture. + gen_dict: A state_dict from a PyTorch Neural Net subclass. + } + filename: The filename to save the current model into. + """ + torch.save(state, filename) + logging.info("Saved the checkpoint") + + +def load_checkpoint_and_model( + checkpoint_path, _device = "cpu", what: str = "" +): + """Load a checkpoint file representing a saved state of a model into memory. + + Args: + checkpoint_path: The string filepath to find the file to load. + _device: A string or torch.device indicating the availability of a GPU. Default value is 'cpu'. + what: A string specifying the particular model to load. + Options: 'c2g', 'text2embedding', 'DAE', 'autoencoder', 'baseline', 'text2embedding_gan', 'autoencoder_vq' + + Returns a Tuple: + args: ArgumentParser object containing model and data parameters used. + generator: The model loaded. + loss_fn: A PyTorch loss function used to score the model. + lang_model: A Vocab object that contains the pre-trained word vector representations used. + pose_dim: An integer value of the number of dimensions of a gesture. + + Raises: + An assertion if the 'what' arg specifies a non-existing model. + """ + print("loading checkpoint {}".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location=_device) + args: argparse.Namespace = checkpoint["args"] + epoch: int = checkpoint["epoch"] + lang_model: Vocab = checkpoint["lang_model"] + pose_dim: int = checkpoint["pose_dim"] + print("epoch {}".format(epoch)) + + if what == "c2g": + generator, loss_fn = c2g_init(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + elif what == "text2embedding": + generator, loss_fn = text2embedding_init(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + # Todo: should change to what==? and then fix all it's callings. + elif what == "DAE": # Todo + generator, loss_fn = DAE_init(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + elif what == "autoencoder": # Todo + generator, loss_fn = autoencoder_init(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + elif what == "baseline": + generator, loss_fn = baseline_init_model(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + elif what == "text2embedding_gan": + generator, loss_fn = text2embedding_gan_init( + args, lang_model, pose_dim, _device + ) + generator.load_state_dict(checkpoint["gen_dict"]) + elif what == "autoencoder_vq": + generator, loss_fn = VQVAE_init(args, lang_model, pose_dim, _device) + generator.load_state_dict(checkpoint["gen_dict"]) + else: + assert 1 == 2 + + # set to eval mode + generator.train(False) + + return args, generator, loss_fn, lang_model, pose_dim diff --git a/scripts/utils/vocab_utils.py b/scripts/utils/vocab_utils.py index dac6f29..c900c2d 100644 --- a/scripts/utils/vocab_utils.py +++ b/scripts/utils/vocab_utils.py @@ -1,107 +1,107 @@ -"""Utility file to build a a language vector representation model. - -Language vector representations are built from existing FastText resources. -These resource files (ex. crawl-300d-2M-subword.bin) must be saved to a directory. -The directory must be included in the config files. -Training and Testing datasets (as PyTorch dataset subclasses) must also be provided. -If a cache file is valid, then the word_vec_path and feat_dim args can be ignored. - -Typical usage example: - v = build_vocab( - 'first_vocab', - [training, testing], - 'dataset/vocab_cache.pkl', - 'resource/crawl-300d-2M-subword.bin', - 300 - ) - index_words(v, 'dataset/lmdb/data.mdb') -""" - - -import logging -import os -import pickle - -import lmdb -import pyarrow - -from model.vocab import Vocab - - -def build_vocab( - name: str, - dataset_list: list, - cache_path: str, - word_vec_path: str = None, - feat_dim: int | None = None, -) -> Vocab: - """Build a language vector representation model from an existing source. - - Builds a language model using existing (English) FastText vector representations. - The 'word_vec_path' and 'feat_dim' arguments must be provided if a model has not been previously created. - Once the model has been built, saves the model using Pickle to the 'cache_path' location. - If an existing model has been detected at the 'cache_path' then load the model instead of build. - - Args: - name: A string to be used as a name for the language model. - dataset_list: A list containing PyTorch 'Dataset' objects that are represented within Lmdb files and contains dataset associated information. - cache_path: A string representing the filepath to check if a language model has been previously built. - word_vec_path: A string representing (FastText) .bin files to use. - feat_dim: An int representing the dimensions in the FastText files. - - Returns: - A Vocab object that contains the language vector representations. - - Raises: - Assertion that the model is consistent with its embedded weights. - """ - logging.info(" building a language model...") - if not os.path.exists(cache_path): - lang_model = Vocab(name) - for dataset in dataset_list: - logging.info(" indexing words from {}".format(dataset.lmdb_dir)) - index_words(lang_model, dataset.lmdb_dir) - - if word_vec_path is not None: - lang_model.load_word_vectors(word_vec_path, feat_dim) - - with open(cache_path, "wb") as f: - pickle.dump(lang_model, f) - else: - logging.info(" loaded from {}".format(cache_path)) - with open(cache_path, "rb") as f: - lang_model: Vocab = pickle.load(f) - - if word_vec_path is None: - lang_model.word_embedding_weights = None - elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: - logging.warning(" failed to load word embedding weights. check this") - assert False - - return lang_model - - -def index_words(lang_model: Vocab, lmdb_dir: str) -> None: - """Analyzes and indexes the words in the dataset to a Vocab object. - - Modifies the lang_model object by calling a mutating method. - Adds all words in the lmdb file to the lang_model. - - Args: - lang_model: A Vocab object representing the language model to train. - lmdb_dir: A string representing the filepath of the dataset to analyze. - """ - lmdb_env: lmdb.Environment = lmdb.open(lmdb_dir, readonly=True, lock=False) - txn = lmdb_env.begin(write=False) - cursor = txn.cursor() - - for key, buf in cursor: - video = pyarrow.deserialize(buf) - - for clip in video["clips"]: - for word_info in clip["words"]: - word = word_info[0] - lang_model.index_word(word) - - lmdb_env.close() - logging.info(" indexed %d words" % lang_model.n_words) +"""Utility file to build a a language vector representation model. + +Language vector representations are built from existing FastText resources. +These resource files (ex. crawl-300d-2M-subword.bin) must be saved to a directory. +The directory must be included in the config files. +Training and Testing datasets (as PyTorch dataset subclasses) must also be provided. +If a cache file is valid, then the word_vec_path and feat_dim args can be ignored. + +Typical usage example: + v = build_vocab( + 'first_vocab', + [training, testing], + 'dataset/vocab_cache.pkl', + 'resource/crawl-300d-2M-subword.bin', + 300 + ) + index_words(v, 'dataset/lmdb/data.mdb') +""" + + +import logging +import os +import pickle + +import lmdb +import pyarrow + +from model.vocab import Vocab + + +def build_vocab( + name: str, + dataset_list: list, + cache_path: str, + word_vec_path: str = None, + feat_dim = None, +) -> Vocab: + """Build a language vector representation model from an existing source. + + Builds a language model using existing (English) FastText vector representations. + The 'word_vec_path' and 'feat_dim' arguments must be provided if a model has not been previously created. + Once the model has been built, saves the model using Pickle to the 'cache_path' location. + If an existing model has been detected at the 'cache_path' then load the model instead of build. + + Args: + name: A string to be used as a name for the language model. + dataset_list: A list containing PyTorch 'Dataset' objects that are represented within Lmdb files and contains dataset associated information. + cache_path: A string representing the filepath to check if a language model has been previously built. + word_vec_path: A string representing (FastText) .bin files to use. + feat_dim: An int representing the dimensions in the FastText files. + + Returns: + A Vocab object that contains the language vector representations. + + Raises: + Assertion that the model is consistent with its embedded weights. + """ + logging.info(" building a language model...") + if not os.path.exists(cache_path): + lang_model = Vocab(name) + for dataset in dataset_list: + logging.info(" indexing words from {}".format(dataset.lmdb_dir)) + index_words(lang_model, dataset.lmdb_dir) + + if word_vec_path is not None: + lang_model.load_word_vectors(word_vec_path, feat_dim) + + with open(cache_path, "wb") as f: + pickle.dump(lang_model, f) + else: + logging.info(" loaded from {}".format(cache_path)) + with open(cache_path, "rb") as f: + lang_model: Vocab = pickle.load(f) + + if word_vec_path is None: + lang_model.word_embedding_weights = None + elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: + logging.warning(" failed to load word embedding weights. check this") + assert False + + return lang_model + + +def index_words(lang_model: Vocab, lmdb_dir: str) -> None: + """Analyzes and indexes the words in the dataset to a Vocab object. + + Modifies the lang_model object by calling a mutating method. + Adds all words in the lmdb file to the lang_model. + + Args: + lang_model: A Vocab object representing the language model to train. + lmdb_dir: A string representing the filepath of the dataset to analyze. + """ + lmdb_env: lmdb.Environment = lmdb.open(lmdb_dir, readonly=True, lock=False) + txn = lmdb_env.begin(write=False) + cursor = txn.cursor() + + for key, buf in cursor: + video = pyarrow.deserialize(buf) + + for clip in video["clips"]: + for word_info in clip["words"]: + word = word_info[0] + lang_model.index_word(word) + + lmdb_env.close() + logging.info(" indexed %d words" % lang_model.n_words)