diff --git a/poetry.lock b/poetry.lock index 3de431e..78583f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -13,24 +13,24 @@ files = [ [[package]] name = "anyio" -version = "4.6.2.post1" +version = "4.7.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" files = [ - {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, - {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, + {file = "anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352"}, + {file = "anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48"}, ] [package.dependencies] exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} +typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] -doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] trio = ["trio (>=0.26.1)"] [[package]] diff --git a/src/provenaclient/models/general.py b/src/provenaclient/models/general.py index 706572f..b020f49 100644 --- a/src/provenaclient/models/general.py +++ b/src/provenaclient/models/general.py @@ -12,8 +12,10 @@ ---------- --- --------------------------------------------------------- ''' -from pydantic import BaseModel, Field -from ProvenaInterfaces.RegistryAPI import ItemSubType +from typing import Any, Dict, Optional, Type, TypedDict, List +from pydantic import BaseModel, Field, ValidationError, validator +from ProvenaInterfaces.RegistryAPI import ItemSubType, Node +from ProvenaInterfaces.ProvenanceAPI import LineageResponse class HealthCheckResponse(BaseModel): @@ -33,3 +35,31 @@ class AsyncAwaitSettings(BaseModel): job_async_in_progress_polling_timeout = 180 # 3 minutes DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings() + +class GraphProperty(BaseModel): + type: str + source: str + target: str + +class CustomGraph(BaseModel): + directed: bool + multigraph: bool + graph: Dict[str, Any] + nodes: List[Node] + links: List[GraphProperty] + +class CustomLineageResponse(LineageResponse): + """A Custom Lineage Response Pydantic Model + that inherits from its parent (LineageResponse). + + This model overrides the "graph" field within the + Lineage Response, and converts it from an untyped + dictionary into a pydantic object/ typed datatype. + + The custom validator function is called, and has custom + parsing logic to parse the nodes as well. + + """ + + graph: Optional[CustomGraph] #type:ignore + diff --git a/src/provenaclient/modules/prov.py b/src/provenaclient/modules/prov.py index c13067d..e98fd18 100644 --- a/src/provenaclient/modules/prov.py +++ b/src/provenaclient/modules/prov.py @@ -22,7 +22,7 @@ from provenaclient.modules.module_helpers import * from provenaclient.utils.helpers import read_file_helper, write_file_helper, get_and_validate_file_path from typing import List -from provenaclient.models.general import HealthCheckResponse +from provenaclient.models.general import CustomLineageResponse, HealthCheckResponse from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse, GenerateReportRequest from ProvenaInterfaces.RegistryAPI import ItemModelRun from ProvenaInterfaces.SharedTypes import StatusResponse @@ -217,7 +217,7 @@ async def update_model_run(self, model_run_id: str, reason: str, record: ModelRu record=record ) - async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Explores in the upstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. @@ -231,13 +231,15 @@ async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth) + upstream_response = await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth) + typed_upstream_response = CustomLineageResponse.parse_obj(upstream_response.dict()) + return typed_upstream_response - async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Explores in the downstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. @@ -251,13 +253,15 @@ async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAU Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth) + typed_downstream_response = await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth) + typed_downstream_response = CustomLineageResponse.parse_obj(typed_downstream_response.dict()) + return typed_downstream_response - async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches datasets (inputs) which involved in a model run naturally in the upstream direction. @@ -270,13 +274,15 @@ async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_AP Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth) + contributing_datasets = await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth) + typed_contributing_datasets = CustomLineageResponse.parse_obj(contributing_datasets.dict()) + return typed_contributing_datasets - async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches datasets (outputs) which are derived from the model run naturally in the downstream direction. @@ -289,13 +295,15 @@ async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DE Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth) + effected_datasets_response = await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth) + typed_effected_datasets = CustomLineageResponse.parse_obj(effected_datasets_response.dict()) + return typed_effected_datasets - async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the upstream direction. @@ -308,13 +316,15 @@ async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_ Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth) + contributing_agents_response = await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth) + typed_contributing_agents = CustomLineageResponse.parse_obj(contributing_agents_response.dict()) + return typed_contributing_agents - async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the downstream direction. @@ -327,11 +337,13 @@ async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFA Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth) + effected_agents_response = await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth) + typed_effected_agents = CustomLineageResponse.parse_obj(effected_agents_response.dict()) + return typed_effected_agents async def register_batch_model_runs(self, batch_model_run_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse: """This function allows you to register multiple model runs in one go (batch) asynchronously. diff --git a/tests/adhoc.py b/tests/adhoc.py index 6dad927..ff9ab11 100644 --- a/tests/adhoc.py +++ b/tests/adhoc.py @@ -307,9 +307,21 @@ def random_num() -> int: return random.randint(100, 1000) depth=1 ), file_path="./idontexistpath/butinhere/") + + response = await client.prov_api.explore_upstream( + starting_id="10378.1/1965416", + depth=2 + ) + print(response.record_count) + assert response.graph + print(response.graph.nodes) + print("Listing all datasets") + for node in response.graph.nodes: + if node.item_subtype == ItemSubType.DATASET: + print(node.id, node.item_subtype) asyncio.run(main()) diff --git a/tests/integration_helpers.py b/tests/integration_helpers.py index 315403b..0ceda00 100644 --- a/tests/integration_helpers.py +++ b/tests/integration_helpers.py @@ -14,6 +14,7 @@ from ProvenaInterfaces.TestConfig import RouteParameters, route_params, non_test_route_params from ProvenaInterfaces.AsyncJobModels import RegistryRegisterCreateActivityResult +from provenaclient.models.general import CustomGraph, CustomLineageResponse, GraphProperty from provenaclient.modules.provena_client import ProvenaClient from provenaclient.modules.registry import ModelClient, OrganisationClient, PersonClient, StudyClient, DatasetTemplateClient from provenaclient.utils.registry_endpoints import * @@ -344,14 +345,7 @@ def get_item_subtype_domain_info_example(item_subtype: ItemSubType) -> DomainInf return get_item_subtype_route_params(item_subtype=item_subtype).model_examples.domain_info[0] -Graph = Dict[str, Any] - - -@dataclass -class GraphProperty(): - type: str - source: str - target: str +Graph = CustomGraph def assert_graph_property(prop: GraphProperty, graph: Graph) -> None: @@ -368,18 +362,17 @@ def assert_graph_property(prop: GraphProperty, graph: Graph) -> None: graph (Graph): The graph to analyse """ - links = graph['links'] + links = graph.links found = False for l in links: - actual_prop = GraphProperty(**l) - if actual_prop == prop: + actual_prop = GraphProperty(type=l.type, source=l.source, target=l.target) + if actual_prop == prop: found = True break - + assert found, f"Could not find relation specified {prop}." - -def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: LineageResponse) -> None: +def assert_non_empty_graph_property(prop: GraphProperty, lineage_response: CustomLineageResponse) -> None: """ Determines if the desired graph property exists in the networkX JSON graph lineage response in the graph object. diff --git a/tests/test_integration.py b/tests/test_integration.py index 5b87067..6f7ff7e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -706,6 +706,10 @@ async def test_provenance_workflow(client: ProvenaClient, org_person_fixture: Tu depth=1, ) + # Adding tests for the "custom lineage response override" + assert activity_upstream_query.graph, f"The graph field is missing from upstream response with CustomLineageResponse override." + assert activity_upstream_query.graph.nodes, f"The nodes field is missing from upstream response with CustomLineageResponse override." + # model run -wasInformedBy-> study assert_non_empty_graph_property( prop=GraphProperty(