From db812c3555b6665ad903d0ad36dea5b5deeece6b Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Mon, 20 May 2024 15:52:17 -0700 Subject: [PATCH] Add method to fetch a given property from multiple entities. (#309) --- simple/util/dc_client.py | 52 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/simple/util/dc_client.py b/simple/util/dc_client.py index 6c53b121..88b68725 100644 --- a/simple/util/dc_client.py +++ b/simple/util/dc_client.py @@ -56,8 +56,8 @@ # Make the implementation more generic if more entity types are resolved via mapping functions. _S2CELL_ENTITY_TYPE_PATTERN = r"S2CellLevel(\d+)" -# The maximum number of entities to include in a single DC resolve call. -_RESOLVE_BATCH_SIZE = 500 +# The maximum number of entities to include in a single DC call. +_BATCH_SIZE = 500 _HTTPX_LIMITS = Limits(max_keepalive_connections=5, max_connections=10) @@ -104,7 +104,7 @@ async def resolve_place_entities_async( entity_type: str = None, property_name: str = "description") -> dict[str, str]: - chunks = chunked(entities, _RESOLVE_BATCH_SIZE) + chunks = chunked(entities, _BATCH_SIZE) resolved: dict[str, str] = {} async with AsyncClient(limits=_HTTPX_LIMITS, timeout=None) as client: @@ -187,7 +187,7 @@ def resolve_non_place_entities(entities: list[str], # TODO: Cache results to file and return from cache if present. def get_entities_of_type(entity_type: str, - next_token: str = None) -> (dict[str, str], str): + next_token: str = None) -> tuple[dict[str, str], str]: data = { "nodes": [entity_type], "property": "<-typeOf", @@ -244,6 +244,50 @@ def resolve_entity_type(entity_dcids: list[str]) -> str: return common_entity_types.pop() if common_entity_types else "" +def get_property_of_entities(entities: list[str], + property_name: str) -> dict[str, str]: + return asyncio.run(get_property_of_entities_async(entities, property_name)) + + +async def get_property_of_entities_async(entities: list[str], + property_name: str) -> dict[str, str]: + + chunks = chunked(entities, _BATCH_SIZE) + + result: dict[str, str] = {} + async with AsyncClient(limits=_HTTPX_LIMITS, timeout=None) as client: + futures: dict[str, str] = [ + _get_property_of_entities_chunk(client, chunk, property_name) + for chunk in chunks + ] + for result_chunk in await asyncio.gather(*futures): + result.update(result_chunk) + + return result + + +async def _get_property_of_entities_chunk(client: AsyncClient, + entities_chunk: list[str], + property_name: str) -> dict[str, str]: + data = { + "nodes": entities_chunk, + "property": f"->{property_name}", + } + + logging.debug("Fetching nodes: %s", data) + # TODO: handle pagination. + response = await post_async(client, path="/v2/node", data=data) + + result_chunk: dict[str, str] = {} + for entity_dcid, entity_data in response.get("data", {}).items(): + nodes = entity_data.get("arcs", {}).get(property_name, {}).get("nodes", []) + values = [node.get("value") for node in nodes if node.get("value")] + if values: + result_chunk[entity_dcid] = values[0] + + return result_chunk + + def post(path: str, data={}) -> dict: url = get_api_root() + path headers = {"Content-Type": "application/json"}