Skip to content

Commit

Permalink
fix: flake8 rule B018
Browse files Browse the repository at this point in the history
Signed-off-by: win5923 <[email protected]>
  • Loading branch information
win5923 committed Oct 22, 2024
1 parent fbf24a0 commit a408283
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
14 changes: 10 additions & 4 deletions python/ray/data/_internal/datasource/range_datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import builtins
import functools
from copy import copy
from typing import Iterable, List, Optional, Tuple

Expand All @@ -25,6 +24,7 @@ def __init__(
self._block_format = block_format
self._tensor_shape = tensor_shape
self._column_name = column_name
self._schema_cache = None

def estimate_inmemory_data_size(self) -> Optional[int]:
if self._block_format == "tensor":
Expand Down Expand Up @@ -96,7 +96,7 @@ def make_blocks(
meta = BlockMetadata(
num_rows=count,
size_bytes=8 * count * element_size,
schema=copy(self._schema()),
schema=copy(self._get_schema()),
input_files=None,
exec_stats=None,
)
Expand All @@ -112,8 +112,14 @@ def make_blocks(

return read_tasks

@functools.cache
def _schema(self):
def _get_schema(self):
"""Get the schema, using cached value if available."""
if self._schema_cache is None:
self._schema_cache = self._compute_schema()
return self._schema_cache

def _compute_schema(self):
"""Compute the schema without caching."""
if self._n == 0:
return None

Expand Down
10 changes: 8 additions & 2 deletions python/ray/data/_internal/logical/operators/from_operators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import functools
from typing import TYPE_CHECKING, List, Optional, Union

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -32,6 +31,7 @@ def __init__(
RefBundle([(input_blocks[i], input_metadata[i])], owns_blocks=False)
for i in range(len(input_blocks))
]
self._output_metadata_cache = None

@property
def input_data(self) -> List[RefBundle]:
Expand All @@ -40,8 +40,14 @@ def input_data(self) -> List[RefBundle]:
def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data

@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
"""Get aggregated output metadata, using cache if available."""
if self._output_metadata_cache is None:
self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
from typing import Callable, List, Optional

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -27,17 +26,24 @@ def __init__(
)
self.input_data = input_data
self.input_data_factory = input_data_factory
self._output_metadata_cache = None

def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
return None
return self.input_data

@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
"""Get aggregated output metadata, using cache if available."""
if self.input_data is None:
return BlockMetadata(None, None, None, None, None)

if self._output_metadata_cache is None:
self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
Expand Down
11 changes: 9 additions & 2 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
from typing import Any, Dict, Optional, Union

from ray.data._internal.logical.operators.map_operator import AbstractMap
Expand Down Expand Up @@ -32,6 +31,7 @@ def __init__(
self._mem_size = mem_size
self._concurrency = concurrency
self._detected_parallelism = None
self._output_metadata_cache = None

def set_detected_parallelism(self, parallelism: int):
"""
Expand All @@ -46,14 +46,21 @@ def get_detected_parallelism(self) -> int:
"""
return self._detected_parallelism

@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
"""A ``BlockMetadata`` that represents the aggregate metadata of the outputs.
This method gets metadata from the read tasks. It doesn't trigger any actual
execution.
"""
# Legacy datasources might not implement `get_read_tasks`.
if self._output_metadata_cache is not None:
return self._output_metadata_cache

self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
if self._datasource.should_create_reader:
return BlockMetadata(None, None, None, None, None)

Expand Down

0 comments on commit a408283

Please sign in to comment.