From e17710c8068b3e6796f698147770d2444c722414 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 18 Dec 2020 19:01:38 +0800 Subject: [PATCH] FEA: add config['benchmark_filename'] to load pre-split dataset. --- recbole/data/dataset/dataset.py | 7 ++++++- recbole/trainer/trainer.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index aadf4eeeb..7612d4052 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1201,7 +1201,7 @@ def copy(self, new_inter_feat): :class:`~Dataset`: the new :class:`~Dataset` object, whose interaction feature has been updated. """ nxt = copy.copy(self) - nxt.inter_feat = new_inter_feat + nxt.inter_feat = pd.DataFrame(new_inter_feat) return nxt def _calcu_split_ids(self, tot, ratios): @@ -1325,6 +1325,11 @@ def build(self, eval_setting): Returns: list: List of builded :class:`Dataset`. """ + if self.benchmark_filename_list is not None: + cumsum = list(np.cumsum(self.file_size_list)) + datasets = [self.copy(self.inter_feat[start: end]) for start, end in zip([0] + cumsum[:-1], cumsum)] + return datasets + ordering_args = eval_setting.ordering_args if ordering_args['strategy'] == 'shuffle': self.shuffle() diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 4caae0271..f855cf6cc 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -346,6 +346,9 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None): Returns: dict: eval result, key is the eval metric and value in the corresponding metric value """ + if not eval_data: + return + if load_best_model: if model_file: checkpoint_file = model_file