Our trainer is a class that inherits from openstereo.stereo.modeling.trainer_template.TrainerTemplate
.
You can define your own trainer by inheriting from TrainerTemplate
and overriding or define the methods you need.
To use your own trainer, you need to mount the trainer to your model. Let's take psmnet as an example.
stereo/modeling/models/psmnet/trainer.py
from stereo.modeling.trainer_template import TrainerTemplate
from .psmnet import PSMNet
__all__ = {
'PSMNet': PSMNet,
}
class Trainer(TrainerTemplate):
def __init__(self, args, cfgs, local_rank, global_rank, logger, tb_writer):
model = __all__[cfgs.MODEL.NAME](cfgs.MODEL)
super().__init__(args, cfgs, local_rank, global_rank, logger, tb_writer, model)