diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 2bd2c40b..1187eba4 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -24,6 +24,10 @@ from axlearn.common.utils import Tensor +def _local_size(array: Tensor) -> int: + return sum(shard.data.nbytes for shard in array.addressable_shards) + + def _proxy(fut: asyncio.Future) -> asyncio.Future: """Returns a proxy that can be used to await (but does not cancel) `fut`.""" loop = asyncio.get_event_loop() @@ -134,7 +138,11 @@ async def async_serialize( # Memory usage seems to be proportional to the array size rather than shard sizes. # TODO(markblee): Investigate why this is the case. _acquire_and_write( - t, limiter=limiter, shard=shard, nbytes=array.nbytes, release_tasks=release_tasks + t, + limiter=limiter, + shard=shard, + nbytes=_local_size(array), + release_tasks=release_tasks, ) for shard in local_shards ) @@ -165,7 +173,7 @@ def serialize( logging.info("Waiting for previous serialization to finish.") self.wait_until_finished() - max_shard_bytes = max((0, *(array.nbytes for array in arrays))) + max_shard_bytes = max((0, *(_local_size(array) for array in arrays))) max_concurrent_bytes = self._max_concurrent_bytes if max_shard_bytes > max_concurrent_bytes: logging.warning( diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index 3d68516e..09be4adc 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -6,6 +6,7 @@ import asyncio import contextlib import functools +import math from typing import List, Optional from unittest import mock @@ -315,6 +316,12 @@ async def acquire_and_write(*, shards: List[Shard]): max_concurrent_gb=1, expect_max_concurrent_gb=1, ), + # Test non-addressable shards (which are represented by negative numbers). + dict( + arrays=[[1, 1, -1, -1], [1]], + max_concurrent_gb=1, + expect_max_concurrent_gb=2, + ), ) def test_serialize( self, arrays: List[List[int]], max_concurrent_gb: int, expect_max_concurrent_gb: int @@ -324,8 +331,9 @@ def test_serialize( addressable_shards=[ mock.Mock(replica_id=0, **{"data.nbytes": int(shard * 10**9)}) for shard in array + if shard >= 0 ], - nbytes=int(sum(array) * 10**9), + nbytes=int(sum(math.fabs(shard) for shard in array) * 10**9), ) for array in arrays ]