Skip to content

Commit

Permalink
fix mgcd bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tzt-star committed Dec 5, 2024
0 parents commit 8f6ae5d
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
11 changes: 11 additions & 0 deletions edustudio/traintpl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
# @Author : Xiangzhi Chen
# @Github : kervias

from .base_traintpl import BaseTrainTPL
from .gd_traintpl import GDTrainTPL
from .general_traintpl import GeneralTrainTPL
from .atkt_traintpl import AtktTrainTPL
from .dcd_traintpl import DCDTrainTPL
from .adversarial_traintpl import AdversarialTrainTPL
from .group_traintpl import GroupTrainTPL
89 changes: 89 additions & 0 deletions edustudio/traintpl/group_traintpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from edustudio.traintpl import GeneralTrainTPL
from edustudio.utils.common import tensor2npy
import torch
from tqdm import tqdm


class GroupTrainTPL(GeneralTrainTPL):

@torch.no_grad()
def evaluate(self, loader):
self.model.eval()
stu_id_list = list(range(len(loader)))
pd_list = list(range(len(loader)))
gt_list = list(range(len(loader)))
for idx, batch_dict in enumerate(tqdm(loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[PREDICT]")):
batch_dict = self.batch_dict2device(batch_dict)
eval_dict = self.model.predict(**batch_dict)
stu_id_list[idx] = batch_dict['group_id']
pd_list[idx] = eval_dict['y_pd']
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
y_pd = torch.hstack(pd_list)
y_gt = torch.hstack(gt_list)
group_id = torch.hstack(stu_id_list)

eval_data_dict = {
'group_id': group_id,
'y_pd': y_pd,
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(self.datatpl, 'Q_mat'):
eval_data_dict.update({
'Q_mat': tensor2npy(self.datatpl.Q_mat)
})
eval_result = {}
for evaltpl in self.evaltpls: eval_result.update(
evaltpl.eval(ignore_metrics=self.traintpl_cfg['ignore_metrics_in_train'], **eval_data_dict)
)
return eval_result

@torch.no_grad()
def inference(self, loader):
self.model.eval()
stu_id_list = list(range(len(loader)))
pd_list = list(range(len(loader)))
gt_list = list(range(len(loader)))
for idx, batch_dict in enumerate(tqdm(loader, ncols=self.frame_cfg['TQDM_NCOLS'], desc="[PREDICT]")):
batch_dict = self.batch_dict2device(batch_dict)
eval_dict = self.model.predict(**batch_dict)
stu_id_list[idx] = batch_dict['group_id']
pd_list[idx] = eval_dict['y_pd']
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
y_pd = torch.hstack(pd_list)
y_gt = torch.hstack(gt_list)
group_id = torch.hstack(stu_id_list)

eval_data_dict = {
'group_id': group_id,
'y_pd': y_pd,
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(self.datatpl, 'Q_mat'):
eval_data_dict.update({
'Q_mat': tensor2npy(self.datatpl.Q_mat)
})
eval_result = {}
for evaltpl in self.evaltpls: eval_result.update(evaltpl.eval(**eval_data_dict))
return eval_result
32 changes: 32 additions & 0 deletions examples/single_model/run_mgcd_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
import os

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
os.chdir(os.path.dirname(os.path.abspath(__file__)))

from edustudio.quickstart import run_edustudio

run_edustudio(
dataset='ASSIST_0910',
cfg_file_name=None,
traintpl_cfg_dict={
'cls': 'GroupTrainTPL',
'early_stop_metrics': [('rmse','min')],
'best_epoch_metric': 'rmse',
'batch_size': 512
},
datatpl_cfg_dict={
'cls': 'MGCDDataTPL',
# 'load_data_from': 'rawdata',
# 'raw2mid_op': 'R2M_ASSIST_1213'
},
modeltpl_cfg_dict={
'cls': 'MGCD',
},
evaltpl_cfg_dict={
'clses': ['PredictionEvalTPL'],
'PredictionEvalTPL': {
'use_metrics': ['rmse']
}
}
)

0 comments on commit 8f6ae5d

Please sign in to comment.