From 20d0f6f7c32f3ddbd3a18077861d8d6b5e60f0c2 Mon Sep 17 00:00:00 2001 From: kervias Date: Fri, 7 Jul 2023 10:31:12 +0800 Subject: [PATCH] merge cd and kt traintpl --- README.md | 2 +- docs/source/conf.py | 2 +- docs/source/get_started/quick_start.md | 2 +- docs/source/index.rst | 2 +- docs/source/user_guide/reference_table.md | 84 ++++----- docs/source/user_guide/usage/aht.md | 4 +- docs/source/user_guide/usage/run_edustudio.md | 4 +- edustudio/__init__.py | 2 +- edustudio/traintpl/__init__.py | 3 +- edustudio/traintpl/base_traintpl.py | 2 +- edustudio/traintpl/cd_inter_traintpl.py | 161 ------------------ .../{kt_inter_traintpl.py => edu_traintpl.py} | 9 +- examples/1.run_cd_demo.py | 2 +- examples/2.run_kt_demo.py | 2 +- examples/3.run_with_customized_tpl.py | 2 +- examples/5.run_with_hyperopt.py | 2 +- examples/6.run_with_ray.tune.py | 2 +- examples/single_model/run_akt_demo.py | 2 +- examples/single_model/run_cdgk_demo.py | 2 +- examples/single_model/run_cdmfkc_demo.py | 2 +- examples/single_model/run_ckt_demo.py | 2 +- examples/single_model/run_cl4kt_demo.py | 2 +- examples/single_model/run_cncd_f_demo.py | 2 +- examples/single_model/run_cncdq_demo.py | 2 +- examples/single_model/run_ctncm_demo.py | 2 +- examples/single_model/run_deepirt_demo.py | 2 +- examples/single_model/run_dimkt_demo.py | 2 +- examples/single_model/run_dina_demo.py | 2 +- examples/single_model/run_dkt_demo.py | 2 +- examples/single_model/run_dkt_dsc_demo.py | 2 +- examples/single_model/run_dkt_plus_demo.py | 2 +- examples/single_model/run_dktforget_demo.py | 2 +- examples/single_model/run_dkvmn_demo.py | 2 +- .../single_model/run_dtransformer_demo.py | 2 +- examples/single_model/run_eernn_demo.py | 2 +- examples/single_model/run_gkt_demo.py | 2 +- examples/single_model/run_hawkeskt_demo.py | 2 +- examples/single_model/run_hiercdf_demo.py | 2 +- examples/single_model/run_iekt_demo.py | 2 +- examples/single_model/run_irr_demo.py | 2 +- examples/single_model/run_irt_demo.py | 2 +- examples/single_model/run_kancd_demo.py | 2 +- examples/single_model/run_kqn_demo.py | 2 +- examples/single_model/run_kscd_demo.py | 2 +- examples/single_model/run_lpkt_demo.py | 2 +- examples/single_model/run_lpkt_s_demo.py | 2 +- examples/single_model/run_mgcd_demo.py | 2 +- examples/single_model/run_mirt_demo.py | 2 +- examples/single_model/run_ncdm_demo.py | 2 +- examples/single_model/run_qdkt_demo.py | 2 +- examples/single_model/run_qikt_demo.py | 2 +- examples/single_model/run_rcd_demo.py | 2 +- examples/single_model/run_rkt_demo.py | 2 +- examples/single_model/run_saint_demo.py | 2 +- examples/single_model/run_saint_plus_demo.py | 2 +- examples/single_model/run_sakt_demo.py | 2 +- examples/single_model/run_simplekt_demo.py | 2 +- examples/single_model/run_skvmn_demo.py | 2 +- setup.py | 2 +- 59 files changed, 103 insertions(+), 268 deletions(-) delete mode 100644 edustudio/traintpl/cd_inter_traintpl.py rename edustudio/traintpl/{kt_inter_traintpl.py => edu_traintpl.py} (96%) diff --git a/README.md b/README.md index d74793c..6999afe 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ run_edustudio( dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL' diff --git a/docs/source/conf.py b/docs/source/conf.py index 3acadce..0615213 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,7 @@ project = 'EduStudio' copyright = '2023, HFUT-LEC' author = 'HFUT-LEC' -release = 'v1.0.0-alpha4' +release = 'v1.0.0-alpha5' import sphinx_rtd_theme import os diff --git a/docs/source/get_started/quick_start.md b/docs/source/get_started/quick_start.md index 1a8806e..b82cc33 100644 --- a/docs/source/get_started/quick_start.md +++ b/docs/source/get_started/quick_start.md @@ -13,7 +13,7 @@ run_edustudio( dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL' diff --git a/docs/source/index.rst b/docs/source/index.rst index aa41bc2..af81b3c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,5 +1,5 @@ .. EduStudio documentation master file. -.. title:: EduStudio v1.0.0-alpha4 +.. title:: EduStudio v1.0.0-alpha5 .. image:: assets/logo.png ========================================================= diff --git a/docs/source/user_guide/reference_table.md b/docs/source/user_guide/reference_table.md index 35f134b..680a0f4 100644 --- a/docs/source/user_guide/reference_table.md +++ b/docs/source/user_guide/reference_table.md @@ -4,51 +4,51 @@ | Model | DataTPL | TrainTPL | EvalTPL | | :------ | ---------------------: | :-------------: | ------------------------------------------------------ | -| IRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| MIRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| NCDM | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CNCD_Q | CNCDQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| CNCD_F | CNCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| DINA | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| HierCDF | HierCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CDGK | CDGKDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CDMFKC | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| ECD | ECDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| IRR | IRRDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| KaNCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| KSCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| MGCD | MGCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | -| RCD | RCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL | +| IRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| MIRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| NCDM | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | +| CNCD_Q | CNCDQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| CNCD_F | CNCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DINA | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | +| HierCDF | HierCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | +| CDGK | CDGKDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | +| CDMFKC | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| ECD | ECDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| IRR | IRRDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| KaNCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | +| KSCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| MGCD | MGCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| RCD | RCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | ## KT models | Model | DataTPL | TrainTPL | EvalTPL | | :----------- | ----------------------: | :-------------: | --------------------------- | -| AKT | KTInterDataTPLCptUnfold | KTInterTrainTPL | BinaryClassificationEvalTPL | +| AKT | KTInterDataTPLCptUnfold | EduTrainTPL | BinaryClassificationEvalTPL | | ATKT | KTInterDataTPLCptUnfold | AtktTrainTPL | BinaryClassificationEvalTPL | -| CKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| CL4KT | CL4KTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| CT_NCM | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DeepIRT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DIMKT | DIMKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DKTDSC | DKTDSCDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DKTForget | DKTForgetDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DKT_plus | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DKVMN | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| DTransformer | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| EERNN | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| EKT | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| HawkesKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| IEKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| KQN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| LPKT | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| LPKT_S | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| QDKT | QDKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| QIKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| RKT | RKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| SAINT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| SAINT_plus | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| SAKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| SimpleKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | -| SKVMN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL | +| CKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| CL4KT | CL4KTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| CT_NCM | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DeepIRT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DIMKT | DIMKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DKTDSC | DKTDSCDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DKTForget | DKTForgetDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DKT_plus | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DKVMN | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| DTransformer | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| EERNN | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| EKT | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| HawkesKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| IEKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| KQN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| LPKT | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| LPKT_S | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| QDKT | QDKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| QIKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| RKT | RKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| SAINT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| SAINT_plus | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| SAKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| SimpleKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | +| SKVMN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL | diff --git a/docs/source/user_guide/usage/aht.md b/docs/source/user_guide/usage/aht.md index fc93128..936b54d 100644 --- a/docs/source/user_guide/usage/aht.md +++ b/docs/source/user_guide/usage/aht.md @@ -54,7 +54,7 @@ def objective_function(args): search_space= { - 'traintpl_cfg.cls': tune.grid_search(['CDInterTrainTPL']), + 'traintpl_cfg.cls': tune.grid_search(['EduTrainTPL']), 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']), 'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), @@ -115,7 +115,7 @@ def objective_function(args): space = { - 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['CDInterTrainTPL']), + 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['EduTrainTPL']), 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']), 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), diff --git a/docs/source/user_guide/usage/run_edustudio.md b/docs/source/user_guide/usage/run_edustudio.md index be78d91..bb1ecdf 100644 --- a/docs/source/user_guide/usage/run_edustudio.md +++ b/docs/source/user_guide/usage/run_edustudio.md @@ -11,7 +11,7 @@ run_edustudio( dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL' @@ -48,7 +48,7 @@ datatpl_cfg: cls: CDInterDataTPL traintpl_cfg: - cls: CDTrainTPL + cls: EduTrainTPL batch_size: 512 modeltpl_cfg: diff --git a/edustudio/__init__.py b/edustudio/__init__.py index 819bd80..c8514e6 100644 --- a/edustudio/__init__.py +++ b/edustudio/__init__.py @@ -2,4 +2,4 @@ from __future__ import print_function from __future__ import division -__version__ = '1.0.0-alpha4' +__version__ = 'v1.0.0-alpha5' diff --git a/edustudio/traintpl/__init__.py b/edustudio/traintpl/__init__.py index e7c67cb..c8436e4 100644 --- a/edustudio/traintpl/__init__.py +++ b/edustudio/traintpl/__init__.py @@ -4,6 +4,5 @@ from .base_traintpl import BaseTrainTPL from .gd_traintpl import GDTrainTPL -from .cd_inter_traintpl import CDInterTrainTPL -from .kt_inter_traintpl import KTInterTrainTPL +from .edu_traintpl import EduTrainTPL from .atkt_traintpl import AtktTrainTPL diff --git a/edustudio/traintpl/base_traintpl.py b/edustudio/traintpl/base_traintpl.py index 96c88ca..eee6172 100644 --- a/edustudio/traintpl/base_traintpl.py +++ b/edustudio/traintpl/base_traintpl.py @@ -60,7 +60,7 @@ def get_default_cfg(cls): return cfg def start(self): - self.logger.info(f"TrainTPL {self.__class__.__base__} Started!") + self.logger.info(f"TrainTPL {self.__class__} Started!") set_same_seeds(self.traintpl_cfg['seed']) def _check_params(self): diff --git a/edustudio/traintpl/cd_inter_traintpl.py b/edustudio/traintpl/cd_inter_traintpl.py deleted file mode 100644 index 43922bc..0000000 --- a/edustudio/traintpl/cd_inter_traintpl.py +++ /dev/null @@ -1,161 +0,0 @@ -from .gd_traintpl import GDTrainTPL -from edustudio.utils.common import UnifyConfig, set_same_seeds, tensor2npy -from edustudio.utils.callback import ModelCheckPoint, EarlyStopping, History, BaseLogger, Callback, CallbackList -from edustudio.model import BaseModel -import torch -from typing import Sequence -from collections import defaultdict -from tqdm import tqdm -import numpy as np -import shutil - -class CDInterTrainTPL(GDTrainTPL): - default_cfg = { - 'num_stop_rounds': 10, - 'early_stop_metrics': [('auc','max')], - 'best_epoch_metric': 'auc', - 'unsave_best_epoch_pth': True, - 'ignore_metrics_in_train': [] - } - - def __init__(self, cfg: UnifyConfig): - super().__init__(cfg) - - def _check_params(self): - super()._check_params() - assert self.traintpl_cfg['best_epoch_metric'] in set(i[0] for i in self.traintpl_cfg['early_stop_metrics']) - - def one_fold_start(self, fold_id): - super().one_fold_start(fold_id) - # callbacks - num_stop_rounds = self.traintpl_cfg['num_stop_rounds'] - es_metrics = self.traintpl_cfg["early_stop_metrics"] - modelCheckPoint = ModelCheckPoint( - es_metrics, save_folder_path=f"{self.frame_cfg.temp_folder_path}/pths/{fold_id}/" - ) - es_cb = EarlyStopping(es_metrics, num_stop_rounds=num_stop_rounds, start_round=1) - history_cb = History(folder_path=f"{self.frame_cfg.temp_folder_path}/history/{fold_id}", plot_curve=False) - callbacks = [ - modelCheckPoint, es_cb, history_cb, - BaseLogger(self.logger, group_by_contains=['loss']) - ] - self.callback_list = CallbackList(callbacks=callbacks, model=self.model, logger=self.logger) - # evaltpls - for evaltpl in self.evaltpls: - evaltpl.set_callback_list(self.callback_list) - evaltpl.set_dataloaders(train_loader=self.train_loader, - valid_loader=self.valid_loader, - test_loader=self.test_loader - ) - # train - set_same_seeds(self.traintpl_cfg['seed']) - if self.valid_loader is not None: - self.fit(train_loader=self.train_loader, valid_loader=self.valid_loader) - else: - self.fit(train_loader=self.train_loader, valid_loader=self.test_loader) - - metric_name = self.traintpl_cfg['best_epoch_metric'] - metric = [m for m in modelCheckPoint.metric_list if m.name == metric_name][0] - if self.valid_loader is not None: - # load best params - fpth = f"{self.frame_cfg.temp_folder_path}/pths/{fold_id}/best-epoch-{metric.best_epoch:03d}-for-{metric.name}.pth" - self.model.load_state_dict(torch.load(fpth)) - - metrics = self.inference(self.test_loader) - for name in metrics: self.logger.info(f"{name}: {metrics[name]}") - History.dump_json(metrics, f"{self.frame_cfg.temp_folder_path}/{fold_id}/result.json") - else: - metrics = history_cb.log_as_time[metric.best_epoch] - - if self.traintpl_cfg['unsave_best_epoch_pth']: shutil.rmtree(f"{self.frame_cfg.temp_folder_path}/pths/") - return metrics - - def fit(self, train_loader, valid_loader): - self.model.train() - self.optimizer = self._get_optim() - self.callback_list.on_train_begin() - for epoch in range(self.traintpl_cfg['epoch_num']): - self.callback_list.on_epoch_begin(epoch + 1) - logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32)) - for batch_id, batch_dict in enumerate( - tqdm(train_loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[EPOCH={:03d}]".format(epoch + 1)) - ): - batch_dict = self.batch_dict2device(batch_dict) - loss_dict = self.model.get_loss_dict(**batch_dict) - loss = torch.hstack([i for i in loss_dict.values() if i is not None]).sum() - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - for k in loss_dict: logs[k][batch_id] = loss_dict[k].item() if loss_dict[k] is not None else np.nan - - for name in logs: logs[name] = float(np.nanmean(logs[name])) - - if valid_loader is not None: - val_metrics = self.evaluate(valid_loader) - logs.update({f"{metric}": val_metrics[metric] for metric in val_metrics}) - - self.callback_list.on_epoch_end(epoch + 1, logs=logs) - if self.model.share_callback_dict.get('stop_training', False): - break - - self.callback_list.on_train_end() - - @torch.no_grad() - def evaluate(self, loader): - self.model.eval() - pd_list = list(range(len(loader))) - gt_list = list(range(len(loader))) - for idx, batch_dict in enumerate(tqdm(loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[PREDICT]")): - batch_dict = self.batch_dict2device(batch_dict) - eval_dict = self.model.predict(**batch_dict) - pd_list[idx] = eval_dict['y_pd'] - gt_list[idx] = batch_dict['label'] - y_pd = torch.hstack(pd_list) - y_gt = torch.hstack(gt_list) - - eval_data_dict = { - 'y_pd': y_pd, - 'y_gt': y_gt, - } - if hasattr(self.model, 'get_stu_status'): - eval_data_dict.update({ - 'stu_stats': tensor2npy(self.model.get_stu_status()), - }) - if hasattr(self.datatpl, 'Q_mat'): - eval_data_dict.update({ - 'Q_mat': tensor2npy(self.datatpl.Q_mat) - }) - eval_result = {} - for evaltpl in self.evaltpls: eval_result.update( - evaltpl.eval(ignore_metrics=self.traintpl_cfg['ignore_metrics_in_train'], **eval_data_dict) - ) - return eval_result - - @torch.no_grad() - def inference(self, loader): - self.model.eval() - pd_list = list(range(len(loader))) - gt_list = list(range(len(loader))) - for idx, batch_dict in enumerate(tqdm(loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[PREDICT]")): - batch_dict = self.batch_dict2device(batch_dict) - eval_dict = self.model.predict(**batch_dict) - pd_list[idx] = eval_dict['y_pd'] - gt_list[idx] = batch_dict['label'] - y_pd = torch.hstack(pd_list) - y_gt = torch.hstack(gt_list) - - eval_data_dict = { - 'y_pd': y_pd, - 'y_gt': y_gt, - } - if hasattr(self.model, 'get_stu_status'): - eval_data_dict.update({ - 'stu_stats': tensor2npy(self.model.get_stu_status()), - }) - if hasattr(self.datatpl, 'Q_mat'): - eval_data_dict.update({ - 'Q_mat': tensor2npy(self.datatpl.Q_mat) - }) - eval_result = {} - for evaltpl in self.evaltpls: eval_result.update(evaltpl.eval(**eval_data_dict)) - return eval_result diff --git a/edustudio/traintpl/kt_inter_traintpl.py b/edustudio/traintpl/edu_traintpl.py similarity index 96% rename from edustudio/traintpl/kt_inter_traintpl.py rename to edustudio/traintpl/edu_traintpl.py index 88eafd1..46eac98 100644 --- a/edustudio/traintpl/kt_inter_traintpl.py +++ b/edustudio/traintpl/edu_traintpl.py @@ -10,7 +10,7 @@ import shutil -class KTInterTrainTPL(GDTrainTPL): +class EduTrainTPL(GDTrainTPL): default_cfg = { 'num_stop_rounds': 10, 'early_stop_metrics': [('auc','max')], @@ -20,9 +20,6 @@ class KTInterTrainTPL(GDTrainTPL): 'batch_size': 32, } - def __init__(self, cfg: UnifyConfig): - super().__init__(cfg) - def _check_params(self): super()._check_params() assert self.traintpl_cfg['best_epoch_metric'] in set(i[0] for i in self.traintpl_cfg['early_stop_metrics']) @@ -111,7 +108,7 @@ def evaluate(self, loader): batch_dict = self.batch_dict2device(batch_dict) eval_dict = self.model.predict(**batch_dict) pd_list[idx] = eval_dict['y_pd'] - gt_list[idx] = eval_dict['y_gt'] + gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label'] y_pd = torch.hstack(pd_list) y_gt = torch.hstack(gt_list) @@ -142,7 +139,7 @@ def inference(self, loader): batch_dict = self.batch_dict2device(batch_dict) eval_dict = self.model.predict(**batch_dict) pd_list[idx] = eval_dict['y_pd'] - gt_list[idx] = eval_dict['y_gt'] + gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label'] y_pd = torch.hstack(pd_list) y_gt = torch.hstack(gt_list) diff --git a/examples/1.run_cd_demo.py b/examples/1.run_cd_demo.py index ce18273..e1cd225 100644 --- a/examples/1.run_cd_demo.py +++ b/examples/1.run_cd_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL' diff --git a/examples/2.run_kt_demo.py b/examples/2.run_kt_demo.py index 2c66e7c..65ba200 100644 --- a/examples/2.run_kt_demo.py +++ b/examples/2.run_kt_demo.py @@ -10,7 +10,7 @@ dataset='ASSIST_0910', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'KTInterCptUnfoldDataTPL' diff --git a/examples/3.run_with_customized_tpl.py b/examples/3.run_with_customized_tpl.py index 5bc8a1d..2a73e1d 100644 --- a/examples/3.run_with_customized_tpl.py +++ b/examples/3.run_with_customized_tpl.py @@ -11,7 +11,7 @@ dataset='ASSIST_0910', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'KTInterCptUnfoldDataTPL' diff --git a/examples/5.run_with_hyperopt.py b/examples/5.run_with_hyperopt.py index 2f5a3ed..ca65652 100644 --- a/examples/5.run_with_hyperopt.py +++ b/examples/5.run_with_hyperopt.py @@ -41,7 +41,7 @@ def objective_function(args): space = { - 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['CDInterTrainTPL']), + 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['EduTrainTPL']), 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']), 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), diff --git a/examples/6.run_with_ray.tune.py b/examples/6.run_with_ray.tune.py index 8be0f59..f69a079 100644 --- a/examples/6.run_with_ray.tune.py +++ b/examples/6.run_with_ray.tune.py @@ -43,7 +43,7 @@ def objective_function(args): search_space= { - 'traintpl_cfg.cls': tune.grid_search(['CDInterTrainTPL']), + 'traintpl_cfg.cls': tune.grid_search(['EduTrainTPL']), 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']), 'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), diff --git a/examples/single_model/run_akt_demo.py b/examples/single_model/run_akt_demo.py index a61b828..5c03de5 100644 --- a/examples/single_model/run_akt_demo.py +++ b/examples/single_model/run_akt_demo.py @@ -10,7 +10,7 @@ dataset='ASSIST_0910', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'KTInterCptUnfoldDataTPL' diff --git a/examples/single_model/run_cdgk_demo.py b/examples/single_model/run_cdgk_demo.py index 5405398..6dedba2 100644 --- a/examples/single_model/run_cdgk_demo.py +++ b/examples/single_model/run_cdgk_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL', diff --git a/examples/single_model/run_cdmfkc_demo.py b/examples/single_model/run_cdmfkc_demo.py index c07a594..7ba31d5 100644 --- a/examples/single_model/run_cdmfkc_demo.py +++ b/examples/single_model/run_cdmfkc_demo.py @@ -11,7 +11,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', 'lr': 0.01, 'epoch_num': 1000 }, diff --git a/examples/single_model/run_ckt_demo.py b/examples/single_model/run_ckt_demo.py index 0683f96..8218d82 100644 --- a/examples/single_model/run_ckt_demo.py +++ b/examples/single_model/run_ckt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterExtendsQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'CKT', diff --git a/examples/single_model/run_cl4kt_demo.py b/examples/single_model/run_cl4kt_demo.py index 8357581..dbbfd0b 100644 --- a/examples/single_model/run_cl4kt_demo.py +++ b/examples/single_model/run_cl4kt_demo.py @@ -13,7 +13,7 @@ 'cls': 'CL4KTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', 'eval_batch_size': 1024, }, modeltpl_cfg_dict={ diff --git a/examples/single_model/run_cncd_f_demo.py b/examples/single_model/run_cncd_f_demo.py index 46dcd43..f54dee8 100644 --- a/examples/single_model/run_cncd_f_demo.py +++ b/examples/single_model/run_cncd_f_demo.py @@ -13,7 +13,7 @@ 'cls': 'CNCDFDataTPL', }, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'CNCD_F', diff --git a/examples/single_model/run_cncdq_demo.py b/examples/single_model/run_cncdq_demo.py index dde7dab..2451d8e 100644 --- a/examples/single_model/run_cncdq_demo.py +++ b/examples/single_model/run_cncdq_demo.py @@ -13,7 +13,7 @@ 'cls': 'CNCDQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'CNCD_Q', diff --git a/examples/single_model/run_ctncm_demo.py b/examples/single_model/run_ctncm_demo.py index 9c14b77..80976f1 100644 --- a/examples/single_model/run_ctncm_demo.py +++ b/examples/single_model/run_ctncm_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'CT_NCM', diff --git a/examples/single_model/run_deepirt_demo.py b/examples/single_model/run_deepirt_demo.py index 2003f7c..c605499 100644 --- a/examples/single_model/run_deepirt_demo.py +++ b/examples/single_model/run_deepirt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterExtendsQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DeepIRT', diff --git a/examples/single_model/run_dimkt_demo.py b/examples/single_model/run_dimkt_demo.py index 4885c6e..b844e38 100644 --- a/examples/single_model/run_dimkt_demo.py +++ b/examples/single_model/run_dimkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'DIMKTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', 'device': 'cpu', }, modeltpl_cfg_dict={ diff --git a/examples/single_model/run_dina_demo.py b/examples/single_model/run_dina_demo.py index f179b78..72f2db9 100644 --- a/examples/single_model/run_dina_demo.py +++ b/examples/single_model/run_dina_demo.py @@ -11,7 +11,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL', diff --git a/examples/single_model/run_dkt_demo.py b/examples/single_model/run_dkt_demo.py index 5cefe9b..2896251 100644 --- a/examples/single_model/run_dkt_demo.py +++ b/examples/single_model/run_dkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DKT', diff --git a/examples/single_model/run_dkt_dsc_demo.py b/examples/single_model/run_dkt_dsc_demo.py index b6adaaf..bf50c7e 100644 --- a/examples/single_model/run_dkt_dsc_demo.py +++ b/examples/single_model/run_dkt_dsc_demo.py @@ -13,7 +13,7 @@ 'cls': 'DKTDSCDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DKTDSC', diff --git a/examples/single_model/run_dkt_plus_demo.py b/examples/single_model/run_dkt_plus_demo.py index c5ccbb5..87d9288 100644 --- a/examples/single_model/run_dkt_plus_demo.py +++ b/examples/single_model/run_dkt_plus_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DKT_plus', diff --git a/examples/single_model/run_dktforget_demo.py b/examples/single_model/run_dktforget_demo.py index f09c078..f3d7882 100644 --- a/examples/single_model/run_dktforget_demo.py +++ b/examples/single_model/run_dktforget_demo.py @@ -13,7 +13,7 @@ 'cls': 'DKTForgetDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DKTForget', diff --git a/examples/single_model/run_dkvmn_demo.py b/examples/single_model/run_dkvmn_demo.py index 640b030..87466d5 100644 --- a/examples/single_model/run_dkvmn_demo.py +++ b/examples/single_model/run_dkvmn_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterExtendsQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DKVMN', diff --git a/examples/single_model/run_dtransformer_demo.py b/examples/single_model/run_dtransformer_demo.py index 077bfbb..d1a4dac 100644 --- a/examples/single_model/run_dtransformer_demo.py +++ b/examples/single_model/run_dtransformer_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'DTransformer', diff --git a/examples/single_model/run_eernn_demo.py b/examples/single_model/run_eernn_demo.py index 6473aaf..064c4d8 100644 --- a/examples/single_model/run_eernn_demo.py +++ b/examples/single_model/run_eernn_demo.py @@ -13,7 +13,7 @@ 'cls': 'EERNNDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'EERNNA', diff --git a/examples/single_model/run_gkt_demo.py b/examples/single_model/run_gkt_demo.py index 0e2c400..b75f07b 100644 --- a/examples/single_model/run_gkt_demo.py +++ b/examples/single_model/run_gkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'GKT', diff --git a/examples/single_model/run_hawkeskt_demo.py b/examples/single_model/run_hawkeskt_demo.py index f555147..0c15854 100644 --- a/examples/single_model/run_hawkeskt_demo.py +++ b/examples/single_model/run_hawkeskt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'HawkesKT', diff --git a/examples/single_model/run_hiercdf_demo.py b/examples/single_model/run_hiercdf_demo.py index 79b746b..dc6ebb8 100644 --- a/examples/single_model/run_hiercdf_demo.py +++ b/examples/single_model/run_hiercdf_demo.py @@ -45,7 +45,7 @@ def process(self, **kwargs): dataset='JunyiExerAsCpt', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'HierCDFDataTPL', diff --git a/examples/single_model/run_iekt_demo.py b/examples/single_model/run_iekt_demo.py index f86dc38..a788bd8 100644 --- a/examples/single_model/run_iekt_demo.py +++ b/examples/single_model/run_iekt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterExtendsQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'IEKT', diff --git a/examples/single_model/run_irr_demo.py b/examples/single_model/run_irr_demo.py index ff78978..b0386ce 100644 --- a/examples/single_model/run_irr_demo.py +++ b/examples/single_model/run_irr_demo.py @@ -11,7 +11,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'IRRDataTPL', diff --git a/examples/single_model/run_irt_demo.py b/examples/single_model/run_irt_demo.py index 73859ca..f07de4e 100644 --- a/examples/single_model/run_irt_demo.py +++ b/examples/single_model/run_irt_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterDataTPL', diff --git a/examples/single_model/run_kancd_demo.py b/examples/single_model/run_kancd_demo.py index c56961c..2f2c4b4 100644 --- a/examples/single_model/run_kancd_demo.py +++ b/examples/single_model/run_kancd_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL', diff --git a/examples/single_model/run_kqn_demo.py b/examples/single_model/run_kqn_demo.py index 367ab70..186e6d9 100644 --- a/examples/single_model/run_kqn_demo.py +++ b/examples/single_model/run_kqn_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'KQN', diff --git a/examples/single_model/run_kscd_demo.py b/examples/single_model/run_kscd_demo.py index 1592777..17ef967 100644 --- a/examples/single_model/run_kscd_demo.py +++ b/examples/single_model/run_kscd_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL', diff --git a/examples/single_model/run_lpkt_demo.py b/examples/single_model/run_lpkt_demo.py index 96fbd4a..39cae6d 100644 --- a/examples/single_model/run_lpkt_demo.py +++ b/examples/single_model/run_lpkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'LPKTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'LPKT', diff --git a/examples/single_model/run_lpkt_s_demo.py b/examples/single_model/run_lpkt_s_demo.py index f6db4aa..2034783 100644 --- a/examples/single_model/run_lpkt_s_demo.py +++ b/examples/single_model/run_lpkt_s_demo.py @@ -13,7 +13,7 @@ 'cls': 'LPKTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'LPKT_S', diff --git a/examples/single_model/run_mgcd_demo.py b/examples/single_model/run_mgcd_demo.py index c57a301..c2b3b19 100644 --- a/examples/single_model/run_mgcd_demo.py +++ b/examples/single_model/run_mgcd_demo.py @@ -10,7 +10,7 @@ dataset='ASSIST_1213', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', 'early_stop_metrics': [('rmse','min')], 'best_epoch_metric': 'rmse', 'batch_size': 512 diff --git a/examples/single_model/run_mirt_demo.py b/examples/single_model/run_mirt_demo.py index 631958d..34e25b9 100644 --- a/examples/single_model/run_mirt_demo.py +++ b/examples/single_model/run_mirt_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterDataTPL', diff --git a/examples/single_model/run_ncdm_demo.py b/examples/single_model/run_ncdm_demo.py index fb4690e..aecb972 100644 --- a/examples/single_model/run_ncdm_demo.py +++ b/examples/single_model/run_ncdm_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'CDInterExtendsQDataTPL', diff --git a/examples/single_model/run_qdkt_demo.py b/examples/single_model/run_qdkt_demo.py index dc8f3e7..6298826 100644 --- a/examples/single_model/run_qdkt_demo.py +++ b/examples/single_model/run_qdkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'QDKTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'QDKT', diff --git a/examples/single_model/run_qikt_demo.py b/examples/single_model/run_qikt_demo.py index 57546be..40eeeea 100644 --- a/examples/single_model/run_qikt_demo.py +++ b/examples/single_model/run_qikt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterExtendsQDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'QIKT', diff --git a/examples/single_model/run_rcd_demo.py b/examples/single_model/run_rcd_demo.py index 4df5d5e..4ebe7cd 100644 --- a/examples/single_model/run_rcd_demo.py +++ b/examples/single_model/run_rcd_demo.py @@ -10,7 +10,7 @@ dataset='FrcSub', cfg_file_name=None, traintpl_cfg_dict={ - 'cls': 'CDInterTrainTPL', + 'cls': 'EduTrainTPL', }, datatpl_cfg_dict={ 'cls': 'RCDDataTPL', diff --git a/examples/single_model/run_rkt_demo.py b/examples/single_model/run_rkt_demo.py index 389d175..48a744d 100644 --- a/examples/single_model/run_rkt_demo.py +++ b/examples/single_model/run_rkt_demo.py @@ -13,7 +13,7 @@ 'cls': 'RKTDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'RKT', diff --git a/examples/single_model/run_saint_demo.py b/examples/single_model/run_saint_demo.py index db6cc61..14ca775 100644 --- a/examples/single_model/run_saint_demo.py +++ b/examples/single_model/run_saint_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'SAINT', diff --git a/examples/single_model/run_saint_plus_demo.py b/examples/single_model/run_saint_plus_demo.py index 862f15c..ea30c86 100644 --- a/examples/single_model/run_saint_plus_demo.py +++ b/examples/single_model/run_saint_plus_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'SAINT_plus', diff --git a/examples/single_model/run_sakt_demo.py b/examples/single_model/run_sakt_demo.py index 7ce08a2..22b70a0 100644 --- a/examples/single_model/run_sakt_demo.py +++ b/examples/single_model/run_sakt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'SAKT', diff --git a/examples/single_model/run_simplekt_demo.py b/examples/single_model/run_simplekt_demo.py index 4dd7d15..1339cfb 100644 --- a/examples/single_model/run_simplekt_demo.py +++ b/examples/single_model/run_simplekt_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptUnfoldDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'SimpleKT', diff --git a/examples/single_model/run_skvmn_demo.py b/examples/single_model/run_skvmn_demo.py index 7d1f0a1..8971edf 100644 --- a/examples/single_model/run_skvmn_demo.py +++ b/examples/single_model/run_skvmn_demo.py @@ -13,7 +13,7 @@ 'cls': 'KTInterCptAsExerDataTPL', }, traintpl_cfg_dict={ - 'cls': 'KTInterTrainTPL', + 'cls': 'EduTrainTPL', }, modeltpl_cfg_dict={ 'cls': 'SKVMN', diff --git a/setup.py b/setup.py index 4ea76d3..0a56ddb 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ setup( name="edustudio", - version="v1.0.0-alpha4", + version="v1.0.0-alpha5", description="a Unified and Templatized Framework for Student Assessment Models", long_description=long_description, python_requires='>=3.8',