Skip to content

Commit

Permalink
fixed MGCD bug and add GroupCDTrainTPL
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Dec 5, 2024
1 parent f4ebcaa commit 115da3d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/user_guide/reference_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
| IRR | IRRDataTPL | GeneralTrainTPL | PredictionEvalTPL |
| KaNCD | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL, InterpretabilityEvalTPL, IdentifiabilityEvalTPL |
| KSCD | CDInterExtendsQDataTPL | GeneralTrainTPL | PredictionEvalTPL |
| MGCD | MGCDDataTPL | GeneralTrainTPL | PredictionEvalTPL |
| MGCD | MGCDDataTPL | GroupCDTrainTPL | PredictionEvalTPL |
| RCD | RCDDataTPL | GeneralTrainTPL | PredictionEvalTPL |
| DCD | CCDDataTPL | DCDTrainTPL | PredictionEvalTPL, InterpretabilityEvalTPL, IdentifiabilityEvalTPL |
| FairCD | FAIRDataTPL | AdversarialTrainTPL | PredictionEvalTPL, FairnessEvalTPL |
Expand Down
2 changes: 1 addition & 1 deletion edustudio/traintpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .atkt_traintpl import AtktTrainTPL
from .dcd_traintpl import DCDTrainTPL
from .adversarial_traintpl import AdversarialTrainTPL
from .group_traintpl import GroupTrainTPL
from .group_cd_traintpl import GroupCDTrainTPL
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm


class GroupTrainTPL(GeneralTrainTPL):
class GroupCDTrainTPL(GeneralTrainTPL):

@torch.no_grad()
def evaluate(self, loader):
Expand Down
6 changes: 3 additions & 3 deletions examples/single_model/run_mgcd_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@
dataset='ASSIST_0910',
cfg_file_name=None,
traintpl_cfg_dict={
'cls': 'GroupTrainTPL',
'cls': 'GroupCDTrainTPL',
'early_stop_metrics': [('rmse','min')],
'best_epoch_metric': 'rmse',
'batch_size': 512
},
datatpl_cfg_dict={
'cls': 'MGCDDataTPL',
# 'load_data_from': 'rawdata',
# 'raw2mid_op': 'R2M_ASSIST_1213'
# 'raw2mid_op': 'R2M_ASSIST_0910'
},
modeltpl_cfg_dict={
'cls': 'MGCD',
},
evaltpl_cfg_dict={
'clses': ['PredictionEvalTPL'],
'PredictionEvalTPL': {
'use_metrics': ['rmse']
'use_metrics': ['acc', 'auc','rmse']
}
}
)

0 comments on commit 115da3d

Please sign in to comment.