Skip to content

Commit

Permalink
training agent abstract class
Browse files Browse the repository at this point in the history
  • Loading branch information
eric8607242 committed Jul 12, 2021
1 parent a65b33f commit ce94120
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 383 deletions.
132 changes: 132 additions & 0 deletions agent/base_training_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os

from abc import ABC, abstractmethod

from utils import save

class MetaTrainingAgent:
def train_loop(self, model, train_loader, val_loader, agent):
""" The main training loop.
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
val_loader (torch.utils.data.DataLoader)
agent (Object)
"""
training_step = getattr(self, f"_{agent.agent_state}_training_step")
validate_step = getattr(self, f"_{agent.agent_state}_validate_step")

best_val_performance = -float("inf")
for epoch in range(agent.start_epochs, agent.epochs):
agent.logger.info(f"Start to train for epoch {epoch}")
agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}")

training_step(
model,
train_loader,
agent,
epoch)
val_performance = validate_step(
model,
val_loader,
agent,
epoch)

if val_performance > best_val_performance:
agent.logger.info(f"Best validation performance : {val_performance}. Save model!")
best_val_performance = val_performance
save(
model,
agent.config["experiment_path"]["best_checkpoint_path"],
agent.criterion,
agent.optimizer,
agent.lr_scheduler,
epoch + 1)

save(
model,
os.path.join(
agent.config["experiment_path"]["checkpoint_root_path"],
f"{agent.agent_state}_{epoch}.pth"),
agent.criterion,
agent.optimizer,
agent.lr_scheduler,
epoch + 1)

def _search_training_step(self, model, train_loader, agent, epoch):
""" The training step for searching process. Users should step the sampler
to decide how to train supernet and step the search strategy to search the architecture.
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The search agent.
epoch (int)
"""
self._training_step(model, train_loader, agent, epoch)

def _evaluate_training_step(self, model, train_loader, agent, epoch):
""" The training step for evaluating process (training from scratch).
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
"""
self._training_step(model, train_loader, agent, epoch)

def _search_validate_step(self, model, val_loader, agent, epoch):
""" The validate step for searching process.
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The search agent.
epoch (int)
Return:
evaluate_metric (float): The performance of the supernet
"""
return self._validate_step(model, val_loader, agent, epoch)

def _evaluate_validate_step(self, model, val_loader, agent, epoch):
""" The training step for evaluating process (training from scratch).
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
Return:
evaluate_metric (float): The performance of the searched model.
"""
return self._validate_step(model, val_loader, agent, epoch)

@abstractmethod
def _training_step(self, model, train_loader, agent, epoch):
"""
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
"""
raise NotImplementedError

@abstractmethod
def _validate_step(self, model, val_loader, agent, epoch):
"""
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
Return:
evaluate_metric (float): The performance of the searched model.
"""
raise NotImplementedError
return evaluate_metric
126 changes: 21 additions & 105 deletions agent/classification_agent/training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,115 +4,14 @@
import torch

from utils import AverageMeter, accuracy, save
from ..base_training_agent import MetaTrainingAgent

class CFTrainingAgent:
class CFTrainingAgent(MetaTrainingAgent):
"""The training agent to train the supernet and the searched architecture.
By implementing TrainingAgent class, users can adapt the searching and evaluating agent into
various tasks easily.
"""
def train_loop(self, model, train_loader, val_loader, agent):
""" The main training loop.
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
val_loader (torch.utils.data.DataLoader)
agent (Object)
"""
# Utilize different step method based on differet agent state
training_step = getattr(self, f"_{agent.agent_state}_training_step")
validate_step = getattr(self, f"_{agent.agent_state}_validate_step")

best_val_performance = -float("inf")
for epoch in range(agent.start_epochs, agent.epochs):
agent.logger.info(f"Start to train for epoch {epoch}")
agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}")

training_step(
model,
train_loader,
agent,
epoch)
val_performance = validate_step(
model,
val_loader,
agent,
epoch)

if val_performance > best_val_performance:
agent.logger.info(f"Best validation performance : {val_performance}. Save model!")
best_val_performance = val_performance
save(
model,
agent.config["experiment_path"]["best_checkpoint_path"],
agent.criterion,
agent.optimizer,
agent.lr_scheduler,
epoch + 1)

save(
model,
os.path.join(
agent.config["experiment_path"]["checkpoint_root_path"],
f"{agent.agent_state}_{epoch}.pth"),
agent.criterion,
agent.optimizer,
agent.lr_scheduler,
epoch + 1)


def _search_training_step(self, model, train_loader, agent, epoch):
""" The training step for searching process. Users should step the sampler
to decide how to train supernet and step the search strategy to search the architecture.
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The search agent.
epoch (int)
"""
self._training_step(model, train_loader, agent, epoch)

def _search_validate_step(self, model, val_loader, agent, epoch):
""" The validate step for searching process.
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The search agent.
epoch (int)
Return:
evaluate_metric (float): The performance of the supernet
"""
return self._validate(model ,val_loader, agent, epoch)

def _evaluate_training_step(self, model, train_loader, agent, epoch):
""" The training step for evaluating process (training from scratch).
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
"""
self._training_step(model, train_loader, agent, epoch)

def _evaluate_validate_step(self, model, val_loader, agent, epoch):
""" The training step for evaluating process (training from scratch).
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
Return:
evaluate_metric (float): The performance of the searched model.
"""
return self._validate(model ,val_loader, agent, epoch)

@staticmethod
def searching_evaluate(model, val_loader, device, criterion):
""" Evaluating the performance of the supernet. The search strategy will evaluate
Expand Down Expand Up @@ -150,6 +49,13 @@ def searching_evaluate(model, val_loader, device, criterion):


def _training_step(self, model, train_loader, agent, epoch, print_freq=100):
"""
Args:
model (nn.Module)
train_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
"""
top1 = AverageMeter()
top5 = AverageMeter()
losses = AverageMeter()
Expand Down Expand Up @@ -182,7 +88,7 @@ def _training_step(self, model, train_loader, agent, epoch, print_freq=100):
top5.update(prec5.item(), N)

if (step > 1 and step % print_freq == 0) or (step == len(train_loader) - 1):
agent.logger.info(f"Train : [{(epoch+1):3d}/{agent.epochs}]"
agent.logger.info(f"Train : [{(epoch+1):3d}/{agent.epochs}] "
f"Step {step:3d}/{len(train_loader)-1:3d} Loss {losses.get_avg():.3f}"
f"Prec@(1, 5) ({top1.get_avg():.1%}, {top5.get_avg():.1%})")

Expand All @@ -196,7 +102,17 @@ def _training_step(self, model, train_loader, agent, epoch, print_freq=100):
f"Time {time.time() - start_time:.2f}")


def _validate(self, model, val_loader, agent, epoch):
def _validate_step(self, model, val_loader, agent, epoch):
"""
Args:
model (nn.Module)
val_loader (torch.utils.data.DataLoader)
agent (Object): The evaluate agent
epoch (int)
Return:
evaluate_metric (float): The performance of the searched model.
"""
model.eval()
start_time = time.time()

Expand Down
Loading

0 comments on commit ce94120

Please sign in to comment.