From a13ccf25c5d62cf3fc5c97a6a0b8b9a3d143fad5 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 6 Jan 2025 16:39:14 +0100 Subject: [PATCH 1/2] align remove_columns in the formatted case --- src/datasets/arrow_dataset.py | 11 ++++------- tests/test_arrow_dataset.py | 3 +-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 9648a25e139..46d88cacb04 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3336,13 +3336,11 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example if with_rank: additional_args += (rank,) processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) + returned_same_object = processed_inputs is inputs if isinstance(processed_inputs, LazyDict): processed_inputs = { k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format } - returned_lazy_dict = True - else: - returned_lazy_dict = False if update_data is None: # Check if the function returns updated examples updatable_types = (Mapping, pa.Table, pd.DataFrame) @@ -3366,10 +3364,9 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example if remove_columns is not None: for column in remove_columns: # `function` can modify input in-place causing column to be already removed. - if column in inputs_to_merge: - inputs_to_merge.pop(column) - if returned_lazy_dict and column in processed_inputs: - processed_inputs.pop(column) + inputs_to_merge.pop(column, None) + if returned_same_object: + processed_inputs.pop(column, None) if check_same_num_examples: input_num_examples = len(pa_inputs) processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))]) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 6cf8898ce67..048ef2acd80 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4356,13 +4356,12 @@ def f(x): outputs = ds[:] assert outputs == {"b": [-1, -1, 2, 3]} - # The formatted dataset version removes the lazy column from a different dictionary, hence it should be preserved in the output ds = Dataset.from_dict({"a": [0, 1, 2, 3]}) ds = ds.with_format("numpy") ds = ds.map(f, remove_columns=["a"]) ds = ds.with_format(None) outputs = ds[:] - assert outputs == {"a": [0, 1, 2, 3], "b": [-1, -1, 2, 3]} + assert outputs == {"b": [-1, -1, 2, 3]} def f(x): """May return a mix of LazyDict and regular Dict, but we replace a lazy column""" From 1222bd876b5f8f2ff0995d43f6e9570e0469d527 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 6 Jan 2025 16:42:33 +0100 Subject: [PATCH 2/2] simplify --- src/datasets/arrow_dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 46d88cacb04..a732942a7ee 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3352,15 +3352,12 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example validate_function_output(processed_inputs, indices) if not update_data: return None # Nothing to update, let's move on - if shard._format_type or input_columns: - # TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release) - inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) - elif isinstance(inputs, LazyDict): + if isinstance(inputs, LazyDict): inputs_to_merge = { k: (v if k not in inputs.keys_to_format else pa_inputs[k]) for k, v in inputs.data.items() } else: - inputs_to_merge = inputs + inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) if remove_columns is not None: for column in remove_columns: # `function` can modify input in-place causing column to be already removed.