Skip to content

Commit

Permalink
feat: JIRA 1818 - Adding Types to Lineage Response for Special Querie…
Browse files Browse the repository at this point in the history
…s in Prov-API. (#38)
  • Loading branch information
parth-kulkarni1 authored Dec 20, 2024
1 parent 8055590 commit 926154b
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 291 deletions.
534 changes: 287 additions & 247 deletions poetry.lock

Large diffs are not rendered by default.

36 changes: 33 additions & 3 deletions src/provenaclient/models/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
---------- --- ---------------------------------------------------------
'''

from pydantic import BaseModel

from typing import Any, Dict, Optional, Type, TypedDict, List
from pydantic import BaseModel, ValidationError, validator
from ProvenaInterfaces.ProvenanceAPI import LineageResponse
from ProvenaInterfaces.RegistryAPI import Node

class HealthCheckResponse(BaseModel):
message: str
Expand All @@ -31,4 +33,32 @@ class AsyncAwaitSettings(BaseModel):
# how long do we wait for it to become in progress? (seconds)
job_async_in_progress_polling_timeout = 180 # 3 minutes

DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings()
DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings()

class GraphProperty(BaseModel):
type: str
source: str
target: str

class CustomGraph(BaseModel):
directed: bool
multigraph: bool
graph: Dict[str, Any]
nodes: List[Node]
links: List[GraphProperty]

class CustomLineageResponse(LineageResponse):
"""A Custom Lineage Response Pydantic Model
that inherits from its parent (LineageResponse).
This model overrides the "graph" field within the
Lineage Response, and converts it from an untyped
dictionary into a pydantic object/ typed datatype.
The custom validator function is called, and has custom
parsing logic to parse the nodes as well.
"""

graph: Optional[CustomGraph] #type:ignore

64 changes: 38 additions & 26 deletions src/provenaclient/modules/prov.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from provenaclient.modules.module_helpers import *
from provenaclient.utils.helpers import read_file_helper, write_file_helper, get_and_validate_file_path
from typing import List
from provenaclient.models.general import HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse
from provenaclient.models.general import CustomLineageResponse, HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse
from ProvenaInterfaces.RegistryAPI import ItemModelRun
from ProvenaInterfaces.SharedTypes import StatusResponse

Expand Down Expand Up @@ -214,7 +214,7 @@ async def update_model_run(self, model_run_id: str, reason: str, record: ModelRu
record=record
)

async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the upstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -228,13 +228,15 @@ async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
upstream_response = await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
typed_upstream_response = CustomLineageResponse.parse_obj(upstream_response.dict())
return typed_upstream_response

async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the downstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -248,13 +250,15 @@ async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAU
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = CustomLineageResponse.parse_obj(typed_downstream_response.dict())
return typed_downstream_response

async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (inputs) which involved in a model run
naturally in the upstream direction.
Expand All @@ -267,13 +271,15 @@ async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_AP
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
contributing_datasets = await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
typed_contributing_datasets = CustomLineageResponse.parse_obj(contributing_datasets.dict())
return typed_contributing_datasets

async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (outputs) which are derived from the model run
naturally in the downstream direction.
Expand All @@ -286,13 +292,15 @@ async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DE
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
effected_datasets_response = await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
typed_effected_datasets = CustomLineageResponse.parse_obj(effected_datasets_response.dict())
return typed_effected_datasets

async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the upstream direction.
Expand All @@ -305,13 +313,15 @@ async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
contributing_agents_response = await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
typed_contributing_agents = CustomLineageResponse.parse_obj(contributing_agents_response.dict())
return typed_contributing_agents

async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the downstream direction.
Expand All @@ -324,11 +334,13 @@ async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFA
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
effected_agents_response = await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
typed_effected_agents = CustomLineageResponse.parse_obj(effected_agents_response.dict())
return typed_effected_agents

async def register_batch_model_runs(self, batch_model_run_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse:
"""This function allows you to register multiple model runs in one go (batch) asynchronously.
Expand Down
16 changes: 15 additions & 1 deletion tests/adhoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def random_num() -> int: return random.randint(100, 1000)

# print(item_counts)

"""Example for downloading specific files..."""
"""Example for downloading specific files...
# Downloading a file at root level
await client.datastore.io.download_specific_file(dataset_id="10378.1/1876000", s3_path="metadata.json", destination_directory="./")
Expand All @@ -278,8 +278,22 @@ def random_num() -> int: return random.randint(100, 1000)
# my_dataset = await client.datastore.interactive_dataset(dataset_id="10378.1/1948400")
# await my_dataset.download_all_files(destination_directory="./")
"""

response = await client.prov_api.explore_upstream(
starting_id="10378.1/1965416",
depth=2
)

print(response.record_count)

assert response.graph

print(response.graph.nodes)

print("Listing all datasets")
for node in response.graph.nodes:
if node.item_subtype == ItemSubType.DATASET:
print(node.id, node.item_subtype)

asyncio.run(main())
21 changes: 7 additions & 14 deletions tests/integration_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ProvenaInterfaces.TestConfig import RouteParameters, route_params, non_test_route_params
from ProvenaInterfaces.AsyncJobModels import RegistryRegisterCreateActivityResult

from provenaclient.models.general import CustomGraph, CustomLineageResponse, GraphProperty
from provenaclient.modules.provena_client import ProvenaClient
from provenaclient.modules.registry import ModelClient, OrganisationClient, PersonClient, StudyClient, DatasetTemplateClient
from provenaclient.utils.registry_endpoints import *
Expand Down Expand Up @@ -344,14 +345,7 @@ def get_item_subtype_domain_info_example(item_subtype: ItemSubType) -> DomainInf
return get_item_subtype_route_params(item_subtype=item_subtype).model_examples.domain_info[0]


Graph = Dict[str, Any]


@dataclass
class GraphProperty():
type: str
source: str
target: str
Graph = CustomGraph


def assert_graph_property(prop: GraphProperty, graph: Graph) -> None:
Expand All @@ -368,18 +362,17 @@ def assert_graph_property(prop: GraphProperty, graph: Graph) -> None:
graph (Graph): The graph to analyse
"""

links = graph['links']
links = graph.links
found = False
for l in links:
actual_prop = GraphProperty(**l)
if actual_prop == prop:
actual_prop = GraphProperty(type=l.type, source=l.source, target=l.target)
if actual_prop == prop:
found = True
break

assert found, f"Could not find relation specified {prop}."


def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: LineageResponse) -> None:
def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: CustomLineageResponse) -> None:
"""
Determines if the desired graph property exists in the networkX JSON graph
lineage response in the graph object.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,10 @@ async def test_provenance_workflow(client: ProvenaClient, org_person_fixture: Tu
depth=1,
)

# Adding tests for the "custom lineage response override"
assert activity_upstream_query.graph, f"The graph field is missing from upstream response with CustomLineageResponse override."
assert activity_upstream_query.graph.nodes, f"The nodes field is missing from upstream response with CustomLineageResponse override."

# model run -wasInformedBy-> study
assert_non_empty_graph_property(
prop=GraphProperty(
Expand Down

0 comments on commit 926154b

Please sign in to comment.