Skip to content

Commit

Permalink
Add method to fetch a given property from multiple entities. (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored May 20, 2024
1 parent a1b4df3 commit db812c3
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions simple/util/dc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
# Make the implementation more generic if more entity types are resolved via mapping functions.
_S2CELL_ENTITY_TYPE_PATTERN = r"S2CellLevel(\d+)"

# The maximum number of entities to include in a single DC resolve call.
_RESOLVE_BATCH_SIZE = 500
# The maximum number of entities to include in a single DC call.
_BATCH_SIZE = 500

_HTTPX_LIMITS = Limits(max_keepalive_connections=5, max_connections=10)

Expand Down Expand Up @@ -104,7 +104,7 @@ async def resolve_place_entities_async(
entity_type: str = None,
property_name: str = "description") -> dict[str, str]:

chunks = chunked(entities, _RESOLVE_BATCH_SIZE)
chunks = chunked(entities, _BATCH_SIZE)

resolved: dict[str, str] = {}
async with AsyncClient(limits=_HTTPX_LIMITS, timeout=None) as client:
Expand Down Expand Up @@ -187,7 +187,7 @@ def resolve_non_place_entities(entities: list[str],

# TODO: Cache results to file and return from cache if present.
def get_entities_of_type(entity_type: str,
next_token: str = None) -> (dict[str, str], str):
next_token: str = None) -> tuple[dict[str, str], str]:
data = {
"nodes": [entity_type],
"property": "<-typeOf",
Expand Down Expand Up @@ -244,6 +244,50 @@ def resolve_entity_type(entity_dcids: list[str]) -> str:
return common_entity_types.pop() if common_entity_types else ""


def get_property_of_entities(entities: list[str],
property_name: str) -> dict[str, str]:
return asyncio.run(get_property_of_entities_async(entities, property_name))


async def get_property_of_entities_async(entities: list[str],
property_name: str) -> dict[str, str]:

chunks = chunked(entities, _BATCH_SIZE)

result: dict[str, str] = {}
async with AsyncClient(limits=_HTTPX_LIMITS, timeout=None) as client:
futures: dict[str, str] = [
_get_property_of_entities_chunk(client, chunk, property_name)
for chunk in chunks
]
for result_chunk in await asyncio.gather(*futures):
result.update(result_chunk)

return result


async def _get_property_of_entities_chunk(client: AsyncClient,
entities_chunk: list[str],
property_name: str) -> dict[str, str]:
data = {
"nodes": entities_chunk,
"property": f"->{property_name}",
}

logging.debug("Fetching nodes: %s", data)
# TODO: handle pagination.
response = await post_async(client, path="/v2/node", data=data)

result_chunk: dict[str, str] = {}
for entity_dcid, entity_data in response.get("data", {}).items():
nodes = entity_data.get("arcs", {}).get(property_name, {}).get("nodes", [])
values = [node.get("value") for node in nodes if node.get("value")]
if values:
result_chunk[entity_dcid] = values[0]

return result_chunk


def post(path: str, data={}) -> dict:
url = get_api_root() + path
headers = {"Content-Type": "application/json"}
Expand Down

0 comments on commit db812c3

Please sign in to comment.