Skip to content

Commit

Permalink
add AdversarialTrainTPL
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Feb 4, 2024
1 parent 34c13e0 commit b55f2c6
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The overall structure is illustrated as follows:
:maxdepth: 1
:caption: User Guide

user_guide/atom_op
user_guide/datasets
user_guide/models
user_guide/reference_table
Expand Down
14 changes: 14 additions & 0 deletions docs/source/user_guide/atom_op.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Atomic Data Operation List

| M2C Atomic operation | Description |
| :--------------------: | :----------------------------------------------------------: |
| M2C_Label2Int | Convert label column into discrete values |
| M2C_MergeDividedSplits | Merge train/valid/test set into one dataframe |
| M2C_ReMapId | ReMap Column ID |
| M2C_GenQMat | Generate Q-matrix |
| M2C_RandomDataSplit4CD | Split datasets Randomly for CD |
| M2C_FilterRecords4CD | Filter students or exercises whose number of interaction records is less than a threshold |
| M2C_BuildSeqInterFeats | Build Sequential Features and Split dataset |
| M2C_CptAsExer | Treat knowledge concept as exercise |
| M2C_GenCptSeq | Generate knowledge concept seq |
| M2C_GenUnFoldCptSeq | Unfold knowledge concepts |
1 change: 1 addition & 0 deletions edustudio/traintpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .general_traintpl import GeneralTrainTPL
from .atkt_traintpl import AtktTrainTPL
from .dcd_traintpl import DCDTrainTPL
from .adversarial_traintpl import AdversarialTrainTPL
100 changes: 100 additions & 0 deletions edustudio/traintpl/adversarial_traintpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .general_traintpl import GeneralTrainTPL
import torch
from collections import defaultdict
import numpy as np
from tqdm import tqdm


class AdversarialTrainTPL(GeneralTrainTPL):
default_cfg = {
'lr': 0.001,
'lr_d': 0.001,
'g_rounds': 1,
'd_rounds': 1,
'optim': 'adam',
'optim_d': 'adam',
}

def _get_optim(self, model_params, optimizer='adam', lr=0.001, weight_decay=0.0, eps=1e-8):
"""Get optimizer
"""
if optimizer == "sgd":
optim = torch.optim.SGD(model_params, lr=lr, weight_decay=weight_decay, eps=eps)
elif optimizer == "adam":
optim = torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay, eps=eps)
elif optimizer == "adagrad":
optim = torch.optim.Adagrad(model_params, lr=lr, weight_decay=weight_decay, eps=eps)
elif optimizer == "rmsprop":
optim = torch.optim.RMSprop(model_params, lr=lr, weight_decay=weight_decay, eps=eps)
else:
raise ValueError("unsupported optimizer")

return optim

def fit(self, train_loader, valid_loader):
self.model.train()
lr = self.traintpl_cfg['lr']
lr_d = self.traintpl_cfg['lr_d']
weight_decay = self.traintpl_cfg['weight_decay']
eps = self.traintpl_cfg['eps']

self.optimizer_g = self._get_optim(self.model.get_g_parameters(), self.modeltpl_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
self.optimizer_d = self._get_optim(self.model.get_d_parameters(), self.modeltpl_cfg['optim_d'], lr=lr_d, weight_decay=weight_decay, eps=eps)

self.callback_list.on_train_begin()
for epoch in range(self.traintpl_cfg['epoch_num']):
self.callback_list.on_epoch_begin(epoch + 1)

# train_for_generator
g_rounds = self.traintpl_cfg['g_rounds']
d_rounds = self.traintpl_cfg['d_rounds']
logs = defaultdict(lambda: np.full((len(train_loader) * g_rounds,), np.nan, dtype=np.float32))
for round_id in range(g_rounds):
for batch_id, batch_dict in enumerate(
tqdm(train_loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[GEN:EPOCH={:03d}]".format(epoch + 1))
):
batch_dict = self.batch_dict2device(batch_dict)
loss_gen_dict, loss_dis_dict = self.model.get_loss_dict(**batch_dict)
loss_gen = torch.hstack([i for i in loss_gen_dict.values() if i is not None]).sum()
loss_dis = torch.hstack([i for i in loss_dis_dict.values() if i is not None]).sum()
loss = loss_gen - loss_dis
self.optimizer_g.zero_grad()
loss.backward()
self.optimizer_g.step()
for k in loss_gen_dict: logs[k][batch_id + len(train_loader) * round_id] = loss_gen_dict[k].item() if loss_gen_dict[k] is not None else np.nan
for k in loss_dis_dict: logs[k][batch_id + len(train_loader) * round_id] = loss_dis_dict[k].item() if loss_dis_dict[k] is not None else np.nan

logs_g = {}
for name in logs: logs_g[f"GEN_{name}"] = float(np.nanmean(logs[name]))

# train_for_discriminator
logs = defaultdict(lambda: np.full((len(train_loader) * d_rounds,), np.nan, dtype=np.float32))
for round_id in range(d_rounds):
for batch_id, batch_dict in enumerate(
tqdm(train_loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[DIS:EPOCH={:03d}]".format(epoch + 1))
):
batch_dict = self.batch_dict2device(batch_dict)
loss_gen_dict, loss_dis_dict = self.model.get_loss_dict(**batch_dict)
loss_gen = torch.hstack([i for i in loss_gen_dict.values() if i is not None]).sum()
loss_dis = torch.hstack([i for i in loss_dis_dict.values() if i is not None]).sum()
loss = - loss_gen + loss_dis
self.optimizer_d.zero_grad()
loss.backward()
self.optimizer_d.step()
for k in loss_gen_dict: logs[k][batch_id + len(train_loader) * round_id] = loss_gen_dict[k].item() if loss_gen_dict[k] is not None else np.nan
for k in loss_dis_dict: logs[k][batch_id + len(train_loader) * round_id] = loss_dis_dict[k].item() if loss_dis_dict[k] is not None else np.nan

logs_d = {}
for name in logs: logs_d[f"DIS_{name}"] = float(np.nanmean(logs[name]))

logs = logs_g
logs.update(logs_d)
if valid_loader is not None:
val_metrics = self.evaluate(valid_loader)
logs.update({f"{metric}": val_metrics[metric] for metric in val_metrics})

self.callback_list.on_epoch_end(epoch + 1, logs=logs)
if self.model.share_callback_dict.get('stop_training', False):
break

self.callback_list.on_train_end()
7 changes: 6 additions & 1 deletion edustudio/traintpl/atkt_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def one_fold_start(self, fold_id):
def fit(self, train_loader, valid_loader):
kt_loss = KTLoss()
self.model.train()
self.optimizer = self._get_optim()
optimizer = self.traintpl_cfg['optim']
lr = self.traintpl_cfg['lr']
weight_decay = self.traintpl_cfg['weight_decay']
eps = self.traintpl_cfg['eps']
self.optimizer = self._get_optim(optimizer=optimizer, lr=lr, weight_decay=weight_decay, eps=eps)

self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.traintpl_cfg['lr_decay'], gamma=self.traintpl_cfg['gamma'])
self.callback_list.on_train_begin()
for epoch in range(self.traintpl_cfg['epoch_num']):
Expand Down
6 changes: 1 addition & 5 deletions edustudio/traintpl/gd_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@ def __init__(self, cfg: UnifyConfig):
self.valid_loader_list = []
self.test_loader_list = []

def _get_optim(self):
def _get_optim(self, optimizer='adam', lr=0.001, weight_decay=0.0, eps=1e-8):
"""Get optimizer
"""
optimizer = self.traintpl_cfg['optim']
lr = self.traintpl_cfg['lr']
weight_decay = self.traintpl_cfg['weight_decay']
eps = self.traintpl_cfg['eps']
if optimizer == "sgd":
optim = torch.optim.SGD(self.model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
elif optimizer == "adam":
Expand Down
6 changes: 5 additions & 1 deletion edustudio/traintpl/general_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def one_fold_start(self, fold_id):

def fit(self, train_loader, valid_loader):
self.model.train()
self.optimizer = self._get_optim()
optimizer = self.traintpl_cfg['optim']
lr = self.traintpl_cfg['lr']
weight_decay = self.traintpl_cfg['weight_decay']
eps = self.traintpl_cfg['eps']
self.optimizer = self._get_optim(optimizer=optimizer, lr=lr, weight_decay=weight_decay, eps=eps)
self.callback_list.on_train_begin()
for epoch in range(self.traintpl_cfg['epoch_num']):
self.callback_list.on_epoch_begin(epoch + 1)
Expand Down

0 comments on commit b55f2c6

Please sign in to comment.