diff --git a/README.md b/README.md index d5688d9..e22c529 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ run_edustudio( 'cls': 'NCDM', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) diff --git a/docs/source/developer_guide/customize_evaltpl.md b/docs/source/developer_guide/customize_evaltpl.md index 684f96f..823e9af 100644 --- a/docs/source/developer_guide/customize_evaltpl.md +++ b/docs/source/developer_guide/customize_evaltpl.md @@ -23,14 +23,14 @@ The protocols in ``BaseEvalTPL`` are listed as follows. EvalTPLs ---------------------- -EduStudio provides ``BinaryClassificationEvalTPL`` and ``CognitiveDiagnosisEvalTPL``, which inherent ``BaseEvalTPL``. +EduStudio provides ``PredictionEvalTPL`` and ``InterpretabilityEvalTPL``, which inherent ``BaseEvalTPL``. -### BinaryClassificationEvalTPL +### PredictionEvalTPL This EvalTPL is for the model evaluation using binary classification metrics. -The protocols in ``BinaryClassificationEvalTPL`` are listed as follows. +The protocols in ``PredictionEvalTPL`` are listed as follows. -### CognitiveDiagnosisEvalTPL +### InterpretabilityEvalT This EvalTPL is for the model evaluation for interpretability. It uses states of students and Q matrix for ``eval``, which are domain-specific in student assessment. ## Develop a New EvalTPL in EduStudio diff --git a/docs/source/features/dataset_folder_protocol.md b/docs/source/features/dataset_folder_protocol.md index d45f541..5d30d0c 100644 --- a/docs/source/features/dataset_folder_protocol.md +++ b/docs/source/features/dataset_folder_protocol.md @@ -62,7 +62,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` @@ -90,7 +90,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` @@ -118,7 +118,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` @@ -154,7 +154,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` diff --git a/docs/source/get_started/quick_start.md b/docs/source/get_started/quick_start.md index 687dd65..7948f44 100644 --- a/docs/source/get_started/quick_start.md +++ b/docs/source/get_started/quick_start.md @@ -22,7 +22,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` diff --git a/docs/source/user_guide/reference_table.md b/docs/source/user_guide/reference_table.md index 2efb772..7223121 100644 --- a/docs/source/user_guide/reference_table.md +++ b/docs/source/user_guide/reference_table.md @@ -4,52 +4,52 @@ | Model | DataTPL | TrainTPL | EvalTPL | | :------ | ---------------------: | :-------------: | ------------------------------------------------------ | -| IRT | CDInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| MIRT | CDInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| NCDM | CDInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CNCD_Q | CNCDQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| CNCD_F | CNCDFDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DINA | CDInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| HierCDF | HierCDFDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CDGK | CDGKDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| CDMFKC | CDInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| ECD | ECDDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| IRR | IRRDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| KaNCD | CDInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL | -| KSCD | CDInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| MGCD | MGCDDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| RCD | RCDDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | +| IRT | CDInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| MIRT | CDInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| NCDM | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL、InterpretabilityEvalTPL | +| CNCD_Q | CNCDQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| CNCD_F | CNCDFDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DINA | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL、InterpretabilityEvalTPL | +| HierCDF | HierCDFDataTPL | GeneralTrainTPL | PredictionEvalTPL、InterpretabilityEvalTPL | +| CDGK | CDGKDataTPL | GeneralTrainTPL | PredictionEvalTPL、InterpretabilityEvalTPL | +| CDMFKC | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| ECD | ECDDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| IRR | IRRDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| KaNCD | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL、InterpretabilityEvalTPL | +| KSCD | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| MGCD | MGCDDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| RCD | RCDDataTPL | GeneralTrainTPL | PredictionEvalTPL | ## KT models | Model | DataTPL | TrainTPL | EvalTPL | | :----------- | ----------------------: | :-------------: | --------------------------- | -| AKT | KTInterDataTPLCptUnfold | GeneralTrainTPL | BinaryClassificationEvalTPL | -| ATKT | KTInterDataTPLCptUnfold | AtktTrainTPL | BinaryClassificationEvalTPL | -| CKT | KTInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| CL4KT | CL4KTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| CT_NCM | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DeepIRT | KTInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DIMKT | DIMKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DKT | KTInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DKTDSC | DKTDSCDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DKTForget | DKTForgetDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DKT_plus | KTInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DKVMN | KTInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| DTransformer | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| EERNN | EERNNDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| EKT | EKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| GKT | GKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| HawkesKT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| IEKT | KTInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| KQN | KTInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| LPKT | LPKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| LPKT_S | LPKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| QDKT | QDKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| QIKT | KTInterExtendsQDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| RKT | RKTDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| SAINT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| SAINT_plus | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| SAKT | KTInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| SimpleKT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | -| SKVMN | KTInterDataTPL | GeneralTrainTPL | BinaryClassificationEvalTPL | +| AKT | KTInterDataTPLCptUnfold | GeneralTrainTPL | PredictionEvalTPL | +| ATKT | KTInterDataTPLCptUnfold | AtktTrainTPL | PredictionEvalTPL | +| CKT | KTInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| CL4KT | CL4KTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| CT_NCM | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DeepIRT | KTInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DIMKT | DIMKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DKT | KTInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DKTDSC | DKTDSCDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DKTForget | DKTForgetDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DKT_plus | KTInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DKVMN | KTInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| DTransformer | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| EERNN | EERNNDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| EKT | EKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| GKT | GKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| HawkesKT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| IEKT | KTInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| KQN | KTInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| LPKT | LPKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| LPKT_S | LPKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| QDKT | QDKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| QIKT | KTInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| RKT | RKTDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| SAINT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| SAINT_plus | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| SAKT | KTInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| SimpleKT | KTInterCptUnfoldDataTPL | GeneralTrainTPL | PredictionEvalTPL | +| SKVMN | KTInterDataTPL | GeneralTrainTPL | PredictionEvalTPL | diff --git a/docs/source/user_guide/usage/aht.md b/docs/source/user_guide/usage/aht.md index 6645eaa..edec924 100644 --- a/docs/source/user_guide/usage/aht.md +++ b/docs/source/user_guide/usage/aht.md @@ -57,7 +57,7 @@ search_space= { 'traintpl_cfg.cls': tune.grid_search(['GeneralTrainTPL']), 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']), - 'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), + 'evaltpl_cfg.clses': tune.grid_search([['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 'traintpl_cfg.batch_size': tune.grid_search([256,]), @@ -118,7 +118,7 @@ space = { 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['GeneralTrainTPL']), 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']), - 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), + 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 'traintpl_cfg.batch_size': hp.choice('traintpl_cfg.batch_size', [256,]), diff --git a/docs/source/user_guide/usage/run_edustudio.md b/docs/source/user_guide/usage/run_edustudio.md index cfaa452..d33b9ac 100644 --- a/docs/source/user_guide/usage/run_edustudio.md +++ b/docs/source/user_guide/usage/run_edustudio.md @@ -20,7 +20,7 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` @@ -55,7 +55,7 @@ modeltpl_cfg: cls: NCDM evaltpl_cfg: - clses: [BinaryClassificationEvalTPL, CognitiveDiagnosisEvalTPL] + clses: [PredictionEvalTPL, InterpretabilityEvalT] ``` then, run command: diff --git a/docs/source/user_guide/usage/use_case_of_config.md b/docs/source/user_guide/usage/use_case_of_config.md index f09be42..d696d5d 100644 --- a/docs/source/user_guide/usage/use_case_of_config.md +++ b/docs/source/user_guide/usage/use_case_of_config.md @@ -34,15 +34,15 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) ``` ## Q2: How to specify the config of evaluate template -The default_cfg of `BinaryClassificationEvalTPL` is as follows: +The default_cfg of `PredictionEvalTPL` is as follows: ```python -class BinaryClassificationEvalTPL(BaseEvalTPL): +class PredictionEvalTPL(BaseEvalTPL): default_cfg = { 'use_metrics': ['auc', 'acc', 'rmse'] } @@ -70,8 +70,8 @@ run_edustudio( 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], - 'CognitiveDiagnosisEvalTPL': { + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], + 'InterpretabilityEvalTPL': { 'use_metrics': {"auc"} # look here } } diff --git a/edustudio/evaltpl/__init__.py b/edustudio/evaltpl/__init__.py index 7079012..f8d8dff 100644 --- a/edustudio/evaltpl/__init__.py +++ b/edustudio/evaltpl/__init__.py @@ -1,4 +1,4 @@ from .base_evaltpl import BaseEvalTPL -from .bc_evaltpl import BinaryClassificationEvalTPL -from .cd_evaltpl import CognitiveDiagnosisEvalTPL +from .prediction_evaltpl import PredictionEvalTPL +from .interpretability_evaltpl import InterpretabilityEvalTPL from .fairness_evaltpl import FairnessEvalTPL diff --git a/edustudio/evaltpl/fairness_evaltpl.py b/edustudio/evaltpl/fairness_evaltpl.py index 13a53a6..240219e 100644 --- a/edustudio/evaltpl/fairness_evaltpl.py +++ b/edustudio/evaltpl/fairness_evaltpl.py @@ -5,6 +5,8 @@ class FairnessEvalTPL(BaseEvalTPL): + """Fairness Cognitive Evaluation + """ default_cfg = { 'use_sensi_attrs': ['gender:token'], 'use_metrics': ['EO', 'DP', 'FCD'] diff --git a/edustudio/evaltpl/cd_evaltpl.py b/edustudio/evaltpl/interpretability_evaltpl.py similarity index 97% rename from edustudio/evaltpl/cd_evaltpl.py rename to edustudio/evaltpl/interpretability_evaltpl.py index 4e6141e..b89daaa 100644 --- a/edustudio/evaltpl/cd_evaltpl.py +++ b/edustudio/evaltpl/interpretability_evaltpl.py @@ -6,7 +6,9 @@ from edustudio.utils.callback import ModeState -class CognitiveDiagnosisEvalTPL(BaseEvalTPL): +class InterpretabilityEvalTPL(BaseEvalTPL): + """Student Cogntive Representation Interpretability Evaluation + """ default_cfg = { 'use_metrics': ['doa_all'], 'test_only_metrics': ['doa_all'] @@ -189,10 +191,4 @@ def doa_eval(self, y_true, y_pred): doa.append(_doa / _z) z_support += _z # 有效pair个数 doa_support += 1 # 有效doa - # return { - # "doa": np.mean(doa), - # "doa_know_support": doa_support, - # "doa_z_support": z_support, - # "doa_list": doa, - # } return float(np.mean(doa)) diff --git a/edustudio/evaltpl/bc_evaltpl.py b/edustudio/evaltpl/prediction_evaltpl.py similarity index 95% rename from edustudio/evaltpl/bc_evaltpl.py rename to edustudio/evaltpl/prediction_evaltpl.py index 6c04d76..cd8997b 100644 --- a/edustudio/evaltpl/bc_evaltpl.py +++ b/edustudio/evaltpl/prediction_evaltpl.py @@ -4,7 +4,9 @@ from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, f1_score, label_ranking_loss, coverage_error -class BinaryClassificationEvalTPL(BaseEvalTPL): +class PredictionEvalTPL(BaseEvalTPL): + """Student Performance Prediction Evaluation + """ default_cfg = { 'use_metrics': ['auc', 'acc', 'rmse'] } diff --git a/examples/1.run_cd_demo.py b/examples/1.run_cd_demo.py index dab7f6d..e747cf3 100644 --- a/examples/1.run_cd_demo.py +++ b/examples/1.run_cd_demo.py @@ -19,6 +19,6 @@ 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) diff --git a/examples/2.run_kt_demo.py b/examples/2.run_kt_demo.py index 8b167ba..379a8f3 100644 --- a/examples/2.run_kt_demo.py +++ b/examples/2.run_kt_demo.py @@ -19,6 +19,6 @@ 'cls': 'DKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/3.run_with_customized_tpl.py b/examples/3.run_with_customized_tpl.py index 988855d..9c2e58e 100644 --- a/examples/3.run_with_customized_tpl.py +++ b/examples/3.run_with_customized_tpl.py @@ -20,6 +20,6 @@ 'cls': DKT, }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/5.run_with_hyperopt.py b/examples/5.run_with_hyperopt.py index 555e5bf..93dcdd1 100644 --- a/examples/5.run_with_hyperopt.py +++ b/examples/5.run_with_hyperopt.py @@ -44,7 +44,7 @@ def objective_function(args): 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['GeneralTrainTPL']), 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']), - 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), + 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 'traintpl_cfg.batch_size': hp.choice('traintpl_cfg.batch_size', [256,]), diff --git a/examples/6.run_with_ray.tune.py b/examples/6.run_with_ray.tune.py index 7f7b922..e3c305b 100644 --- a/examples/6.run_with_ray.tune.py +++ b/examples/6.run_with_ray.tune.py @@ -42,7 +42,7 @@ def objective_function(args): 'traintpl_cfg.cls': tune.grid_search(['GeneralTrainTPL']), 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']), 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']), - 'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]), + 'evaltpl_cfg.clses': tune.grid_search([['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 'traintpl_cfg.batch_size': tune.grid_search([256,]), diff --git a/examples/single_model/run_akt_demo.py b/examples/single_model/run_akt_demo.py index 0af7a6f..9a35fed 100644 --- a/examples/single_model/run_akt_demo.py +++ b/examples/single_model/run_akt_demo.py @@ -19,6 +19,6 @@ 'cls': 'AKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_atkt_demo.py b/examples/single_model/run_atkt_demo.py index 144da03..91155ef 100644 --- a/examples/single_model/run_atkt_demo.py +++ b/examples/single_model/run_atkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'ATKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_cdgk_demo.py b/examples/single_model/run_cdgk_demo.py index 0afa185..8ebc5cb 100644 --- a/examples/single_model/run_cdgk_demo.py +++ b/examples/single_model/run_cdgk_demo.py @@ -23,6 +23,6 @@ 'cls': 'CDGK_MULTI', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_cdmfkc_demo.py b/examples/single_model/run_cdmfkc_demo.py index 7742877..365b081 100644 --- a/examples/single_model/run_cdmfkc_demo.py +++ b/examples/single_model/run_cdmfkc_demo.py @@ -22,6 +22,6 @@ 'cls': 'CDMFKC', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_ckt_demo.py b/examples/single_model/run_ckt_demo.py index b0a67d1..8c6eeba 100644 --- a/examples/single_model/run_ckt_demo.py +++ b/examples/single_model/run_ckt_demo.py @@ -21,6 +21,6 @@ 'cls': 'CKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_cl4kt_demo.py b/examples/single_model/run_cl4kt_demo.py index 6a3609a..e737997 100644 --- a/examples/single_model/run_cl4kt_demo.py +++ b/examples/single_model/run_cl4kt_demo.py @@ -20,6 +20,6 @@ 'cls': 'CL4KT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_cncd_f_demo.py b/examples/single_model/run_cncd_f_demo.py index 18fe89c..33d44ed 100644 --- a/examples/single_model/run_cncd_f_demo.py +++ b/examples/single_model/run_cncd_f_demo.py @@ -19,6 +19,6 @@ 'cls': 'CNCD_F', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_cncdq_demo.py b/examples/single_model/run_cncdq_demo.py index b579435..10c39e1 100644 --- a/examples/single_model/run_cncdq_demo.py +++ b/examples/single_model/run_cncdq_demo.py @@ -19,6 +19,6 @@ 'cls': 'CNCD_Q', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_ctncm_demo.py b/examples/single_model/run_ctncm_demo.py index 4017724..2145076 100644 --- a/examples/single_model/run_ctncm_demo.py +++ b/examples/single_model/run_ctncm_demo.py @@ -19,7 +19,7 @@ 'cls': 'CT_NCM', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_deepirt_demo.py b/examples/single_model/run_deepirt_demo.py index 4d32e2a..85813d4 100644 --- a/examples/single_model/run_deepirt_demo.py +++ b/examples/single_model/run_deepirt_demo.py @@ -20,6 +20,6 @@ 'cls': 'DeepIRT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dimkt_demo.py b/examples/single_model/run_dimkt_demo.py index 547f30e..970f0bb 100644 --- a/examples/single_model/run_dimkt_demo.py +++ b/examples/single_model/run_dimkt_demo.py @@ -20,6 +20,6 @@ 'cls': 'DIMKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dina_demo.py b/examples/single_model/run_dina_demo.py index 0d2d705..6d7bf9e 100644 --- a/examples/single_model/run_dina_demo.py +++ b/examples/single_model/run_dina_demo.py @@ -20,6 +20,6 @@ 'cls': 'DINA', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } ) diff --git a/examples/single_model/run_dkt_demo.py b/examples/single_model/run_dkt_demo.py index 770812d..1bf2d3b 100644 --- a/examples/single_model/run_dkt_demo.py +++ b/examples/single_model/run_dkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'DKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dkt_dsc_demo.py b/examples/single_model/run_dkt_dsc_demo.py index fc61c98..e198f2a 100644 --- a/examples/single_model/run_dkt_dsc_demo.py +++ b/examples/single_model/run_dkt_dsc_demo.py @@ -19,6 +19,6 @@ 'cls': 'DKTDSC', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dkt_plus_demo.py b/examples/single_model/run_dkt_plus_demo.py index 09d751e..01bd653 100644 --- a/examples/single_model/run_dkt_plus_demo.py +++ b/examples/single_model/run_dkt_plus_demo.py @@ -19,6 +19,6 @@ 'cls': 'DKT_plus', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dktforget_demo.py b/examples/single_model/run_dktforget_demo.py index 596338b..9c4c113 100644 --- a/examples/single_model/run_dktforget_demo.py +++ b/examples/single_model/run_dktforget_demo.py @@ -19,6 +19,6 @@ 'cls': 'DKTForget', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dkvmn_demo.py b/examples/single_model/run_dkvmn_demo.py index 784569f..f54b267 100644 --- a/examples/single_model/run_dkvmn_demo.py +++ b/examples/single_model/run_dkvmn_demo.py @@ -20,7 +20,7 @@ 'cls': 'DKVMN', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_dtransformer_demo.py b/examples/single_model/run_dtransformer_demo.py index 1a612fe..bc34d99 100644 --- a/examples/single_model/run_dtransformer_demo.py +++ b/examples/single_model/run_dtransformer_demo.py @@ -21,7 +21,7 @@ 'cls': 'DTransformer', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_ecd_demo.py b/examples/single_model/run_ecd_demo.py index a57dfa3..a931603 100644 --- a/examples/single_model/run_ecd_demo.py +++ b/examples/single_model/run_ecd_demo.py @@ -21,6 +21,6 @@ 'cls': 'ECD_IRT',#ECD_IRT,ECD_MIRT,ECD_NCD }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_eernn_demo.py b/examples/single_model/run_eernn_demo.py index 305a14f..ef05848 100644 --- a/examples/single_model/run_eernn_demo.py +++ b/examples/single_model/run_eernn_demo.py @@ -19,7 +19,7 @@ 'cls': 'EERNNA', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_ekt_demo.py b/examples/single_model/run_ekt_demo.py index 9b9c860..21c2d68 100644 --- a/examples/single_model/run_ekt_demo.py +++ b/examples/single_model/run_ekt_demo.py @@ -24,6 +24,6 @@ 'cls': 'EKTM', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_gkt_demo.py b/examples/single_model/run_gkt_demo.py index c0e302e..190a325 100644 --- a/examples/single_model/run_gkt_demo.py +++ b/examples/single_model/run_gkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'GKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_hawkeskt_demo.py b/examples/single_model/run_hawkeskt_demo.py index 8052e14..b9c336c 100644 --- a/examples/single_model/run_hawkeskt_demo.py +++ b/examples/single_model/run_hawkeskt_demo.py @@ -19,6 +19,6 @@ 'cls': 'HawkesKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_hiercdf_demo.py b/examples/single_model/run_hiercdf_demo.py index 6fb36d1..804f15a 100644 --- a/examples/single_model/run_hiercdf_demo.py +++ b/examples/single_model/run_hiercdf_demo.py @@ -65,6 +65,6 @@ def process(self, **kwargs): 'cls': 'HierCDF', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_iekt_demo.py b/examples/single_model/run_iekt_demo.py index bbe0c4c..1a737ca 100644 --- a/examples/single_model/run_iekt_demo.py +++ b/examples/single_model/run_iekt_demo.py @@ -19,6 +19,6 @@ 'cls': 'IEKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_irr_demo.py b/examples/single_model/run_irr_demo.py index 8d68e63..fceb955 100644 --- a/examples/single_model/run_irr_demo.py +++ b/examples/single_model/run_irr_demo.py @@ -21,6 +21,6 @@ 'cls': 'IRR', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_irt_demo.py b/examples/single_model/run_irt_demo.py index 8f70b66..7b28e56 100644 --- a/examples/single_model/run_irt_demo.py +++ b/examples/single_model/run_irt_demo.py @@ -19,6 +19,6 @@ 'cls': 'IRT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_kancd_demo.py b/examples/single_model/run_kancd_demo.py index 46ea9e1..78c981b 100644 --- a/examples/single_model/run_kancd_demo.py +++ b/examples/single_model/run_kancd_demo.py @@ -19,6 +19,6 @@ 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_kqn_demo.py b/examples/single_model/run_kqn_demo.py index e7fbcec..0a2ded8 100644 --- a/examples/single_model/run_kqn_demo.py +++ b/examples/single_model/run_kqn_demo.py @@ -19,6 +19,6 @@ 'cls': 'KQN', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_kscd_demo.py b/examples/single_model/run_kscd_demo.py index ad12591..c695758 100644 --- a/examples/single_model/run_kscd_demo.py +++ b/examples/single_model/run_kscd_demo.py @@ -19,6 +19,6 @@ 'cls': 'KSCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_lpkt_demo.py b/examples/single_model/run_lpkt_demo.py index 7ccd544..a925a57 100644 --- a/examples/single_model/run_lpkt_demo.py +++ b/examples/single_model/run_lpkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'LPKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_lpkt_s_demo.py b/examples/single_model/run_lpkt_s_demo.py index fa0ca37..814dac8 100644 --- a/examples/single_model/run_lpkt_s_demo.py +++ b/examples/single_model/run_lpkt_s_demo.py @@ -19,6 +19,6 @@ 'cls': 'LPKT_S', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_mgcd_demo.py b/examples/single_model/run_mgcd_demo.py index d44b232..b79aca5 100644 --- a/examples/single_model/run_mgcd_demo.py +++ b/examples/single_model/run_mgcd_demo.py @@ -24,8 +24,8 @@ 'cls': 'MGCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], - 'BinaryClassificationEvalTPL': { + 'clses': ['PredictionEvalTPL'], + 'PredictionEvalTPL': { 'use_metrics': ['rmse'] } } diff --git a/examples/single_model/run_mirt_demo.py b/examples/single_model/run_mirt_demo.py index aa5e2b6..967cf9c 100644 --- a/examples/single_model/run_mirt_demo.py +++ b/examples/single_model/run_mirt_demo.py @@ -19,6 +19,6 @@ 'cls': 'MIRT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_ncdm_demo.py b/examples/single_model/run_ncdm_demo.py index 178c314..352511c 100644 --- a/examples/single_model/run_ncdm_demo.py +++ b/examples/single_model/run_ncdm_demo.py @@ -19,6 +19,6 @@ 'cls': 'NCDM', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_qdkt_demo.py b/examples/single_model/run_qdkt_demo.py index b6032fe..632397d 100644 --- a/examples/single_model/run_qdkt_demo.py +++ b/examples/single_model/run_qdkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'QDKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_qikt_demo.py b/examples/single_model/run_qikt_demo.py index 69996a0..9bacd7c 100644 --- a/examples/single_model/run_qikt_demo.py +++ b/examples/single_model/run_qikt_demo.py @@ -21,6 +21,6 @@ 'cls': 'QIKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_rcd_demo.py b/examples/single_model/run_rcd_demo.py index b815d54..b46820f 100644 --- a/examples/single_model/run_rcd_demo.py +++ b/examples/single_model/run_rcd_demo.py @@ -19,6 +19,6 @@ 'cls': 'RCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_rkt_demo.py b/examples/single_model/run_rkt_demo.py index 2a2d89b..bf9496f 100644 --- a/examples/single_model/run_rkt_demo.py +++ b/examples/single_model/run_rkt_demo.py @@ -19,6 +19,6 @@ 'cls': 'RKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_saint_demo.py b/examples/single_model/run_saint_demo.py index 98eefee..5579e8b 100644 --- a/examples/single_model/run_saint_demo.py +++ b/examples/single_model/run_saint_demo.py @@ -19,7 +19,7 @@ 'cls': 'SAINT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_saint_plus_demo.py b/examples/single_model/run_saint_plus_demo.py index dce0551..c73ac54 100644 --- a/examples/single_model/run_saint_plus_demo.py +++ b/examples/single_model/run_saint_plus_demo.py @@ -19,7 +19,7 @@ 'cls': 'SAINT_plus', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_sakt_demo.py b/examples/single_model/run_sakt_demo.py index 5948afc..a0c4da3 100644 --- a/examples/single_model/run_sakt_demo.py +++ b/examples/single_model/run_sakt_demo.py @@ -19,6 +19,6 @@ 'cls': 'SAKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_simplekt_demo.py b/examples/single_model/run_simplekt_demo.py index ca2cac8..0fd877a 100644 --- a/examples/single_model/run_simplekt_demo.py +++ b/examples/single_model/run_simplekt_demo.py @@ -19,6 +19,6 @@ 'cls': 'SimpleKT', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/examples/single_model/run_skvmn_demo.py b/examples/single_model/run_skvmn_demo.py index 63fd555..fa149f6 100644 --- a/examples/single_model/run_skvmn_demo.py +++ b/examples/single_model/run_skvmn_demo.py @@ -19,6 +19,6 @@ 'cls': 'SKVMN', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL'], + 'clses': ['PredictionEvalTPL'], } ) diff --git a/tests/test_run.py b/tests/test_run.py index 0422897..3663579 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -18,6 +18,6 @@ def test_cd(self): 'cls': 'KaNCD', }, evaltpl_cfg_dict={ - 'clses': ['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL'], + 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], } )