-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f7514df
Showing
15 changed files
with
730 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import setuptools | ||
|
||
setuptools.setup( | ||
name="trainer", | ||
version="0.1", | ||
author="Chernyavskiy Vladimir and Stanislav Mikhaylevskiy (my best friend)", | ||
author_email="[email protected], [email protected]", | ||
description="Torch HypeTrainer", | ||
url="https://github.com/JJBT/HypeTrainer", | ||
packages=setuptools.find_packages(), | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from trainer.trainer import Trainer | ||
from trainer.factory import Factory | ||
from trainer.metrics import * | ||
from trainer.callbacks import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from trainer.callbacks.callback import Callback | ||
from trainer.callbacks.validation import ValidationCallback | ||
from trainer.callbacks.logging import LogCallback | ||
from trainer.callbacks.checkpoint import SaveCheckpointCallback, SaveBestCheckpointCallback, LoadCheckpointCallback | ||
from trainer.callbacks.stop_criterion import StopAtStep, NoStopping | ||
from trainer.callbacks.tensorboard import TensorBoardCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import logging | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Callback: | ||
def __init__(self, frequency=0, before=False, after=False, | ||
attributes=None): | ||
|
||
if frequency < 0: | ||
raise ValueError("Frequency argument should be positive.") | ||
|
||
if attributes is None: | ||
self._attributes = dict() | ||
else: | ||
self._attributes = attributes | ||
|
||
self.frequency = frequency | ||
self.before = before | ||
self.after = after | ||
|
||
def __call__(self, trainer): | ||
raise NotImplementedError() | ||
|
||
def before_run(self, trainer): | ||
if self.before: | ||
self.__call__(trainer) | ||
|
||
def after_run(self, trainer): | ||
if self.after: | ||
self.__call__(trainer) | ||
|
||
def set_trainer(self, trainer): | ||
for attribute_name in self._attributes: | ||
attribute_default_value = self._attributes[attribute_name] | ||
trainer.state.add_attribute(attribute_name, attribute_default_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
from trainer.callbacks.callback import Callback | ||
import torch | ||
from trainer.utils.utils import get_state_dict, load_state_dict | ||
import logging | ||
import warnings | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SaveCheckpointCallback(Callback): | ||
def __init__(self, frequency=0, before=False, after=True): | ||
super().__init__(frequency=frequency, before=before, after=after) | ||
cwd = os.getcwd() | ||
self.savedir = os.path.join(cwd, 'checkpoints') | ||
os.makedirs(self.savedir, exist_ok=True) | ||
self.ckpt_filename = 'checkpoint-{}.pt' | ||
|
||
def __call__(self, trainer): | ||
self._save_checkpoint(trainer, self.ckpt_filename.format(trainer.state.step)) | ||
|
||
def _save_checkpoint(self, trainer, filename): | ||
torch.save({ | ||
'model_state_dict': get_state_dict(trainer.model), | ||
# 'optimizer_state_dict': get_state_dict(trainer.optimizer), | ||
# 'scheduler_state_dict': get_state_dict(trainer.scheduler), | ||
'trainer_state': get_state_dict(trainer.state), | ||
'model_class': str(trainer.model.__class__), | ||
# 'optimizer_class': str(trainer.optimizer.__class__), | ||
# 'scheduler_class': str(trainer.scheduler.__class__) | ||
}, os.path.join(self.savedir, filename)) | ||
|
||
|
||
class LoadCheckpointCallback(Callback): | ||
def __init__(self, directory: str, filename=None): | ||
super().__init__(frequency=0, before=True, after=False) | ||
self.directory = os.path.join(directory) | ||
|
||
if filename is not None: | ||
self.filename = filename | ||
else: | ||
self.filename = self._search_checkpoint() | ||
|
||
self.filename_to_load = os.path.join(self.directory, self.filename) | ||
|
||
def __call__(self, trainer): | ||
self._load_checkpoint(trainer) | ||
logger.info(f'Checkpoint {self.filename_to_load} loaded') | ||
|
||
def _load_checkpoint(self, trainer): | ||
checkpoint = torch.load(self.filename_to_load, map_location=trainer.accelerator.device) | ||
|
||
# checks | ||
if checkpoint['model_class'] != str(trainer.model.__class__): | ||
warnings.warn( | ||
f'Models do not match: {checkpoint["model_class"]} and {trainer.model.__class__}', RuntimeWarning | ||
) | ||
|
||
load_state_dict(trainer.model, checkpoint['model_state_dict']) | ||
# trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
# trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | ||
trainer.state.load_state_dict(checkpoint['trainer_state']) | ||
|
||
def _search_checkpoint(self): | ||
filelist = os.listdir(self.directory) | ||
for file in filelist: | ||
if '.pt' in file: | ||
return file | ||
|
||
|
||
class SaveBestCheckpointCallback(SaveCheckpointCallback): | ||
def __init__(self, frequency, val_metric_name: str, | ||
comparison_function=lambda metric, best: metric > best): | ||
super().__init__(frequency=frequency, before=False, after=False) | ||
self.val_metric_name = val_metric_name # last_(train/validation)_{metric} | ||
self.comparison_function = comparison_function | ||
self.current_best = None | ||
self.best_ckpt_filename = 'best-checkpoint-{}.pt' | ||
|
||
def __call__(self, trainer): | ||
self.state_last_metric = trainer.state.get_validation_metric(self.val_metric_name) | ||
if self.current_best is None or self.comparison_function(self.state_last_metric, self.current_best): | ||
self.current_best = self.state_last_metric | ||
|
||
self._save_checkpoint(trainer, self.best_ckpt_filename.format(trainer.state.step)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from trainer.callbacks.callback import Callback | ||
|
||
|
||
class LogCallback(Callback): | ||
def __init__(self, frequency): | ||
super().__init__(frequency=frequency, before=False, after=True) | ||
|
||
def __call__(self, trainer): | ||
trainer.state.log_train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
class StopAtStep: | ||
def __init__(self, last_step): | ||
self.last_step = last_step | ||
|
||
def __call__(self, state): | ||
if isinstance(state, int): | ||
state_step = state | ||
else: | ||
state_step = state.step | ||
|
||
if state_step < self.last_step: | ||
return False | ||
else: | ||
return True | ||
|
||
|
||
class NoStopping: | ||
def __call__(self, state): | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
from trainer.callbacks.callback import Callback | ||
from torch.utils.tensorboard import SummaryWriter | ||
from omegaconf import OmegaConf | ||
|
||
|
||
class TensorBoardCallback(Callback): | ||
def __init__(self, frequency, add_weights=False, add_grads=False): | ||
super().__init__(frequency=frequency, before=True, after=True) | ||
self.log_dir = os.getcwd() | ||
self.writer = SummaryWriter(log_dir=self.log_dir) | ||
self.add_weights = add_weights | ||
self.add_grads = add_grads | ||
|
||
def before_run(self, trainer): | ||
cfg = OmegaConf.to_yaml(trainer.cfg) | ||
cfg = cfg.replace('\n', ' \n') | ||
self.writer.add_text('cfg', cfg) | ||
description = trainer.cfg.description | ||
if description: | ||
self.writer.add_text('description', description) | ||
|
||
def after_run(self, trainer): | ||
self.writer.close() | ||
|
||
def add_validation_metrics(self, trainer): | ||
metrics = trainer.state.validation_metrics | ||
for name, metric in metrics.items(): | ||
self.writer.add_scalar(name, metric, trainer.state.step) | ||
|
||
def add_weights_histogram(self, trainer): | ||
for name, param in trainer.model.named_parameters(): | ||
if 'bn' not in name: | ||
self.writer.add_histogram(name, param, trainer.state.step) | ||
|
||
def add_grads_histogram(self, trainer): | ||
for name, param in trainer.model.named_parameters(): | ||
if 'bn' not in name and param.requires_grad: | ||
self.writer.add_histogram(name + '_grad', param.grad, trainer.state.step) | ||
|
||
def __call__(self, trainer): | ||
for name, loss in trainer.state.last_train_loss.items(): | ||
self.writer.add_scalar(f'trn/{name}', loss, trainer.state.step) | ||
|
||
self.writer.add_scalar('lr', trainer.optimizer.param_groups[0]['lr'], trainer.state.step) | ||
|
||
if self.add_weights: | ||
self.add_weights_histogram(trainer) | ||
|
||
if self.add_grads: | ||
self.add_grads_histogram(trainer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from trainer.callbacks.callback import Callback | ||
|
||
|
||
class ValidationCallback(Callback): | ||
def __init__(self, frequency): | ||
super().__init__(frequency=frequency, before=False, after=False) | ||
|
||
def __call__(self, trainer): | ||
self.computed_metrics = trainer.evaluate(dataloader=trainer.val_dataloader, metrics=trainer.metrics) | ||
for metric_name, metric_value in self.computed_metrics.items(): | ||
trainer.state.add_validation_metric(name=f'val/{metric_name}', value=metric_value) | ||
|
||
trainer.state.log_validation() | ||
|
||
if 'TensorBoardCallback' in trainer.callbacks: | ||
tb_callback = trainer.callbacks['TensorBoardCallback'] | ||
tb_callback.add_validation_metrics(trainer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
from trainer.utils.utils import object_from_dict | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class Factory: | ||
def __init__(self, cfg): | ||
self.cfg = cfg | ||
|
||
def create_model(self): | ||
model = object_from_dict(self.cfg.model) | ||
return model | ||
|
||
def create_optimizer(self, model: torch.nn.Module): | ||
optimizer = object_from_dict(self.cfg.optimizer, params=filter(lambda x: x.requires_grad, model.parameters())) | ||
return optimizer | ||
|
||
def create_scheduler(self, optimizer: torch.optim.Optimizer): | ||
scheduler = object_from_dict(self.cfg.scheduler, optimizer=optimizer) | ||
return scheduler | ||
|
||
def create_loss(self): | ||
loss = object_from_dict(self.cfg.loss) | ||
return loss | ||
|
||
def create_train_dataloader(self): | ||
dataset = self.create_dataset(self.cfg.data.train_dataset) | ||
train_dataloader = self.create_dataloader(self.cfg.bs, dataset) | ||
return train_dataloader | ||
|
||
def create_val_dataloader(self): | ||
dataset = self.create_dataset(self.cfg.data.validation_dataset) | ||
val_dataloader = self.create_dataloader(self.cfg.bs, dataset) | ||
return val_dataloader | ||
|
||
def create_dataset(self, cfg): | ||
augmentations = self.create_augmentations(cfg.augmentations) | ||
dataset = object_from_dict(cfg, transforms=augmentations, ignore_keys=['augmentations']) | ||
return dataset | ||
|
||
def create_dataloader(self, bs, dataset): | ||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||
return dataloader | ||
|
||
def create_metrics(self): | ||
metrics = [] | ||
for metric in self.cfg.metrics: | ||
metric_obj = object_from_dict(metric) | ||
metrics.append(metric_obj) | ||
|
||
return metrics | ||
|
||
def create_callbacks(self, trainer): | ||
for hook in self.cfg.hooks: | ||
hook_obj = object_from_dict(hook) | ||
trainer.register_callback(hook_obj) | ||
|
||
def create_augmentations(self, cfg): | ||
augmentations = [] | ||
for augm in cfg.augmentations: | ||
augmentations.append(object_from_dict(augm)) | ||
|
||
compose = object_from_dict(cfg.compose, transforms=augmentations) | ||
return compose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch | ||
|
||
|
||
transforms_dict = { | ||
'accuracy_prediction': lambda x: torch.argmax(x, dim=1), | ||
'accuracy_target': lambda x: torch.argmax(x, dim=1), | ||
'recall_prediction': lambda x: torch.argmax(x, dim=1), | ||
'recall_target': lambda x: torch.argmax(x, dim=1), | ||
'precision_prediction': lambda x: torch.argmax(x, dim=1), | ||
'precision_target': lambda x: torch.argmax(x, dim=1), | ||
'conf_matrix_prediction': lambda x: torch.argmax(x, dim=1), | ||
'conf_matrix_target': lambda x: torch.argmax(x, dim=1), | ||
} |
Oops, something went wrong.