From a4082832072c27fcd743fda15ca72a222cbebf9e Mon Sep 17 00:00:00 2001 From: win5923 Date: Tue, 22 Oct 2024 21:36:19 +0800 Subject: [PATCH] fix: flake8 rule B018 Signed-off-by: win5923 --- .../data/_internal/datasource/range_datasource.py | 14 ++++++++++---- .../_internal/logical/operators/from_operators.py | 10 ++++++++-- .../logical/operators/input_data_operator.py | 10 ++++++++-- .../_internal/logical/operators/read_operator.py | 11 +++++++++-- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/python/ray/data/_internal/datasource/range_datasource.py b/python/ray/data/_internal/datasource/range_datasource.py index 50dedbdf2fed..69318d6b424e 100644 --- a/python/ray/data/_internal/datasource/range_datasource.py +++ b/python/ray/data/_internal/datasource/range_datasource.py @@ -1,5 +1,4 @@ import builtins -import functools from copy import copy from typing import Iterable, List, Optional, Tuple @@ -25,6 +24,7 @@ def __init__( self._block_format = block_format self._tensor_shape = tensor_shape self._column_name = column_name + self._schema_cache = None def estimate_inmemory_data_size(self) -> Optional[int]: if self._block_format == "tensor": @@ -96,7 +96,7 @@ def make_blocks( meta = BlockMetadata( num_rows=count, size_bytes=8 * count * element_size, - schema=copy(self._schema()), + schema=copy(self._get_schema()), input_files=None, exec_stats=None, ) @@ -112,8 +112,14 @@ def make_blocks( return read_tasks - @functools.cache - def _schema(self): + def _get_schema(self): + """Get the schema, using cached value if available.""" + if self._schema_cache is None: + self._schema_cache = self._compute_schema() + return self._schema_cache + + def _compute_schema(self): + """Compute the schema without caching.""" if self._n == 0: return None diff --git a/python/ray/data/_internal/logical/operators/from_operators.py b/python/ray/data/_internal/logical/operators/from_operators.py index 7d2b07979c22..83b76f99690c 100644 --- a/python/ray/data/_internal/logical/operators/from_operators.py +++ b/python/ray/data/_internal/logical/operators/from_operators.py @@ -1,5 +1,4 @@ import abc -import functools from typing import TYPE_CHECKING, List, Optional, Union from ray.data._internal.execution.interfaces import RefBundle @@ -32,6 +31,7 @@ def __init__( RefBundle([(input_blocks[i], input_metadata[i])], owns_blocks=False) for i in range(len(input_blocks)) ] + self._output_metadata_cache = None @property def input_data(self) -> List[RefBundle]: @@ -40,8 +40,14 @@ def input_data(self) -> List[RefBundle]: def output_data(self) -> Optional[List[RefBundle]]: return self._input_data - @functools.cache def aggregate_output_metadata(self) -> BlockMetadata: + """Get aggregated output metadata, using cache if available.""" + if self._output_metadata_cache is None: + self._output_metadata_cache = self._compute_output_metadata() + return self._output_metadata_cache + + def _compute_output_metadata(self) -> BlockMetadata: + """Compute the output metadata without caching.""" return BlockMetadata( num_rows=self._num_rows(), size_bytes=self._size_bytes(), diff --git a/python/ray/data/_internal/logical/operators/input_data_operator.py b/python/ray/data/_internal/logical/operators/input_data_operator.py index 6592972e7c94..1feb9b94f601 100644 --- a/python/ray/data/_internal/logical/operators/input_data_operator.py +++ b/python/ray/data/_internal/logical/operators/input_data_operator.py @@ -1,4 +1,3 @@ -import functools from typing import Callable, List, Optional from ray.data._internal.execution.interfaces import RefBundle @@ -27,17 +26,24 @@ def __init__( ) self.input_data = input_data self.input_data_factory = input_data_factory + self._output_metadata_cache = None def output_data(self) -> Optional[List[RefBundle]]: if self.input_data is None: return None return self.input_data - @functools.cache def aggregate_output_metadata(self) -> BlockMetadata: + """Get aggregated output metadata, using cache if available.""" if self.input_data is None: return BlockMetadata(None, None, None, None, None) + if self._output_metadata_cache is None: + self._output_metadata_cache = self._compute_output_metadata() + return self._output_metadata_cache + + def _compute_output_metadata(self) -> BlockMetadata: + """Compute the output metadata without caching.""" return BlockMetadata( num_rows=self._num_rows(), size_bytes=self._size_bytes(), diff --git a/python/ray/data/_internal/logical/operators/read_operator.py b/python/ray/data/_internal/logical/operators/read_operator.py index b75a314ea6cb..f728de218ff2 100644 --- a/python/ray/data/_internal/logical/operators/read_operator.py +++ b/python/ray/data/_internal/logical/operators/read_operator.py @@ -1,4 +1,3 @@ -import functools from typing import Any, Dict, Optional, Union from ray.data._internal.logical.operators.map_operator import AbstractMap @@ -32,6 +31,7 @@ def __init__( self._mem_size = mem_size self._concurrency = concurrency self._detected_parallelism = None + self._output_metadata_cache = None def set_detected_parallelism(self, parallelism: int): """ @@ -46,7 +46,6 @@ def get_detected_parallelism(self) -> int: """ return self._detected_parallelism - @functools.cache def aggregate_output_metadata(self) -> BlockMetadata: """A ``BlockMetadata`` that represents the aggregate metadata of the outputs. @@ -54,6 +53,14 @@ def aggregate_output_metadata(self) -> BlockMetadata: execution. """ # Legacy datasources might not implement `get_read_tasks`. + if self._output_metadata_cache is not None: + return self._output_metadata_cache + + self._output_metadata_cache = self._compute_output_metadata() + return self._output_metadata_cache + + def _compute_output_metadata(self) -> BlockMetadata: + """Compute the output metadata without caching.""" if self._datasource.should_create_reader: return BlockMetadata(None, None, None, None, None)