Skip to content

Commit

Permalink
merge cd and kt traintpl
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Jul 7, 2023
1 parent e471ebd commit 20d0f6f
Show file tree
Hide file tree
Showing 59 changed files with 103 additions and 268 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ run_edustudio(
dataset='FrcSub',
cfg_file_name=None,
traintpl_cfg_dict={
'cls': 'CDInterTrainTPL',
'cls': 'EduTrainTPL',
},
datatpl_cfg_dict={
'cls': 'CDInterExtendsQDataTPL'
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = 'EduStudio'
copyright = '2023, HFUT-LEC'
author = 'HFUT-LEC'
release = 'v1.0.0-alpha4'
release = 'v1.0.0-alpha5'

import sphinx_rtd_theme
import os
Expand Down
2 changes: 1 addition & 1 deletion docs/source/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ run_edustudio(
dataset='FrcSub',
cfg_file_name=None,
traintpl_cfg_dict={
'cls': 'CDInterTrainTPL',
'cls': 'EduTrainTPL',
},
datatpl_cfg_dict={
'cls': 'CDInterExtendsQDataTPL'
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.. EduStudio documentation master file.
.. title:: EduStudio v1.0.0-alpha4
.. title:: EduStudio v1.0.0-alpha5
.. image:: assets/logo.png

=========================================================
Expand Down
84 changes: 42 additions & 42 deletions docs/source/user_guide/reference_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,51 @@

| Model | DataTPL | TrainTPL | EvalTPL |
| :------ | ---------------------: | :-------------: | ------------------------------------------------------ |
| IRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| MIRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| NCDM | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CNCD_Q | CNCDQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| CNCD_F | CNCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| DINA | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| HierCDF | HierCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CDGK | CDGKDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CDMFKC | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| ECD | ECDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| IRR | IRRDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| KaNCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| KSCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| MGCD | MGCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| RCD | RCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
| IRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| MIRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| NCDM | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CNCD_Q | CNCDQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| CNCD_F | CNCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DINA | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| HierCDF | HierCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CDGK | CDGKDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| CDMFKC | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| ECD | ECDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| IRR | IRRDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| KaNCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
| KSCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| MGCD | MGCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| RCD | RCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |

## KT models

| Model | DataTPL | TrainTPL | EvalTPL |
| :----------- | ----------------------: | :-------------: | --------------------------- |
| AKT | KTInterDataTPLCptUnfold | KTInterTrainTPL | BinaryClassificationEvalTPL |
| AKT | KTInterDataTPLCptUnfold | EduTrainTPL | BinaryClassificationEvalTPL |
| ATKT | KTInterDataTPLCptUnfold | AtktTrainTPL | BinaryClassificationEvalTPL |
| CKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| CL4KT | CL4KTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| CT_NCM | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DeepIRT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DIMKT | DIMKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DKTDSC | DKTDSCDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DKTForget | DKTForgetDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DKT_plus | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DKVMN | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| DTransformer | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| EERNN | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| EKT | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| HawkesKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| IEKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| KQN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| LPKT | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| LPKT_S | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| QDKT | QDKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| QIKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| RKT | RKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| SAINT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| SAINT_plus | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| SAKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| SimpleKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| SKVMN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
| CKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| CL4KT | CL4KTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| CT_NCM | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DeepIRT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DIMKT | DIMKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DKTDSC | DKTDSCDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DKTForget | DKTForgetDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DKT_plus | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DKVMN | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| DTransformer | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| EERNN | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| EKT | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| HawkesKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| IEKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| KQN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| LPKT | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| LPKT_S | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| QDKT | QDKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| QIKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| RKT | RKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| SAINT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| SAINT_plus | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| SAKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| SimpleKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
| SKVMN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
4 changes: 2 additions & 2 deletions docs/source/user_guide/usage/aht.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def objective_function(args):


search_space= {
'traintpl_cfg.cls': tune.grid_search(['CDInterTrainTPL']),
'traintpl_cfg.cls': tune.grid_search(['EduTrainTPL']),
'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']),
'modeltpl_cfg.cls': tune.grid_search(['KaNCD']),
'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]),
Expand Down Expand Up @@ -115,7 +115,7 @@ def objective_function(args):


space = {
'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['CDInterTrainTPL']),
'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['EduTrainTPL']),
'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']),
'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']),
'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]),
Expand Down
4 changes: 2 additions & 2 deletions docs/source/user_guide/usage/run_edustudio.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ run_edustudio(
dataset='FrcSub',
cfg_file_name=None,
traintpl_cfg_dict={
'cls': 'CDInterTrainTPL',
'cls': 'EduTrainTPL',
},
datatpl_cfg_dict={
'cls': 'CDInterExtendsQDataTPL'
Expand Down Expand Up @@ -48,7 +48,7 @@ datatpl_cfg:
cls: CDInterDataTPL

traintpl_cfg:
cls: CDTrainTPL
cls: EduTrainTPL
batch_size: 512

modeltpl_cfg:
Expand Down
2 changes: 1 addition & 1 deletion edustudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from __future__ import print_function
from __future__ import division

__version__ = '1.0.0-alpha4'
__version__ = 'v1.0.0-alpha5'
3 changes: 1 addition & 2 deletions edustudio/traintpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@

from .base_traintpl import BaseTrainTPL
from .gd_traintpl import GDTrainTPL
from .cd_inter_traintpl import CDInterTrainTPL
from .kt_inter_traintpl import KTInterTrainTPL
from .edu_traintpl import EduTrainTPL
from .atkt_traintpl import AtktTrainTPL
2 changes: 1 addition & 1 deletion edustudio/traintpl/base_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_default_cfg(cls):
return cfg

def start(self):
self.logger.info(f"TrainTPL {self.__class__.__base__} Started!")
self.logger.info(f"TrainTPL {self.__class__} Started!")
set_same_seeds(self.traintpl_cfg['seed'])

def _check_params(self):
Expand Down
161 changes: 0 additions & 161 deletions edustudio/traintpl/cd_inter_traintpl.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import shutil


class KTInterTrainTPL(GDTrainTPL):
class EduTrainTPL(GDTrainTPL):
default_cfg = {
'num_stop_rounds': 10,
'early_stop_metrics': [('auc','max')],
Expand All @@ -20,9 +20,6 @@ class KTInterTrainTPL(GDTrainTPL):
'batch_size': 32,
}

def __init__(self, cfg: UnifyConfig):
super().__init__(cfg)

def _check_params(self):
super()._check_params()
assert self.traintpl_cfg['best_epoch_metric'] in set(i[0] for i in self.traintpl_cfg['early_stop_metrics'])
Expand Down Expand Up @@ -111,7 +108,7 @@ def evaluate(self, loader):
batch_dict = self.batch_dict2device(batch_dict)
eval_dict = self.model.predict(**batch_dict)
pd_list[idx] = eval_dict['y_pd']
gt_list[idx] = eval_dict['y_gt']
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
y_pd = torch.hstack(pd_list)
y_gt = torch.hstack(gt_list)

Expand Down Expand Up @@ -142,7 +139,7 @@ def inference(self, loader):
batch_dict = self.batch_dict2device(batch_dict)
eval_dict = self.model.predict(**batch_dict)
pd_list[idx] = eval_dict['y_pd']
gt_list[idx] = eval_dict['y_gt']
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
y_pd = torch.hstack(pd_list)
y_gt = torch.hstack(gt_list)

Expand Down
Loading

0 comments on commit 20d0f6f

Please sign in to comment.