Skip to content

Commit

Permalink
[improve] get stu status iteratively in train_tpl
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Mar 12, 2024
1 parent 6aa7649 commit 68611db
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
18 changes: 16 additions & 2 deletions edustudio/traintpl/atkt_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,15 @@ def evaluate(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(loader.dataset, 'Q_mat'):
eval_data_dict.update({
Expand Down Expand Up @@ -201,8 +208,15 @@ def inference(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(loader.dataset, 'Q_mat'):
eval_data_dict.update({
Expand Down
19 changes: 16 additions & 3 deletions edustudio/traintpl/dcd_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,16 @@ def evaluate(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})

if hasattr(self.model, 'get_exer_emb'):
eval_data_dict.update({
'exer_emb': self.model.get_exer_emb(),
Expand Down Expand Up @@ -123,8 +129,15 @@ def inference(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})

if hasattr(self.model, 'get_exer_emb'):
Expand Down
19 changes: 17 additions & 2 deletions edustudio/traintpl/general_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def one_fold_start(self, fold_id):
if self.valid_loader is not None:
self.fit(train_loader=self.train_loader, valid_loader=self.valid_loader)
else:
self.logger.info("Without validation set, replace it with test set")
self.fit(train_loader=self.train_loader, valid_loader=self.test_loader)

metric_name = self.traintpl_cfg['best_epoch_metric']
Expand Down Expand Up @@ -135,8 +136,15 @@ def evaluate(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(self.datatpl, 'Q_mat'):
eval_data_dict.update({
Expand Down Expand Up @@ -170,8 +178,15 @@ def inference(self, loader):
'y_gt': y_gt,
}
if hasattr(self.model, 'get_stu_status'):
stu_stats_list = []
idx = torch.arange(0, self.datatpl_cfg['dt_info']['stu_count']).to(self.traintpl_cfg['device'])
for i in range(0,self.datatpl_cfg['dt_info']['stu_count'], self.traintpl_cfg['eval_batch_size']):
batch_stu_id = idx[i:i+self.traintpl_cfg['eval_batch_size']]
batch = self.model.get_stu_status(batch_stu_id)
stu_stats_list.append(batch)
stu_stats = torch.vstack(stu_stats_list)
eval_data_dict.update({
'stu_stats': tensor2npy(self.model.get_stu_status()),
'stu_stats': tensor2npy(stu_stats),
})
if hasattr(self.datatpl, 'Q_mat'):
eval_data_dict.update({
Expand Down

0 comments on commit 68611db

Please sign in to comment.