Skip to content

Commit

Permalink
first hype commit
Browse files Browse the repository at this point in the history
  • Loading branch information
JJBT committed Jul 8, 2021
0 parents commit f7514df
Show file tree
Hide file tree
Showing 15 changed files with 730 additions and 0 deletions.
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import setuptools

setuptools.setup(
name="trainer",
version="0.1",
author="Chernyavskiy Vladimir and Stanislav Mikhaylevskiy (my best friend)",
author_email="[email protected], [email protected]",
description="Torch HypeTrainer",
url="https://github.com/JJBT/HypeTrainer",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)
5 changes: 5 additions & 0 deletions trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from trainer.trainer import Trainer
from trainer.factory import Factory
from trainer.metrics import *
from trainer.callbacks import *

6 changes: 6 additions & 0 deletions trainer/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from trainer.callbacks.callback import Callback
from trainer.callbacks.validation import ValidationCallback
from trainer.callbacks.logging import LogCallback
from trainer.callbacks.checkpoint import SaveCheckpointCallback, SaveBestCheckpointCallback, LoadCheckpointCallback
from trainer.callbacks.stop_criterion import StopAtStep, NoStopping
from trainer.callbacks.tensorboard import TensorBoardCallback
37 changes: 37 additions & 0 deletions trainer/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging


logger = logging.getLogger(__name__)


class Callback:
def __init__(self, frequency=0, before=False, after=False,
attributes=None):

if frequency < 0:
raise ValueError("Frequency argument should be positive.")

if attributes is None:
self._attributes = dict()
else:
self._attributes = attributes

self.frequency = frequency
self.before = before
self.after = after

def __call__(self, trainer):
raise NotImplementedError()

def before_run(self, trainer):
if self.before:
self.__call__(trainer)

def after_run(self, trainer):
if self.after:
self.__call__(trainer)

def set_trainer(self, trainer):
for attribute_name in self._attributes:
attribute_default_value = self._attributes[attribute_name]
trainer.state.add_attribute(attribute_name, attribute_default_value)
87 changes: 87 additions & 0 deletions trainer/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from trainer.callbacks.callback import Callback
import torch
from trainer.utils.utils import get_state_dict, load_state_dict
import logging
import warnings


logger = logging.getLogger(__name__)


class SaveCheckpointCallback(Callback):
def __init__(self, frequency=0, before=False, after=True):
super().__init__(frequency=frequency, before=before, after=after)
cwd = os.getcwd()
self.savedir = os.path.join(cwd, 'checkpoints')
os.makedirs(self.savedir, exist_ok=True)
self.ckpt_filename = 'checkpoint-{}.pt'

def __call__(self, trainer):
self._save_checkpoint(trainer, self.ckpt_filename.format(trainer.state.step))

def _save_checkpoint(self, trainer, filename):
torch.save({
'model_state_dict': get_state_dict(trainer.model),
# 'optimizer_state_dict': get_state_dict(trainer.optimizer),
# 'scheduler_state_dict': get_state_dict(trainer.scheduler),
'trainer_state': get_state_dict(trainer.state),
'model_class': str(trainer.model.__class__),
# 'optimizer_class': str(trainer.optimizer.__class__),
# 'scheduler_class': str(trainer.scheduler.__class__)
}, os.path.join(self.savedir, filename))


class LoadCheckpointCallback(Callback):
def __init__(self, directory: str, filename=None):
super().__init__(frequency=0, before=True, after=False)
self.directory = os.path.join(directory)

if filename is not None:
self.filename = filename
else:
self.filename = self._search_checkpoint()

self.filename_to_load = os.path.join(self.directory, self.filename)

def __call__(self, trainer):
self._load_checkpoint(trainer)
logger.info(f'Checkpoint {self.filename_to_load} loaded')

def _load_checkpoint(self, trainer):
checkpoint = torch.load(self.filename_to_load, map_location=trainer.accelerator.device)

# checks
if checkpoint['model_class'] != str(trainer.model.__class__):
warnings.warn(
f'Models do not match: {checkpoint["model_class"]} and {trainer.model.__class__}', RuntimeWarning
)

load_state_dict(trainer.model, checkpoint['model_state_dict'])
# trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
trainer.state.load_state_dict(checkpoint['trainer_state'])

def _search_checkpoint(self):
filelist = os.listdir(self.directory)
for file in filelist:
if '.pt' in file:
return file


class SaveBestCheckpointCallback(SaveCheckpointCallback):
def __init__(self, frequency, val_metric_name: str,
comparison_function=lambda metric, best: metric > best):
super().__init__(frequency=frequency, before=False, after=False)
self.val_metric_name = val_metric_name # last_(train/validation)_{metric}
self.comparison_function = comparison_function
self.current_best = None
self.best_ckpt_filename = 'best-checkpoint-{}.pt'

def __call__(self, trainer):
self.state_last_metric = trainer.state.get_validation_metric(self.val_metric_name)
if self.current_best is None or self.comparison_function(self.state_last_metric, self.current_best):
self.current_best = self.state_last_metric

self._save_checkpoint(trainer, self.best_ckpt_filename.format(trainer.state.step))

9 changes: 9 additions & 0 deletions trainer/callbacks/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from trainer.callbacks.callback import Callback


class LogCallback(Callback):
def __init__(self, frequency):
super().__init__(frequency=frequency, before=False, after=True)

def __call__(self, trainer):
trainer.state.log_train()
19 changes: 19 additions & 0 deletions trainer/callbacks/stop_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class StopAtStep:
def __init__(self, last_step):
self.last_step = last_step

def __call__(self, state):
if isinstance(state, int):
state_step = state
else:
state_step = state.step

if state_step < self.last_step:
return False
else:
return True


class NoStopping:
def __call__(self, state):
return False
51 changes: 51 additions & 0 deletions trainer/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from trainer.callbacks.callback import Callback
from torch.utils.tensorboard import SummaryWriter
from omegaconf import OmegaConf


class TensorBoardCallback(Callback):
def __init__(self, frequency, add_weights=False, add_grads=False):
super().__init__(frequency=frequency, before=True, after=True)
self.log_dir = os.getcwd()
self.writer = SummaryWriter(log_dir=self.log_dir)
self.add_weights = add_weights
self.add_grads = add_grads

def before_run(self, trainer):
cfg = OmegaConf.to_yaml(trainer.cfg)
cfg = cfg.replace('\n', ' \n')
self.writer.add_text('cfg', cfg)
description = trainer.cfg.description
if description:
self.writer.add_text('description', description)

def after_run(self, trainer):
self.writer.close()

def add_validation_metrics(self, trainer):
metrics = trainer.state.validation_metrics
for name, metric in metrics.items():
self.writer.add_scalar(name, metric, trainer.state.step)

def add_weights_histogram(self, trainer):
for name, param in trainer.model.named_parameters():
if 'bn' not in name:
self.writer.add_histogram(name, param, trainer.state.step)

def add_grads_histogram(self, trainer):
for name, param in trainer.model.named_parameters():
if 'bn' not in name and param.requires_grad:
self.writer.add_histogram(name + '_grad', param.grad, trainer.state.step)

def __call__(self, trainer):
for name, loss in trainer.state.last_train_loss.items():
self.writer.add_scalar(f'trn/{name}', loss, trainer.state.step)

self.writer.add_scalar('lr', trainer.optimizer.param_groups[0]['lr'], trainer.state.step)

if self.add_weights:
self.add_weights_histogram(trainer)

if self.add_grads:
self.add_grads_histogram(trainer)
17 changes: 17 additions & 0 deletions trainer/callbacks/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from trainer.callbacks.callback import Callback


class ValidationCallback(Callback):
def __init__(self, frequency):
super().__init__(frequency=frequency, before=False, after=False)

def __call__(self, trainer):
self.computed_metrics = trainer.evaluate(dataloader=trainer.val_dataloader, metrics=trainer.metrics)
for metric_name, metric_value in self.computed_metrics.items():
trainer.state.add_validation_metric(name=f'val/{metric_name}', value=metric_value)

trainer.state.log_validation()

if 'TensorBoardCallback' in trainer.callbacks:
tb_callback = trainer.callbacks['TensorBoardCallback']
tb_callback.add_validation_metrics(trainer)
64 changes: 64 additions & 0 deletions trainer/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from trainer.utils.utils import object_from_dict
from torch.utils.data import DataLoader


class Factory:
def __init__(self, cfg):
self.cfg = cfg

def create_model(self):
model = object_from_dict(self.cfg.model)
return model

def create_optimizer(self, model: torch.nn.Module):
optimizer = object_from_dict(self.cfg.optimizer, params=filter(lambda x: x.requires_grad, model.parameters()))
return optimizer

def create_scheduler(self, optimizer: torch.optim.Optimizer):
scheduler = object_from_dict(self.cfg.scheduler, optimizer=optimizer)
return scheduler

def create_loss(self):
loss = object_from_dict(self.cfg.loss)
return loss

def create_train_dataloader(self):
dataset = self.create_dataset(self.cfg.data.train_dataset)
train_dataloader = self.create_dataloader(self.cfg.bs, dataset)
return train_dataloader

def create_val_dataloader(self):
dataset = self.create_dataset(self.cfg.data.validation_dataset)
val_dataloader = self.create_dataloader(self.cfg.bs, dataset)
return val_dataloader

def create_dataset(self, cfg):
augmentations = self.create_augmentations(cfg.augmentations)
dataset = object_from_dict(cfg, transforms=augmentations, ignore_keys=['augmentations'])
return dataset

def create_dataloader(self, bs, dataset):
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
return dataloader

def create_metrics(self):
metrics = []
for metric in self.cfg.metrics:
metric_obj = object_from_dict(metric)
metrics.append(metric_obj)

return metrics

def create_callbacks(self, trainer):
for hook in self.cfg.hooks:
hook_obj = object_from_dict(hook)
trainer.register_callback(hook_obj)

def create_augmentations(self, cfg):
augmentations = []
for augm in cfg.augmentations:
augmentations.append(object_from_dict(augm))

compose = object_from_dict(cfg.compose, transforms=augmentations)
return compose
13 changes: 13 additions & 0 deletions trainer/metric_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


transforms_dict = {
'accuracy_prediction': lambda x: torch.argmax(x, dim=1),
'accuracy_target': lambda x: torch.argmax(x, dim=1),
'recall_prediction': lambda x: torch.argmax(x, dim=1),
'recall_target': lambda x: torch.argmax(x, dim=1),
'precision_prediction': lambda x: torch.argmax(x, dim=1),
'precision_target': lambda x: torch.argmax(x, dim=1),
'conf_matrix_prediction': lambda x: torch.argmax(x, dim=1),
'conf_matrix_target': lambda x: torch.argmax(x, dim=1),
}
Loading

0 comments on commit f7514df

Please sign in to comment.