From f7514dfdf06a450a617962e671706ae210f774e3 Mon Sep 17 00:00:00 2001 From: vladimir Date: Thu, 8 Jul 2021 11:28:49 +0300 Subject: [PATCH] first hype commit --- setup.py | 16 +++ trainer/__init__.py | 5 + trainer/callbacks/__init__.py | 6 + trainer/callbacks/callback.py | 37 ++++++ trainer/callbacks/checkpoint.py | 87 +++++++++++++ trainer/callbacks/logging.py | 9 ++ trainer/callbacks/stop_criterion.py | 19 +++ trainer/callbacks/tensorboard.py | 51 ++++++++ trainer/callbacks/validation.py | 17 +++ trainer/factory.py | 64 ++++++++++ trainer/metric_transforms.py | 13 ++ trainer/metrics.py | 129 +++++++++++++++++++ trainer/trainer.py | 191 ++++++++++++++++++++++++++++ trainer/utils/__init__.py | 2 + trainer/utils/utils.py | 84 ++++++++++++ 15 files changed, 730 insertions(+) create mode 100644 setup.py create mode 100644 trainer/__init__.py create mode 100644 trainer/callbacks/__init__.py create mode 100644 trainer/callbacks/callback.py create mode 100644 trainer/callbacks/checkpoint.py create mode 100644 trainer/callbacks/logging.py create mode 100644 trainer/callbacks/stop_criterion.py create mode 100644 trainer/callbacks/tensorboard.py create mode 100644 trainer/callbacks/validation.py create mode 100644 trainer/factory.py create mode 100644 trainer/metric_transforms.py create mode 100644 trainer/metrics.py create mode 100644 trainer/trainer.py create mode 100644 trainer/utils/__init__.py create mode 100644 trainer/utils/utils.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0ea96d1 --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +import setuptools + +setuptools.setup( + name="trainer", + version="0.1", + author="Chernyavskiy Vladimir and Stanislav Mikhaylevskiy (my best friend)", + author_email="chernvld@gmail.com, m1xst99@yandex.ru", + description="Torch HypeTrainer", + url="https://github.com/JJBT/HypeTrainer", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], +) diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..2f8000d --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,5 @@ +from trainer.trainer import Trainer +from trainer.factory import Factory +from trainer.metrics import * +from trainer.callbacks import * + diff --git a/trainer/callbacks/__init__.py b/trainer/callbacks/__init__.py new file mode 100644 index 0000000..547144c --- /dev/null +++ b/trainer/callbacks/__init__.py @@ -0,0 +1,6 @@ +from trainer.callbacks.callback import Callback +from trainer.callbacks.validation import ValidationCallback +from trainer.callbacks.logging import LogCallback +from trainer.callbacks.checkpoint import SaveCheckpointCallback, SaveBestCheckpointCallback, LoadCheckpointCallback +from trainer.callbacks.stop_criterion import StopAtStep, NoStopping +from trainer.callbacks.tensorboard import TensorBoardCallback diff --git a/trainer/callbacks/callback.py b/trainer/callbacks/callback.py new file mode 100644 index 0000000..82ae2cd --- /dev/null +++ b/trainer/callbacks/callback.py @@ -0,0 +1,37 @@ +import logging + + +logger = logging.getLogger(__name__) + + +class Callback: + def __init__(self, frequency=0, before=False, after=False, + attributes=None): + + if frequency < 0: + raise ValueError("Frequency argument should be positive.") + + if attributes is None: + self._attributes = dict() + else: + self._attributes = attributes + + self.frequency = frequency + self.before = before + self.after = after + + def __call__(self, trainer): + raise NotImplementedError() + + def before_run(self, trainer): + if self.before: + self.__call__(trainer) + + def after_run(self, trainer): + if self.after: + self.__call__(trainer) + + def set_trainer(self, trainer): + for attribute_name in self._attributes: + attribute_default_value = self._attributes[attribute_name] + trainer.state.add_attribute(attribute_name, attribute_default_value) diff --git a/trainer/callbacks/checkpoint.py b/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..a2ac9cf --- /dev/null +++ b/trainer/callbacks/checkpoint.py @@ -0,0 +1,87 @@ +import os +from trainer.callbacks.callback import Callback +import torch +from trainer.utils.utils import get_state_dict, load_state_dict +import logging +import warnings + + +logger = logging.getLogger(__name__) + + +class SaveCheckpointCallback(Callback): + def __init__(self, frequency=0, before=False, after=True): + super().__init__(frequency=frequency, before=before, after=after) + cwd = os.getcwd() + self.savedir = os.path.join(cwd, 'checkpoints') + os.makedirs(self.savedir, exist_ok=True) + self.ckpt_filename = 'checkpoint-{}.pt' + + def __call__(self, trainer): + self._save_checkpoint(trainer, self.ckpt_filename.format(trainer.state.step)) + + def _save_checkpoint(self, trainer, filename): + torch.save({ + 'model_state_dict': get_state_dict(trainer.model), + # 'optimizer_state_dict': get_state_dict(trainer.optimizer), + # 'scheduler_state_dict': get_state_dict(trainer.scheduler), + 'trainer_state': get_state_dict(trainer.state), + 'model_class': str(trainer.model.__class__), + # 'optimizer_class': str(trainer.optimizer.__class__), + # 'scheduler_class': str(trainer.scheduler.__class__) + }, os.path.join(self.savedir, filename)) + + +class LoadCheckpointCallback(Callback): + def __init__(self, directory: str, filename=None): + super().__init__(frequency=0, before=True, after=False) + self.directory = os.path.join(directory) + + if filename is not None: + self.filename = filename + else: + self.filename = self._search_checkpoint() + + self.filename_to_load = os.path.join(self.directory, self.filename) + + def __call__(self, trainer): + self._load_checkpoint(trainer) + logger.info(f'Checkpoint {self.filename_to_load} loaded') + + def _load_checkpoint(self, trainer): + checkpoint = torch.load(self.filename_to_load, map_location=trainer.accelerator.device) + + # checks + if checkpoint['model_class'] != str(trainer.model.__class__): + warnings.warn( + f'Models do not match: {checkpoint["model_class"]} and {trainer.model.__class__}', RuntimeWarning + ) + + load_state_dict(trainer.model, checkpoint['model_state_dict']) + # trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + # trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + trainer.state.load_state_dict(checkpoint['trainer_state']) + + def _search_checkpoint(self): + filelist = os.listdir(self.directory) + for file in filelist: + if '.pt' in file: + return file + + +class SaveBestCheckpointCallback(SaveCheckpointCallback): + def __init__(self, frequency, val_metric_name: str, + comparison_function=lambda metric, best: metric > best): + super().__init__(frequency=frequency, before=False, after=False) + self.val_metric_name = val_metric_name # last_(train/validation)_{metric} + self.comparison_function = comparison_function + self.current_best = None + self.best_ckpt_filename = 'best-checkpoint-{}.pt' + + def __call__(self, trainer): + self.state_last_metric = trainer.state.get_validation_metric(self.val_metric_name) + if self.current_best is None or self.comparison_function(self.state_last_metric, self.current_best): + self.current_best = self.state_last_metric + + self._save_checkpoint(trainer, self.best_ckpt_filename.format(trainer.state.step)) + diff --git a/trainer/callbacks/logging.py b/trainer/callbacks/logging.py new file mode 100644 index 0000000..a11ebcd --- /dev/null +++ b/trainer/callbacks/logging.py @@ -0,0 +1,9 @@ +from trainer.callbacks.callback import Callback + + +class LogCallback(Callback): + def __init__(self, frequency): + super().__init__(frequency=frequency, before=False, after=True) + + def __call__(self, trainer): + trainer.state.log_train() diff --git a/trainer/callbacks/stop_criterion.py b/trainer/callbacks/stop_criterion.py new file mode 100644 index 0000000..b32f51b --- /dev/null +++ b/trainer/callbacks/stop_criterion.py @@ -0,0 +1,19 @@ +class StopAtStep: + def __init__(self, last_step): + self.last_step = last_step + + def __call__(self, state): + if isinstance(state, int): + state_step = state + else: + state_step = state.step + + if state_step < self.last_step: + return False + else: + return True + + +class NoStopping: + def __call__(self, state): + return False diff --git a/trainer/callbacks/tensorboard.py b/trainer/callbacks/tensorboard.py new file mode 100644 index 0000000..495a37b --- /dev/null +++ b/trainer/callbacks/tensorboard.py @@ -0,0 +1,51 @@ +import os +from trainer.callbacks.callback import Callback +from torch.utils.tensorboard import SummaryWriter +from omegaconf import OmegaConf + + +class TensorBoardCallback(Callback): + def __init__(self, frequency, add_weights=False, add_grads=False): + super().__init__(frequency=frequency, before=True, after=True) + self.log_dir = os.getcwd() + self.writer = SummaryWriter(log_dir=self.log_dir) + self.add_weights = add_weights + self.add_grads = add_grads + + def before_run(self, trainer): + cfg = OmegaConf.to_yaml(trainer.cfg) + cfg = cfg.replace('\n', ' \n') + self.writer.add_text('cfg', cfg) + description = trainer.cfg.description + if description: + self.writer.add_text('description', description) + + def after_run(self, trainer): + self.writer.close() + + def add_validation_metrics(self, trainer): + metrics = trainer.state.validation_metrics + for name, metric in metrics.items(): + self.writer.add_scalar(name, metric, trainer.state.step) + + def add_weights_histogram(self, trainer): + for name, param in trainer.model.named_parameters(): + if 'bn' not in name: + self.writer.add_histogram(name, param, trainer.state.step) + + def add_grads_histogram(self, trainer): + for name, param in trainer.model.named_parameters(): + if 'bn' not in name and param.requires_grad: + self.writer.add_histogram(name + '_grad', param.grad, trainer.state.step) + + def __call__(self, trainer): + for name, loss in trainer.state.last_train_loss.items(): + self.writer.add_scalar(f'trn/{name}', loss, trainer.state.step) + + self.writer.add_scalar('lr', trainer.optimizer.param_groups[0]['lr'], trainer.state.step) + + if self.add_weights: + self.add_weights_histogram(trainer) + + if self.add_grads: + self.add_grads_histogram(trainer) diff --git a/trainer/callbacks/validation.py b/trainer/callbacks/validation.py new file mode 100644 index 0000000..9007627 --- /dev/null +++ b/trainer/callbacks/validation.py @@ -0,0 +1,17 @@ +from trainer.callbacks.callback import Callback + + +class ValidationCallback(Callback): + def __init__(self, frequency): + super().__init__(frequency=frequency, before=False, after=False) + + def __call__(self, trainer): + self.computed_metrics = trainer.evaluate(dataloader=trainer.val_dataloader, metrics=trainer.metrics) + for metric_name, metric_value in self.computed_metrics.items(): + trainer.state.add_validation_metric(name=f'val/{metric_name}', value=metric_value) + + trainer.state.log_validation() + + if 'TensorBoardCallback' in trainer.callbacks: + tb_callback = trainer.callbacks['TensorBoardCallback'] + tb_callback.add_validation_metrics(trainer) diff --git a/trainer/factory.py b/trainer/factory.py new file mode 100644 index 0000000..1fe484d --- /dev/null +++ b/trainer/factory.py @@ -0,0 +1,64 @@ +import torch +from trainer.utils.utils import object_from_dict +from torch.utils.data import DataLoader + + +class Factory: + def __init__(self, cfg): + self.cfg = cfg + + def create_model(self): + model = object_from_dict(self.cfg.model) + return model + + def create_optimizer(self, model: torch.nn.Module): + optimizer = object_from_dict(self.cfg.optimizer, params=filter(lambda x: x.requires_grad, model.parameters())) + return optimizer + + def create_scheduler(self, optimizer: torch.optim.Optimizer): + scheduler = object_from_dict(self.cfg.scheduler, optimizer=optimizer) + return scheduler + + def create_loss(self): + loss = object_from_dict(self.cfg.loss) + return loss + + def create_train_dataloader(self): + dataset = self.create_dataset(self.cfg.data.train_dataset) + train_dataloader = self.create_dataloader(self.cfg.bs, dataset) + return train_dataloader + + def create_val_dataloader(self): + dataset = self.create_dataset(self.cfg.data.validation_dataset) + val_dataloader = self.create_dataloader(self.cfg.bs, dataset) + return val_dataloader + + def create_dataset(self, cfg): + augmentations = self.create_augmentations(cfg.augmentations) + dataset = object_from_dict(cfg, transforms=augmentations, ignore_keys=['augmentations']) + return dataset + + def create_dataloader(self, bs, dataset): + dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) + return dataloader + + def create_metrics(self): + metrics = [] + for metric in self.cfg.metrics: + metric_obj = object_from_dict(metric) + metrics.append(metric_obj) + + return metrics + + def create_callbacks(self, trainer): + for hook in self.cfg.hooks: + hook_obj = object_from_dict(hook) + trainer.register_callback(hook_obj) + + def create_augmentations(self, cfg): + augmentations = [] + for augm in cfg.augmentations: + augmentations.append(object_from_dict(augm)) + + compose = object_from_dict(cfg.compose, transforms=augmentations) + return compose diff --git a/trainer/metric_transforms.py b/trainer/metric_transforms.py new file mode 100644 index 0000000..f200816 --- /dev/null +++ b/trainer/metric_transforms.py @@ -0,0 +1,13 @@ +import torch + + +transforms_dict = { + 'accuracy_prediction': lambda x: torch.argmax(x, dim=1), + 'accuracy_target': lambda x: torch.argmax(x, dim=1), + 'recall_prediction': lambda x: torch.argmax(x, dim=1), + 'recall_target': lambda x: torch.argmax(x, dim=1), + 'precision_prediction': lambda x: torch.argmax(x, dim=1), + 'precision_target': lambda x: torch.argmax(x, dim=1), + 'conf_matrix_prediction': lambda x: torch.argmax(x, dim=1), + 'conf_matrix_target': lambda x: torch.argmax(x, dim=1), +} diff --git a/trainer/metrics.py b/trainer/metrics.py new file mode 100644 index 0000000..dce0b6e --- /dev/null +++ b/trainer/metrics.py @@ -0,0 +1,129 @@ +import torch +from trainer.metric_transforms import transforms_dict +from sklearn.metrics import precision_score, recall_score +import numpy as np + + +class Metric: + def __init__(self, name: str, default_value=None, target_transform=None, prediction_transform=None): + self.name = name.replace(' ', '_') + self.default_value = default_value + self.target_transform = target_transform if target_transform else \ + transforms_dict.get(f'{self.name}_target', lambda x: x) + self.prediction_transform = prediction_transform if prediction_transform else \ + transforms_dict.get(f'{self.name}_prediction', lambda x: x) + + def prepare(self, y: torch.Tensor, y_pred: torch.Tensor): + y = self.target_transform(y) + y_pred = self.prediction_transform(y_pred) + + if isinstance(y, torch.Tensor): + y = y.detach() + + if isinstance(y_pred, torch.Tensor): + y_pred = y_pred.detach() + + return y, y_pred + + def step(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + def compute(self): + raise NotImplementedError() + + def reset(self): + raise NotImplementedError() + + +class Accuracy(Metric): + def __init__(self): + super().__init__("accuracy", default_value=0) + self.total_correct = 0 + self.total = 0 + + def step(self, y, y_pred): + y, y_pred = self.prepare(y, y_pred) + correct = torch.eq(y, y_pred) + + self.total_correct += torch.sum(correct).item() + self.total += correct.shape[0] + + def compute(self): + return self.total_correct / self.total + + def reset(self): + self.total_correct = 0 + self.total = 0 + + +class Recall(Metric): + def __init__(self, average=None, target_transform=None, prediction_transform=None): + super().__init__('recall', default_value=0, target_transform=target_transform, + prediction_transform=prediction_transform) + self.predicions = [] + self.targets = [] + self.average = average + + def step(self, y: torch.Tensor, y_pred: torch.Tensor): + # TODO + y, y_pred = self.prepare(y, y_pred) + self.targets.extend(y.tolist()) + self.predicions.extend(y_pred.tolist()) + + def compute(self): + result = recall_score(self.targets, self.predicions, average=self.average) + if self.average: + return result + else: + return {f'{i}_recall': result[i] for i in range(result.shape[0])} + + def reset(self): + self.predicions = [] + self.targets = [] + + +class Precision(Metric): + def __init__(self, average=None, target_transform=None, prediction_transform=None): + super().__init__('precision', default_value=0, target_transform=target_transform, + prediction_transform=prediction_transform) + self.predicions = [] + self.targets = [] + self.average = average + + def step(self, y: torch.Tensor, y_pred: torch.Tensor): + # TODO + y, y_pred = self.prepare(y, y_pred) + self.targets.extend(y.tolist()) + self.predicions.extend(y_pred.tolist()) + + def compute(self): + result = precision_score(self.targets, self.predicions, average=self.average) + if self.average: + return result + else: + return {f'{i}_precision': result[i] for i in range(result.shape[0])} + + def reset(self): + self.predicions = [] + self.targets = [] + + +class ConfusionMatrix(Metric): + def __init__(self, num_classes, target_transform=None, prediction_transform=None): + super().__init__('conf_matrix', default_value=0, target_transform=target_transform, + prediction_transform=prediction_transform) + self.matrix = np.zeros((num_classes, num_classes)) + self.num_classes = num_classes + + def step(self, y: torch.Tensor, y_pred: torch.Tensor): + y, y_pred = self.prepare(y, y_pred) + y, y_pred = y.tolist(), y_pred.tolist() + + for t, p in zip(y, y_pred): + self.matrix[int(t)][int(p)] += 1 + + def compute(self): + return self.matrix + + def reset(self): + self.matrix = np.zeros((self.num_classes, self.num_classes)) diff --git a/trainer/trainer.py b/trainer/trainer.py new file mode 100644 index 0000000..f70bdf1 --- /dev/null +++ b/trainer/trainer.py @@ -0,0 +1,191 @@ +import signal +import torch +from trainer.callbacks import Callback, StopAtStep +import logging +from collections import OrderedDict +from trainer.utils.utils import set_determenistic, flatten_dict, loss_to_dict +from accelerate import Accelerator, GradScalerKwargs + +logger = logging.getLogger(__name__) + + +class State: + def __init__(self): + self.step = 0 + self.last_train_loss = None + + self.validation_metrics = dict() + + def get_validation_metric(self, name): + return self.validation_metrics[name] + + def get(self, attribute_name: str): + return getattr(self, attribute_name) + + def load_state_dict(self, state_dict): + for k, v in state_dict.items(): + setattr(self, k, v) + + def state_dict(self): + return self.__dict__ + + def add_attribute(self, name, value): + if not hasattr(self, name): + setattr(self, name, value) + + def add_validation_metric(self, name, value): + self.validation_metrics[name] = value + + def reset(self): + self.step = 0 + self.last_train_loss = None + + def update(self, loss_dict=None): + self.step += 1 + if loss_dict is not None: + self.last_train_loss = flatten_dict(loss_dict) + + def log_train(self): + msg = f'Step - {self.step} ' + for name, value in self.last_train_loss.items(): + msg += f'{name} - {value:.7f} ' + + logger.info(msg) + + def log_validation(self): + msg = f'Validation ' + for name, value in self.validation_metrics.items(): + msg += f'{name} - {value:.7f} ' + + logger.info(msg) + + +class Trainer: + def __init__(self, cfg, factory): + signal.signal(signal.SIGINT, self._soft_exit) + set_determenistic() + + self.factory = factory + + self.train_dataloader = self.factory.create_train_dataloader() + self.val_dataloader = self.factory.create_val_dataloader() + self.state = State() + self.criterion = self.factory.create_loss() + self.model = self.factory.create_model() + self.optimizer = self.factory.create_optimizer(self.model) + self.scheduler = self.factory.create_scheduler(self.optimizer) + self.n_steps = cfg.n_steps + self.stop_condition = StopAtStep(last_step=self.n_steps) + self.callbacks = OrderedDict() + self.metrics = self.factory.create_metrics() + self.factory.create_callbacks(self) + + self.cfg = cfg + self.stop_validation = False + self.grad_scaler_kwargs = GradScalerKwargs(init_scale=2048, enabled=cfg.amp) + self.accelerator = Accelerator(cpu=bool(cfg.device == 'cpu'), fp16=cfg.amp) + self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + self.train_dataloader, self.val_dataloader = \ + self.accelerator.prepare(self.train_dataloader, self.val_dataloader) + + def get_train_batch(self): + if not getattr(self, 'train_data_iter', False): + self.train_data_iter = iter(self.train_dataloader) + try: + batch = next(self.train_data_iter) + except StopIteration: + self.train_data_iter = iter(self.train_dataloader) + batch = next(self.train_data_iter) + + return batch + + def run_step(self, batch): + self.optimizer.zero_grad() + + inputs, targets = self.get_input_and_target_from_batch(batch) + + outputs = self.model(inputs) + + loss_dict = self.criterion(outputs, targets) + + loss_dict = loss_to_dict(loss_dict) + loss = loss_dict['loss'] + self.accelerator.backward(loss) + + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + loss_dict['loss'] = loss_dict['loss'].detach() + + return loss_dict + + def run_train(self, n_steps=None): + if n_steps is not None: + self.stop_condition = StopAtStep(last_step=n_steps) + + self.state.reset() + self.model.train() + + self._before_run_callbacks() + + while not self.stop_condition(self.state): + batch = self.get_train_batch() + loss = self.run_step(batch) + self.state.update(loss) + + self._run_callbacks() + + self._after_run_callbacks() + logger.info('Done') + + def evaluate(self, dataloader, metrics): + previous_training_flag = self.model.training + + self.model.eval() + for metric in metrics: + metric.reset() + + with torch.no_grad(): + for batch in dataloader: + if self.stop_validation: + break + + input_tensor = batch[0] + target_tensor = batch[1] + outputs = self.model(input_tensor) + + for metric in metrics: + metric.step(y=target_tensor, y_pred=outputs) + + metrics_computed = {metric.name: metric.compute() for metric in metrics} + self.model.train(previous_training_flag) + + return flatten_dict(metrics_computed) + + def get_input_and_target_from_batch(self, batch): + return batch[0], batch[1] + + def register_callback(self, callback: Callback): + callback.set_trainer(self) + callback_name = callback.__class__.__name__ + self.callbacks[callback_name] = callback + + def _soft_exit(self, sig, frame): + logger.info('Soft exit... Currently running steps will be finished') + self.stop_condition = lambda state: True + self.stop_validation = True + + def _before_run_callbacks(self): + for name, callback in self.callbacks.items(): + callback.before_run(self) + + def _after_run_callbacks(self): + for name, callback in self.callbacks.items(): + callback.after_run(self) + + def _run_callbacks(self): + for name, callback in self.callbacks.items(): + freq = callback.frequency + if freq != 0 and self.state.step % freq == 0: + callback(self) diff --git a/trainer/utils/__init__.py b/trainer/utils/__init__.py new file mode 100644 index 0000000..bf10207 --- /dev/null +++ b/trainer/utils/__init__.py @@ -0,0 +1,2 @@ +from trainer.utils import utils + diff --git a/trainer/utils/utils.py b/trainer/utils/utils.py new file mode 100644 index 0000000..80891c7 --- /dev/null +++ b/trainer/utils/utils.py @@ -0,0 +1,84 @@ +import torch +import random +import numpy as np +import pydoc +from omegaconf import DictConfig +from collections import MutableMapping + + +def object_from_dict(d, parent=None, ignore_keys=None, **default_kwargs): + assert isinstance(d, (dict, DictConfig)) and 'type' in d + kwargs = d.copy() + kwargs = dict(kwargs) + object_type = kwargs.pop('type') + + if object_type is None: + return None + + if ignore_keys: + for key in ignore_keys: + kwargs.pop(key, None) + + for name, value in default_kwargs.items(): + kwargs.setdefault(name, value) + + # support nested constructions + for key, value in kwargs.items(): + if isinstance(value, (dict, DictConfig)) and 'type' in value: + value = object_from_dict(value) + kwargs[key] = value + + if parent is not None: + return getattr(parent, object_type)(**kwargs) + else: + return pydoc.locate(object_type)(**kwargs) + + +def freeze_layers(model, layers_to_train): + """Freeze layers not included in layers_to_train""" + for name, parameter in model.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + +def set_determenistic(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.cuda.manual_seed_all(seed) + torch.manual_seed(seed) + + +def flatten_dict(d, parent_key='', sep='_'): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, '', sep=sep).items()) + else: + value = v.item() if isinstance(v, torch.Tensor) else v + items.append((new_key, value)) + + return dict(items) + + +def loss_to_dict(loss): + if not isinstance(loss, dict): + return {'loss': loss} + else: + return loss + + +def get_state_dict(model): + if model is None: + return None + else: + return model.state_dict() + + +def load_state_dict(model, state_dict): + if model is None: + return None + else: + return model.load_state_dict(state_dict)