Skip to content

Commit

Permalink
Reduce duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rita Kurban authored and Rita Kurban committed Jan 17, 2025
1 parent 52717b2 commit 365f6b3
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,33 @@ def _load_dataset(self, dataset_cache_path: Optional[str]) -> Dataset:
dataset_cache_path: The path to the local cache.
Returns:
The loaded dataset.
The loaded Dataset object.
"""
if dataset_cache_path and os.path.exists(dataset_cache_path):
raw_dataset = load_dataset(
dataset_cache_path,
name="pubmed_qa_labeled_fold0_source",
split=self._split or "train+test+validation",
streaming=False,
is_local = bool(dataset_cache_path and os.path.exists(dataset_cache_path))
dataset_path = dataset_cache_path if is_local else "bigbio/pubmed_qa"

if not is_local and not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)

if is_local:
logger.info(f"Loaded dataset from local cache: {dataset_cache_path}")
else:
if not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)
raw_dataset = load_dataset(
"bigbio/pubmed_qa",
name="pubmed_qa_labeled_fold0_source",
split=self._split or "train+test+validation",
cache_dir=self._root if self._root else None,
streaming=False,
)
if self._root:
logger.info(f"Dataset downloaded and cached in: {self._root}")
else:
logger.info("Using dataset directly from HuggingFace without caching.")

raw_dataset = load_dataset(
dataset_path,
name="pubmed_qa_labeled_fold0_source",
split=self._split or "train+test+validation",
streaming=False,
cache_dir=self._root if (not is_local and self._root) else None,
)

if not isinstance(raw_dataset, Dataset):
raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")

Expand Down

0 comments on commit 365f6b3

Please sign in to comment.