From 365f6b35891f2b1d2ca1ea265460a16318d72c32 Mon Sep 17 00:00:00 2001 From: Rita Kurban Date: Fri, 17 Jan 2025 15:34:56 +0000 Subject: [PATCH] Reduce duplicate code --- .../data/datasets/classification/pubmedqa.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/eva/language/data/datasets/classification/pubmedqa.py b/src/eva/language/data/datasets/classification/pubmedqa.py index 6ebe04b5..1b465c1c 100644 --- a/src/eva/language/data/datasets/classification/pubmedqa.py +++ b/src/eva/language/data/datasets/classification/pubmedqa.py @@ -44,34 +44,33 @@ def _load_dataset(self, dataset_cache_path: Optional[str]) -> Dataset: dataset_cache_path: The path to the local cache. Returns: - The loaded dataset. + The loaded Dataset object. """ - if dataset_cache_path and os.path.exists(dataset_cache_path): - raw_dataset = load_dataset( - dataset_cache_path, - name="pubmed_qa_labeled_fold0_source", - split=self._split or "train+test+validation", - streaming=False, + is_local = bool(dataset_cache_path and os.path.exists(dataset_cache_path)) + dataset_path = dataset_cache_path if is_local else "bigbio/pubmed_qa" + + if not is_local and not self._download and self._root: + raise ValueError( + "Dataset not found locally and downloading is disabled. " + "Set `download=True` or provide a valid local cache." ) + + if is_local: logger.info(f"Loaded dataset from local cache: {dataset_cache_path}") else: - if not self._download and self._root: - raise ValueError( - "Dataset not found locally and downloading is disabled. " - "Set `download=True` or provide a valid local cache." - ) - raw_dataset = load_dataset( - "bigbio/pubmed_qa", - name="pubmed_qa_labeled_fold0_source", - split=self._split or "train+test+validation", - cache_dir=self._root if self._root else None, - streaming=False, - ) if self._root: logger.info(f"Dataset downloaded and cached in: {self._root}") else: logger.info("Using dataset directly from HuggingFace without caching.") + raw_dataset = load_dataset( + dataset_path, + name="pubmed_qa_labeled_fold0_source", + split=self._split or "train+test+validation", + streaming=False, + cache_dir=self._root if (not is_local and self._root) else None, + ) + if not isinstance(raw_dataset, Dataset): raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")