Skip to content

Commit

Permalink
Merge branch 'main' of github.com:provena/provena-python-client into …
Browse files Browse the repository at this point in the history
…generate-report
  • Loading branch information
jyucsiro committed Dec 20, 2024
2 parents 145a1bd + 926154b commit 5fcb875
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 47 deletions.
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 32 additions & 2 deletions src/provenaclient/models/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
---------- --- ---------------------------------------------------------
'''

from pydantic import BaseModel, Field
from ProvenaInterfaces.RegistryAPI import ItemSubType
from typing import Any, Dict, Optional, Type, TypedDict, List
from pydantic import BaseModel, Field, ValidationError, validator
from ProvenaInterfaces.RegistryAPI import ItemSubType, Node
from ProvenaInterfaces.ProvenanceAPI import LineageResponse


class HealthCheckResponse(BaseModel):
Expand All @@ -33,3 +35,31 @@ class AsyncAwaitSettings(BaseModel):
job_async_in_progress_polling_timeout = 180 # 3 minutes

DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings()

class GraphProperty(BaseModel):
type: str
source: str
target: str

class CustomGraph(BaseModel):
directed: bool
multigraph: bool
graph: Dict[str, Any]
nodes: List[Node]
links: List[GraphProperty]

class CustomLineageResponse(LineageResponse):
"""A Custom Lineage Response Pydantic Model
that inherits from its parent (LineageResponse).
This model overrides the "graph" field within the
Lineage Response, and converts it from an untyped
dictionary into a pydantic object/ typed datatype.
The custom validator function is called, and has custom
parsing logic to parse the nodes as well.
"""

graph: Optional[CustomGraph] #type:ignore

62 changes: 37 additions & 25 deletions src/provenaclient/modules/prov.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from provenaclient.modules.module_helpers import *
from provenaclient.utils.helpers import read_file_helper, write_file_helper, get_and_validate_file_path
from typing import List
from provenaclient.models.general import HealthCheckResponse
from provenaclient.models.general import CustomLineageResponse, HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse, GenerateReportRequest
from ProvenaInterfaces.RegistryAPI import ItemModelRun
from ProvenaInterfaces.SharedTypes import StatusResponse
Expand Down Expand Up @@ -217,7 +217,7 @@ async def update_model_run(self, model_run_id: str, reason: str, record: ModelRu
record=record
)

async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the upstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -231,13 +231,15 @@ async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
upstream_response = await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
typed_upstream_response = CustomLineageResponse.parse_obj(upstream_response.dict())
return typed_upstream_response

async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the downstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -251,13 +253,15 @@ async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAU
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = CustomLineageResponse.parse_obj(typed_downstream_response.dict())
return typed_downstream_response

async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (inputs) which involved in a model run
naturally in the upstream direction.
Expand All @@ -270,13 +274,15 @@ async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_AP
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
contributing_datasets = await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
typed_contributing_datasets = CustomLineageResponse.parse_obj(contributing_datasets.dict())
return typed_contributing_datasets

async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (outputs) which are derived from the model run
naturally in the downstream direction.
Expand All @@ -289,13 +295,15 @@ async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DE
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
effected_datasets_response = await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
typed_effected_datasets = CustomLineageResponse.parse_obj(effected_datasets_response.dict())
return typed_effected_datasets

async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the upstream direction.
Expand All @@ -308,13 +316,15 @@ async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
contributing_agents_response = await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
typed_contributing_agents = CustomLineageResponse.parse_obj(contributing_agents_response.dict())
return typed_contributing_agents

async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the downstream direction.
Expand All @@ -327,11 +337,13 @@ async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFA
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
effected_agents_response = await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
typed_effected_agents = CustomLineageResponse.parse_obj(effected_agents_response.dict())
return typed_effected_agents

async def register_batch_model_runs(self, batch_model_run_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse:
"""This function allows you to register multiple model runs in one go (batch) asynchronously.
Expand Down
12 changes: 12 additions & 0 deletions tests/adhoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,21 @@ def random_num() -> int: return random.randint(100, 1000)
depth=1
), file_path="./idontexistpath/butinhere/")


response = await client.prov_api.explore_upstream(
starting_id="10378.1/1965416",
depth=2
)

print(response.record_count)

assert response.graph

print(response.graph.nodes)

print("Listing all datasets")
for node in response.graph.nodes:
if node.item_subtype == ItemSubType.DATASET:
print(node.id, node.item_subtype)

asyncio.run(main())
21 changes: 7 additions & 14 deletions tests/integration_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ProvenaInterfaces.TestConfig import RouteParameters, route_params, non_test_route_params
from ProvenaInterfaces.AsyncJobModels import RegistryRegisterCreateActivityResult

from provenaclient.models.general import CustomGraph, CustomLineageResponse, GraphProperty
from provenaclient.modules.provena_client import ProvenaClient
from provenaclient.modules.registry import ModelClient, OrganisationClient, PersonClient, StudyClient, DatasetTemplateClient
from provenaclient.utils.registry_endpoints import *
Expand Down Expand Up @@ -344,14 +345,7 @@ def get_item_subtype_domain_info_example(item_subtype: ItemSubType) -> DomainInf
return get_item_subtype_route_params(item_subtype=item_subtype).model_examples.domain_info[0]


Graph = Dict[str, Any]


@dataclass
class GraphProperty():
type: str
source: str
target: str
Graph = CustomGraph


def assert_graph_property(prop: GraphProperty, graph: Graph) -> None:
Expand All @@ -368,18 +362,17 @@ def assert_graph_property(prop: GraphProperty, graph: Graph) -> None:
graph (Graph): The graph to analyse
"""

links = graph['links']
links = graph.links
found = False
for l in links:
actual_prop = GraphProperty(**l)
if actual_prop == prop:
actual_prop = GraphProperty(type=l.type, source=l.source, target=l.target)
if actual_prop == prop:
found = True
break

assert found, f"Could not find relation specified {prop}."


def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: LineageResponse) -> None:
def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: CustomLineageResponse) -> None:
"""
Determines if the desired graph property exists in the networkX JSON graph
lineage response in the graph object.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,10 @@ async def test_provenance_workflow(client: ProvenaClient, org_person_fixture: Tu
depth=1,
)

# Adding tests for the "custom lineage response override"
assert activity_upstream_query.graph, f"The graph field is missing from upstream response with CustomLineageResponse override."
assert activity_upstream_query.graph.nodes, f"The nodes field is missing from upstream response with CustomLineageResponse override."

# model run -wasInformedBy-> study
assert_non_empty_graph_property(
prop=GraphProperty(
Expand Down

0 comments on commit 5fcb875

Please sign in to comment.