Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Rita Kurban authored and Rita Kurban committed Jan 15, 2025
1 parent 5de4151 commit 52717b2
Showing 1 changed file with 46 additions and 38 deletions.
84 changes: 46 additions & 38 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""PubMedQA dataset class."""

import os
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional

import torch
from datasets import Dataset, load_dataset
from loguru import logger
from typing_extensions import override

from eva.language.data.datasets.classification import base
Expand All @@ -19,7 +20,7 @@ class PubMedQA(base.TextClassification):
def __init__(
self,
root: str | None = None,
split: Literal["train", "val", "test"] | None = None,
split: Literal["train", "validation", "test"] | None = None,
download: bool = False,
) -> None:
"""Initialize the PubMedQA dataset.
Expand All @@ -31,13 +32,51 @@ def __init__(
download: Whether to download the dataset if not found locally. Default is False.
"""
super().__init__()

self._root = root
if split is None:
self._split = "train+test+validation"
else:
self._split = "+".join(split)
self._split = split
self._download = download

def _load_dataset(self, dataset_cache_path: Optional[str]) -> Dataset:
"""Loads the PubMedQA dataset from the local cache or downloads it if needed.
Args:
dataset_cache_path: The path to the local cache.
Returns:
The loaded dataset.
"""
if dataset_cache_path and os.path.exists(dataset_cache_path):
raw_dataset = load_dataset(
dataset_cache_path,
name="pubmed_qa_labeled_fold0_source",
split=self._split or "train+test+validation",
streaming=False,
)
logger.info(f"Loaded dataset from local cache: {dataset_cache_path}")
else:
if not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)
raw_dataset = load_dataset(
"bigbio/pubmed_qa",
name="pubmed_qa_labeled_fold0_source",
split=self._split or "train+test+validation",
cache_dir=self._root if self._root else None,
streaming=False,
)
if self._root:
logger.info(f"Dataset downloaded and cached in: {self._root}")
else:
logger.info("Using dataset directly from HuggingFace without caching.")

if not isinstance(raw_dataset, Dataset):
raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")

return raw_dataset

@override
def prepare_data(self) -> None:
"""Downloads and prepares the PubMedQA dataset.
Expand All @@ -53,38 +92,7 @@ def prepare_data(self) -> None:
os.makedirs(self._root, exist_ok=True)

try:
if dataset_cache_path and os.path.exists(dataset_cache_path):
raw_dataset = load_dataset(
dataset_cache_path,
name="pubmed_qa_labeled_fold0_source",
split=self._split,
streaming=False,
)
print(f"Loaded dataset from local cache: {dataset_cache_path}")
else:
if not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)

raw_dataset = load_dataset(
"bigbio/pubmed_qa",
name="pubmed_qa_labeled_fold0_source",
split=self._split,
cache_dir=self._root if self._root else None,
streaming=False,
)
if self._root:
print(f"Dataset downloaded and cached in: {self._root}")
else:
print("Using dataset directly from Hugging Face without caching.")

if not isinstance(raw_dataset, Dataset):
raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")

self.dataset: Dataset = raw_dataset

self.dataset = self._load_dataset(dataset_cache_path)
except Exception as e:
raise RuntimeError(f"Failed to prepare dataset: {e}") from e

Expand Down

0 comments on commit 52717b2

Please sign in to comment.