Skip to content

Commit

Permalink
cleanup, get timeouts working, fix error propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan King committed Dec 13, 2023
1 parent 4e5fb09 commit da72b48
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 17 deletions.
13 changes: 6 additions & 7 deletions hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,8 @@ def __init__(self,
self._value = None

async def cleanup_task(self, task: asyncio.Task):
print(f'cleaning up {task}')
if task.done() and not task.cancelled():
if exc := task.exception():
print(f'raising {exc}')
raise exc
else:
task.cancel()
Expand All @@ -85,6 +83,7 @@ async def cleanup_task(self, task: asyncio.Task):

async def write(self, b):
assert not self.closed
assert not self._request_task.done()

fut = asyncio.ensure_future(self._it.feed(b))
try:
Expand All @@ -104,9 +103,11 @@ async def write(self, b):
async def _wait_closed(self):
fut = asyncio.ensure_future(self._it.stop())
self._exit_stack.push_async_callback(self.cleanup_task, fut)
async with await self._request_task as resp:
self._value = await resp.json()
await self._exit_stack.aclose()
try:
async with await self._request_task as resp:
self._value = await resp.json()
finally:
await self._exit_stack.aclose()


class _TaskManager:
Expand Down Expand Up @@ -371,7 +372,6 @@ async def insert_object(self, bucket: str, name: str, **kwargs) -> WritableStrea
f'https://storage.googleapis.com/upload/storage/v1/b/{bucket}/o',
retry=False,
**kwargs))
print(f'InsertObjectStream {bucket}/{name}')
return InsertObjectStream(it, request_task)

# Write using resumable uploads. See:
Expand All @@ -384,7 +384,6 @@ async def insert_object(self, bucket: str, name: str, **kwargs) -> WritableStrea
**kwargs
) as resp:
session_url = resp.headers['Location']
print(f'ResumableInsertObjectStream {bucket}/{name}')
return ResumableInsertObjectStream(self._session, session_url, chunk_size)

async def get_object(self, bucket: str, name: str, **kwargs) -> GetObjectStream:
Expand Down
11 changes: 10 additions & 1 deletion hail/python/hailtop/aiocloud/common/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Optional, Type, TypeVar, Mapping, Union
from typing import Optional, Type, TypeVar, Mapping, Union, Dict, Any
import time
import aiohttp
import abc
Expand Down Expand Up @@ -67,6 +67,12 @@ async def close(self) -> None:
del self._session


def coerce_timeout(kwargs: Dict[str, Any]):
timeout = kwargs.get('timeout')
if timeout and isinstance(timeout, float) or isinstance(timeout, int):
kwargs['timeout'] = aiohttp.ClientTimeout(timeout)


class Session(BaseSession):
def __init__(self,
*,
Expand All @@ -82,6 +88,7 @@ def __init__(self,
self._http_session = http_session
else:
self._owns_http_session = True
coerce_timeout(kwargs)
self._http_session = httpx.ClientSession(**kwargs)
self._credentials = credentials

Expand All @@ -96,6 +103,8 @@ async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientRespon
if k not in request_params:
request_params[k] = v

coerce_timeout(kwargs)

# retry by default
retry = kwargs.pop('retry', True)
if retry:
Expand Down
22 changes: 20 additions & 2 deletions hail/python/hailtop/aiotools/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def copy(*,
s3_kwargs: Optional[dict] = None,
transfers: List[Transfer],
verbose: bool = False,
totals: Optional[Tuple[int, int]] = None
totals: Optional[Tuple[int, int]] = None,
) -> None:
with ThreadPoolExecutor() as thread_pool:
if max_simultaneous_transfers is None:
Expand Down Expand Up @@ -173,6 +173,8 @@ async def main() -> None:
parser.add_argument('-v', '--verbose', action='store_const',
const=True, default=False,
help='show logging information')
parser.add_argument('--timeout', type=str, default=None,
help='show logging information')
args = parser.parse_args()

if args.verbose:
Expand All @@ -183,11 +185,27 @@ async def main() -> None:
if args.files is None or args.files == '-':
args.files = sys.stdin.read()
files = json.loads(args.files)
gcs_kwargs = {'gcs_requester_pays_configuration': requester_pays_project}

timeout = args.timeout
if timeout:
timeout = float(timeout)
print(timeout)
gcs_kwargs = {
'gcs_requester_pays_configuration': requester_pays_project,
'timeout': timeout,
}
azure_kwargs = {
'timeout': timeout,
}
s3_kwargs = {
'timeout': timeout,
}

await copy_from_dict(
max_simultaneous_transfers=args.max_simultaneous_transfers,
gcs_kwargs=gcs_kwargs,
azure_kwargs=azure_kwargs,
s3_kwargs=s3_kwargs,
files=files,
verbose=args.verbose
)
Expand Down
5 changes: 3 additions & 2 deletions hail/python/hailtop/aiotools/fs/copier.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ async def _copy_part(self,
this_part_size: int,
part_creator: MultiPartCreate,
return_exceptions: bool) -> None:
total_written = 0
try:
async with self.xfer_sema.acquire_manager(min(Copier.BUFFER_SIZE, this_part_size)):
async with await self.router_fs.open_from(srcfile, part_number * part_size, length=this_part_size) as srcf:
Expand All @@ -238,10 +239,10 @@ async def _copy_part(self,
raise UnexpectedEOFError()
written = await destf.write(b)
assert written == len(b)
source_report.finish_bytes(written)
total_written += written
n -= len(b)
source_report.finish_bytes(total_written)
except Exception as e:
print(f'exception {return_exceptions} {e}')
if return_exceptions:
source_report.set_exception(e)
else:
Expand Down
4 changes: 0 additions & 4 deletions hail/python/hailtop/aiotools/fs/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,12 @@ def closed(self) -> bool:
return self._closed

async def __aenter__(self) -> 'WritableStream':
print(f'aenter {self}')
return self

async def __aexit__(
self, exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
exc_traceback: Optional[TracebackType] = None) -> None:
print(f'aexit {self}')
import traceback
traceback.print_stack()
await self.wait_closed()


Expand Down
1 change: 0 additions & 1 deletion hail/python/hailtop/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ async def run_with_sema(pf: Callable[[], Awaitable[T]]):
try:
return await pf()
except Exception as exc:
print(f'bounded gather {exc}')
raise exc

tasks = [asyncio.create_task(run_with_sema(pf)) for pf in pfs]
Expand Down

0 comments on commit da72b48

Please sign in to comment.