diff --git a/dalm/__init__.py b/dalm/__init__.py index 3b93d0b..27fdca4 100644 --- a/dalm/__init__.py +++ b/dalm/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index c938193..a09dbdc 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -142,15 +142,20 @@ def generate_qa_from_dataset( def _load_dataset_from_path(dataset_path: str) -> Dataset: if dataset_path.endswith(".csv"): dataset = Dataset.from_csv(dataset_path) - elif not os.path.splitext(dataset_path): + elif not os.path.splitext(dataset_path)[-1]: if os.path.isdir(dataset_path): dataset = datasets.load_from_disk(dataset_path) else: dataset = datasets.load_dataset(dataset_path) - key = next(iter(dataset)) if isinstance(dataset, DatasetDict): + if "train" in dataset: + key = "train" + elif "training" in dataset: + key = "training" + else: + key = next(iter(dataset)) warnings.warn(f"Found multiple keys in dataset. Generating qa for split {key}", stacklevel=0) - dataset = dataset[key] + dataset = dataset[key] else: raise ValueError( "dataset-path must be one of csv, dataset directory "