-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
945 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from .base_mid2cache import BaseMid2Cache | ||
import numpy as np | ||
import pandas as pd | ||
from itertools import chain | ||
import torch | ||
from edustudio.utils.common import set_same_seeds | ||
|
||
|
||
class M2C_BuildMissingQ(BaseMid2Cache): | ||
default_cfg = { | ||
'seed': 20230518, | ||
'Q_delete_ratio': 0.0, | ||
} | ||
|
||
def process(self, **kwargs): | ||
dt_info = kwargs['dt_info'] | ||
self.item_count = dt_info['exer_count'] | ||
self.cpt_count = dt_info['cpt_count'] | ||
self.df_Q = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']] | ||
|
||
self.missing_df_Q = self.get_missing_df_Q() | ||
self.missing_Q_mat = self.get_Q_mat_from_df_arr(self.missing_df_Q, self.item_count, self.cpt_count) | ||
|
||
kwargs['missing_df_Q'] = self.missing_df_Q | ||
kwargs['missing_Q_mat'] = self.missing_Q_mat | ||
|
||
return kwargs | ||
|
||
def get_missing_df_Q(self): | ||
set_same_seeds(seed=self.m2c_cfg['seed']) | ||
ratio = self.m2c_cfg['Q_delete_ratio'] | ||
iid2cptlist = self.df_Q.set_index('exer_id:token')['cpt_seq:token_seq'].to_dict() | ||
iid_lis = np.array(list(chain(*[[i]*len(iid2cptlist[i]) for i in iid2cptlist]))) | ||
cpt_lis = np.array(list(chain(*list(iid2cptlist.values())))) | ||
entry_arr = np.vstack([iid_lis, cpt_lis]).T | ||
|
||
np.random.shuffle(entry_arr) | ||
|
||
# reference: https://stackoverflow.com/questions/64834655/python-how-to-find-first-duplicated-items-in-an-numpy-array | ||
_, idx = np.unique(entry_arr[:, 1], return_index=True) # 先从每个知识点中选出1题出来 | ||
bool_idx = np.zeros_like(entry_arr[:, 1], dtype=bool) | ||
bool_idx[idx] = True | ||
preserved_exers = np.unique(entry_arr[bool_idx, 0]) # 选择符合条件的习题作为保留 | ||
|
||
delete_num = int(ratio * self.item_count) | ||
preserved_num = self.item_count - delete_num | ||
|
||
if len(preserved_exers) >= preserved_num: | ||
self.logger.warning( | ||
f"Cant Satisfy Delete Require: {len(preserved_exers)=},{preserved_num=}" | ||
) | ||
else: | ||
need_preserved_num = preserved_num - len(preserved_exers) | ||
|
||
left_iids = np.arange(self.item_count) | ||
left_iids = left_iids[~np.isin(left_iids, preserved_exers)] | ||
np.random.shuffle(left_iids) | ||
choose_iids = left_iids[0:need_preserved_num] | ||
|
||
preserved_exers = np.hstack([preserved_exers, choose_iids]) | ||
|
||
return self.df_Q.copy()[self.df_Q['exer_id:token'].isin(preserved_exers)].reset_index(drop=True) | ||
|
||
|
||
def get_Q_mat_from_df_arr(self, df_Q_arr, item_count, cpt_count): | ||
Q_mat = torch.zeros((item_count, cpt_count), dtype=torch.int64) | ||
for _, item in df_Q_arr.iterrows(): Q_mat[item['exer_id:token'], item['cpt_seq:token_seq']] = 1 | ||
return Q_mat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from .base_mid2cache import BaseMid2Cache | ||
import numpy as np | ||
import pandas as pd | ||
from itertools import chain | ||
import torch | ||
from edustudio.utils.common import set_same_seeds, tensor2npy | ||
from tqdm import tqdm | ||
|
||
class M2C_FillMissingQ(BaseMid2Cache): | ||
default_cfg = { | ||
'Q_fill_type': "None", | ||
'params_topk': 5, | ||
'params_votek': 2, | ||
} | ||
|
||
def __init__(self, m2c_cfg, cfg) -> None: | ||
self.logger = cfg.logger | ||
self.m2c_cfg = m2c_cfg | ||
self.cfg = cfg | ||
|
||
@classmethod | ||
def from_cfg(cls, cfg): | ||
return cls(cfg.datatpl_cfg.get(cls.__name__), cfg) | ||
|
||
def process(self, **kwargs): | ||
dt_info = kwargs['dt_info'] | ||
self.user_count = dt_info['stu_count'] | ||
self.item_count = dt_info['exer_count'] | ||
self.cpt_count = dt_info['cpt_count'] | ||
self.df_Q = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']] | ||
|
||
Q_mat = kwargs['Q_mat'] | ||
missing_Q_mat = kwargs['missing_Q_mat'] | ||
|
||
self.filling_Q_mat_list = [] | ||
for df_train in kwargs['df_train_folds']: | ||
if (missing_Q_mat.sum(dim=1) == 0).sum() > 0: | ||
if self.m2c_cfg['Q_fill_type'] == "sim_dist_for_by_exer": | ||
fill_df_Q = self.fill_df_Q_by_sim_dist( | ||
df_train, kwargs['missing_df_Q'], | ||
params_topk=self.m2c_cfg['params_topk'], | ||
params_votek=self.m2c_cfg['params_votek'] | ||
) | ||
fill_Q_mat = self.get_Q_mat_from_df_arr(fill_df_Q, self.item_count, self.cpt_count) | ||
self.filling_Q_mat_list.append(fill_Q_mat) | ||
elif self.m2c_cfg['Q_fill_type'] == "None": | ||
self.filling_Q_mat_list.append(missing_Q_mat) | ||
else: | ||
raise ValueError(f"unknown Q_fill_type: {self.m2c_cfg['Q_fill_type']}") | ||
else: | ||
self.filling_Q_mat_list.append(Q_mat) | ||
|
||
kwargs['filling_Q_mat_list'] = self.filling_Q_mat_list | ||
return kwargs | ||
|
||
def get_Q_mat_from_df_arr(self, df_Q_arr, item_count, cpt_count): | ||
Q_mat = np.zeros((item_count, cpt_count), dtype=np.int64) | ||
for _, item in df_Q_arr.iterrows(): Q_mat[item['exer_id:token'], item['cpt_seq:token_seq']] = 1 | ||
return Q_mat | ||
|
||
def fill_df_Q_by_sim_dist(self, df_interaction, df_Q_left, params_topk=5, params_votek=2): | ||
preserved_exers = df_Q_left['exer_id:token'].to_numpy() | ||
interact_mat = torch.zeros((self.user_count, self.item_count), dtype=torch.int8).to(self.cfg.traintpl_cfg['device']) | ||
idx = df_interaction[df_interaction['label:float'] == 1][['stu_id:token','exer_id:token']].to_numpy() | ||
interact_mat[idx[:,0], idx[:,1]] = 1 | ||
idx = df_interaction[df_interaction['label:float'] != 1][['stu_id:token','exer_id:token']].to_numpy() | ||
interact_mat[idx[:,0], idx[:,1]] = -1 | ||
|
||
interact_mat = interact_mat.T | ||
|
||
sim_mat = torch.zeros((self.item_count, self.item_count)) | ||
missing_iids = np.array(list(set(np.arange(self.item_count)) - set(preserved_exers))) | ||
for iid in tqdm(missing_iids, desc="[FILL_Q_MAT] compute sim_mat", ncols=self.cfg.frame_cfg['TQDM_NCOLS']): | ||
temp = interact_mat[iid] != 0 | ||
same_mat = interact_mat[iid] == interact_mat | ||
bool_mat = (temp) & (interact_mat != 0) | ||
same_mat[~bool_mat] = False | ||
sim_mat[iid] = same_mat.sum(dim=1) / (temp).sum() | ||
sim_mat[iid, bool_mat.sum(dim=1) == 0] = 0.0 | ||
sim_mat[iid, iid] = -1.0 | ||
sim_mat[iid, missing_iids] = -1.0 | ||
|
||
assert torch.isnan(sim_mat).sum() == 0 | ||
|
||
_, topk_mat_idx = torch.topk(sim_mat, dim=1, k=params_topk, largest=True, sorted=True) | ||
topk_mat_idx = tensor2npy(topk_mat_idx) | ||
|
||
index_df_Q = df_Q_left.set_index('exer_id:token') | ||
missing_iid_fill_cpts = {} | ||
for iid in tqdm(missing_iids, desc="[FILL_Q_MAT] fill process", ncols=self.cfg.frame_cfg['TQDM_NCOLS']): | ||
count_dict = dict(zip(*np.unique( | ||
list(chain(*[index_df_Q.loc[iid2]['cpt_seq:token_seq'] for iid2 in topk_mat_idx[iid] if iid2 in preserved_exers])), | ||
return_counts=True, | ||
))) | ||
count_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True) | ||
missing_iid_fill_cpts[iid] = [i[0] for i in count_dict[0:params_votek]] | ||
|
||
missing_fill_df_Q = pd.DataFrame( | ||
{'exer_id:token': list(missing_iid_fill_cpts.keys()),'cpt_seq:token_seq':list(missing_iid_fill_cpts.values())} | ||
) | ||
final_df_Q = pd.concat([df_Q_left, missing_fill_df_Q], axis=0, ignore_index=True) | ||
|
||
hit_ratio = 0 | ||
t_Q = self.df_Q.set_index('exer_id:token') | ||
for iid in missing_iid_fill_cpts: | ||
if len(set(t_Q.loc[iid]['cpt_seq:token_seq']) & set(missing_iid_fill_cpts[iid])) > 0: | ||
hit_ratio += 1 | ||
hit_ratio = hit_ratio / len(missing_iid_fill_cpts) | ||
|
||
self.logger.info(f"[FILL_Q] Hit_ratio={hit_ratio}") | ||
|
||
return final_df_Q |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
from ..common.edu_datatpl import EduDataTPL | ||
import json | ||
from edustudio.datatpl.common.general_datatpl import DataTPLStatus | ||
import torch | ||
|
||
|
||
class DCDDataTPL(EduDataTPL): | ||
default_cfg = { | ||
'n_folds': 5, | ||
'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat', 'M2C_BuildMissingQ', 'M2C_FillMissingQ'], | ||
'cpt_relation_file_name': 'cpt_relation', | ||
} | ||
|
||
def __init__(self, cfg, df, df_train=None, df_valid=None, df_test=None, dict_cpt_relation=None, status=DataTPLStatus(), df_stu=None, df_exer=None): | ||
self.dict_cpt_relation = dict_cpt_relation | ||
super().__init__(cfg, df, df_train, df_valid, df_test, df_stu, df_exer, status) | ||
|
||
def _check_param(self): | ||
super()._check_params() | ||
assert 0 <= self.datatpl_cfg['Q_delete_ratio'] < 1 | ||
|
||
@property | ||
def common_str2df(self): | ||
dic = super().common_str2df | ||
dic['dict_cpt_relation'] = self.dict_cpt_relation | ||
return dic | ||
|
||
|
||
def process_data(self): | ||
super().process_data() | ||
dt_info = self.final_kwargs['dt_info'] | ||
user_count = dt_info['stu_count'] | ||
item_count = dt_info['exer_count'] | ||
self.interact_mat_list = [] | ||
for interact_df in self.final_kwargs['df_train_folds']: | ||
interact_mat = torch.zeros((user_count, item_count), dtype=torch.int8) | ||
idx = interact_df[interact_df['label:float'] == 1][['stu_id:token','exer_id:token']].to_numpy() | ||
interact_mat[idx[:,0], idx[:,1]] = 1 | ||
idx = interact_df[interact_df['label:float'] != 1][['stu_id:token','exer_id:token']].to_numpy() | ||
interact_mat[idx[:,0], idx[:,1]] = -1 | ||
self.interact_mat_list.append(interact_mat) | ||
|
||
self.final_kwargs['interact_mat_list'] = self.interact_mat_list | ||
|
||
if self.final_kwargs['dict_cpt_relation'] is None: | ||
self.final_kwargs['dict_cpt_relation'] = {i: [i] for i in range(self.final_kwargs['dt_info']['cpt_count'])} | ||
|
||
@classmethod | ||
def load_data(cls, cfg): | ||
kwargs = super().load_data(cfg) | ||
fph = f"{cfg.frame_cfg.data_folder_path}/middata/{cfg.datatpl_cfg['cpt_relation_file_name']}.json" | ||
if os.path.exists(fph): | ||
with open(fph, 'r', encoding='utf-8') as f: | ||
kwargs['dict_cpt_relation'] = json.load(f) | ||
else: | ||
cfg.logger.warning("without cpt_relation.json") | ||
kwargs['dict_cpt_relation'] = None | ||
return kwargs | ||
|
||
def get_extra_data(self): | ||
extra_dict = super().get_extra_data() | ||
extra_dict['filling_Q_mat'] = self.filling_Q_mat | ||
extra_dict['interact_mat'] = self.interact_mat | ||
return extra_dict | ||
|
||
def set_info_for_fold(self, fold_id): | ||
super().set_info_for_fold(fold_id) | ||
self.filling_Q_mat = self.final_kwargs['filling_Q_mat_list'][fold_id] | ||
self.interact_mat = self.final_kwargs['interact_mat_list'][fold_id] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.