From 9f088daf98b521361a02e127a02ae7e7775a5ad6 Mon Sep 17 00:00:00 2001 From: parth-kulkarni1 Date: Fri, 6 Dec 2024 01:38:00 +0000 Subject: [PATCH] Completion of the JIRA Ticket 1818. --- src/provenaclient/models/general.py | 98 ++++++++++++++++++++++++++++- src/provenaclient/modules/prov.py | 64 +++++++++++-------- tests/adhoc.py | 14 ++++- 3 files changed, 146 insertions(+), 30 deletions(-) diff --git a/src/provenaclient/models/general.py b/src/provenaclient/models/general.py index 4a122da..ca327cf 100644 --- a/src/provenaclient/models/general.py +++ b/src/provenaclient/models/general.py @@ -12,8 +12,10 @@ ---------- --- --------------------------------------------------------- ''' -from pydantic import BaseModel - +from typing import Any, Dict, Optional, Type, TypedDict, List +from pydantic import BaseModel, ValidationError, validator +from ProvenaInterfaces.ProvenanceAPI import LineageResponse +from ProvenaInterfaces.RegistryAPI import Node class HealthCheckResponse(BaseModel): message: str @@ -31,4 +33,94 @@ class AsyncAwaitSettings(BaseModel): # how long do we wait for it to become in progress? (seconds) job_async_in_progress_polling_timeout = 180 # 3 minutes -DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings() \ No newline at end of file +DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings() + + +class CustomGraph(BaseModel): + directed: bool + multigraph: bool + graph: Dict[str, Any] + nodes: List[Node] + +class CustomLineageResponse(LineageResponse): + """A Custom Lineage Response Pydantic Model + that inherits from its parent (LineageResponse). + + This model overrides the "graph" field within the + Lineage Response, and converts it from an untyped + dictionary into a pydantic object/ typed datatype. + + The custom validator function is called, and has custom + parsing logic to parse the nodes as well. + + """ + + graph: Optional[CustomGraph] #type:ignore + + @validator('graph') + @classmethod + def convert_graph(cls: Type["CustomLineageResponse"], v: Dict[str, Any]) -> Optional[CustomGraph]: + """Converts the untyped "graph" dictionary into a typed pydantic object/ datatype. + + Parameters + ---------- + cls : Type[CustomLineageResponse] + v : Dict[str, Any] + The "graph" untyped dictionary from Lineage Response. + + Returns + ------- + Optional[CustomGraph] + Pydantic object that is formed from the untyped dictionary. + """ + + if v is None: + return None + + # Parse the nodes before returning the pydantic object + list_of_parsed_nodes: List[Node] = cls.parse_nodes(v.get('nodes', [])) + + # Convert the generic dict to CustomGraph structure + return CustomGraph( + directed= v.get('directed', False), + multigraph= v.get('multigraph', False), + graph= v.get('graph', {}), + nodes= list_of_parsed_nodes + ) + + @classmethod + def parse_nodes(cls, nodes_to_parse: List[Any]) -> List[Node]: + """Parses potential nodes into typed node objects. + + Parameters + ---------- + nodes_to_parse : List[Any] + A list of potential nodes. + + Returns + ------- + List[Node] + A list of typed nodes. + + Raises + ------ + ValidationError + Raised when node parsing fails. + Exception + Raised for any error that occurs during parsing of the node. + """ + + try: + nodes_parsed_list: List[Node] = [] + + for node in nodes_to_parse: + parsed_node = Node.parse_obj(node) + nodes_parsed_list.append(parsed_node) + + return nodes_parsed_list + + except ValidationError as e: + raise ValidationError(f"Something has gone with parsing the nodes - {e}", model=Node) + + except Exception as e: + raise Exception(f"Something has gone wrong with parsing the node - {e}") \ No newline at end of file diff --git a/src/provenaclient/modules/prov.py b/src/provenaclient/modules/prov.py index a4a880e..c340843 100644 --- a/src/provenaclient/modules/prov.py +++ b/src/provenaclient/modules/prov.py @@ -19,8 +19,8 @@ from provenaclient.modules.module_helpers import * from provenaclient.utils.helpers import read_file_helper, write_file_helper, get_and_validate_file_path from typing import List -from provenaclient.models.general import HealthCheckResponse -from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse +from provenaclient.models.general import CustomLineageResponse, HealthCheckResponse +from ProvenaInterfaces.ProvenanceAPI import ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse from ProvenaInterfaces.RegistryAPI import ItemModelRun from ProvenaInterfaces.SharedTypes import StatusResponse @@ -214,7 +214,7 @@ async def update_model_run(self, model_run_id: str, reason: str, record: ModelRu record=record ) - async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Explores in the upstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. @@ -228,13 +228,15 @@ async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth) + upstream_response = await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth) + typed_upstream_response = CustomLineageResponse.parse_obj(upstream_response.dict()) + return typed_upstream_response - async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Explores in the downstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. @@ -248,13 +250,15 @@ async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAU Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth) + typed_downstream_response = await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth) + typed_downstream_response = CustomLineageResponse.parse_obj(typed_downstream_response.dict()) + return typed_downstream_response - async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches datasets (inputs) which involved in a model run naturally in the upstream direction. @@ -267,13 +271,15 @@ async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_AP Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth) + contributing_datasets = await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth) + typed_contributing_datasets = CustomLineageResponse.parse_obj(contributing_datasets.dict()) + return typed_contributing_datasets - async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches datasets (outputs) which are derived from the model run naturally in the downstream direction. @@ -286,13 +292,15 @@ async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DE Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth) + effected_datasets_response = await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth) + typed_effected_datasets = CustomLineageResponse.parse_obj(effected_datasets_response.dict()) + return typed_effected_datasets - async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the upstream direction. @@ -305,13 +313,15 @@ async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_ Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth) + contributing_agents_response = await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth) + typed_contributing_agents = CustomLineageResponse.parse_obj(contributing_agents_response.dict()) + return typed_contributing_agents - async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse: + async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the downstream direction. @@ -324,11 +334,13 @@ async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFA Returns ------- - LineageResponse - A response containing the status, node count, and networkx serialised graph response. + CustomLineageResponse + A typed response containing the status, node count, and networkx serialised graph response. """ - return await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth) + effected_agents_response = await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth) + typed_effected_agents = CustomLineageResponse.parse_obj(effected_agents_response.dict()) + return typed_effected_agents async def register_batch_model_runs(self, batch_model_run_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse: """This function allows you to register multiple model runs in one go (batch) asynchronously. diff --git a/tests/adhoc.py b/tests/adhoc.py index f7bd216..613731a 100644 --- a/tests/adhoc.py +++ b/tests/adhoc.py @@ -257,7 +257,7 @@ def random_num() -> int: return random.randint(100, 1000) # print(item_counts) - """Example for downloading specific files...""" + """Example for downloading specific files... # Downloading a file at root level await client.datastore.io.download_specific_file(dataset_id="10378.1/1876000", s3_path="metadata.json", destination_directory="./") @@ -278,8 +278,20 @@ def random_num() -> int: return random.randint(100, 1000) # my_dataset = await client.datastore.interactive_dataset(dataset_id="10378.1/1948400") # await my_dataset.download_all_files(destination_directory="./") + """ + + response = await client.prov_api.explore_upstream( + starting_id="10378.1/1965416", + depth=2 + ) + + print(response.record_count) + assert response.graph + print(response.graph.nodes) + for node in response.graph.nodes: + print(node.id, node.item_subtype) asyncio.run(main())