diff --git a/src/eva/language/data/datasets/classification/pubmedqa.py b/src/eva/language/data/datasets/classification/pubmedqa.py index c9622f28..6ebe04b5 100644 --- a/src/eva/language/data/datasets/classification/pubmedqa.py +++ b/src/eva/language/data/datasets/classification/pubmedqa.py @@ -1,10 +1,11 @@ """PubMedQA dataset class.""" import os -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional import torch from datasets import Dataset, load_dataset +from loguru import logger from typing_extensions import override from eva.language.data.datasets.classification import base @@ -19,7 +20,7 @@ class PubMedQA(base.TextClassification): def __init__( self, root: str | None = None, - split: Literal["train", "val", "test"] | None = None, + split: Literal["train", "validation", "test"] | None = None, download: bool = False, ) -> None: """Initialize the PubMedQA dataset. @@ -31,13 +32,51 @@ def __init__( download: Whether to download the dataset if not found locally. Default is False. """ super().__init__() + self._root = root - if split is None: - self._split = "train+test+validation" - else: - self._split = "+".join(split) + self._split = split self._download = download + def _load_dataset(self, dataset_cache_path: Optional[str]) -> Dataset: + """Loads the PubMedQA dataset from the local cache or downloads it if needed. + + Args: + dataset_cache_path: The path to the local cache. + + Returns: + The loaded dataset. + """ + if dataset_cache_path and os.path.exists(dataset_cache_path): + raw_dataset = load_dataset( + dataset_cache_path, + name="pubmed_qa_labeled_fold0_source", + split=self._split or "train+test+validation", + streaming=False, + ) + logger.info(f"Loaded dataset from local cache: {dataset_cache_path}") + else: + if not self._download and self._root: + raise ValueError( + "Dataset not found locally and downloading is disabled. " + "Set `download=True` or provide a valid local cache." + ) + raw_dataset = load_dataset( + "bigbio/pubmed_qa", + name="pubmed_qa_labeled_fold0_source", + split=self._split or "train+test+validation", + cache_dir=self._root if self._root else None, + streaming=False, + ) + if self._root: + logger.info(f"Dataset downloaded and cached in: {self._root}") + else: + logger.info("Using dataset directly from HuggingFace without caching.") + + if not isinstance(raw_dataset, Dataset): + raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}") + + return raw_dataset + @override def prepare_data(self) -> None: """Downloads and prepares the PubMedQA dataset. @@ -53,38 +92,7 @@ def prepare_data(self) -> None: os.makedirs(self._root, exist_ok=True) try: - if dataset_cache_path and os.path.exists(dataset_cache_path): - raw_dataset = load_dataset( - dataset_cache_path, - name="pubmed_qa_labeled_fold0_source", - split=self._split, - streaming=False, - ) - print(f"Loaded dataset from local cache: {dataset_cache_path}") - else: - if not self._download and self._root: - raise ValueError( - "Dataset not found locally and downloading is disabled. " - "Set `download=True` or provide a valid local cache." - ) - - raw_dataset = load_dataset( - "bigbio/pubmed_qa", - name="pubmed_qa_labeled_fold0_source", - split=self._split, - cache_dir=self._root if self._root else None, - streaming=False, - ) - if self._root: - print(f"Dataset downloaded and cached in: {self._root}") - else: - print("Using dataset directly from Hugging Face without caching.") - - if not isinstance(raw_dataset, Dataset): - raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}") - - self.dataset: Dataset = raw_dataset - + self.dataset = self._load_dataset(dataset_cache_path) except Exception as e: raise RuntimeError(f"Failed to prepare dataset: {e}") from e