Skip to content

Commit

Permalink
Completion of the JIRA Ticket 1818.
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-kulkarni1 committed Dec 6, 2024
1 parent 8055590 commit 9f088da
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 30 deletions.
98 changes: 95 additions & 3 deletions src/provenaclient/models/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
---------- --- ---------------------------------------------------------
'''

from pydantic import BaseModel

from typing import Any, Dict, Optional, Type, TypedDict, List
from pydantic import BaseModel, ValidationError, validator
from ProvenaInterfaces.ProvenanceAPI import LineageResponse
from ProvenaInterfaces.RegistryAPI import Node

class HealthCheckResponse(BaseModel):
message: str
Expand All @@ -31,4 +33,94 @@ class AsyncAwaitSettings(BaseModel):
# how long do we wait for it to become in progress? (seconds)
job_async_in_progress_polling_timeout = 180 # 3 minutes

DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings()
DEFAULT_AWAIT_SETTINGS = AsyncAwaitSettings()


class CustomGraph(BaseModel):
directed: bool
multigraph: bool
graph: Dict[str, Any]
nodes: List[Node]

class CustomLineageResponse(LineageResponse):
"""A Custom Lineage Response Pydantic Model
that inherits from its parent (LineageResponse).
This model overrides the "graph" field within the
Lineage Response, and converts it from an untyped
dictionary into a pydantic object/ typed datatype.
The custom validator function is called, and has custom
parsing logic to parse the nodes as well.
"""

graph: Optional[CustomGraph] #type:ignore

@validator('graph')
@classmethod
def convert_graph(cls: Type["CustomLineageResponse"], v: Dict[str, Any]) -> Optional[CustomGraph]:
"""Converts the untyped "graph" dictionary into a typed pydantic object/ datatype.
Parameters
----------
cls : Type[CustomLineageResponse]
v : Dict[str, Any]
The "graph" untyped dictionary from Lineage Response.
Returns
-------
Optional[CustomGraph]
Pydantic object that is formed from the untyped dictionary.
"""

if v is None:
return None

# Parse the nodes before returning the pydantic object
list_of_parsed_nodes: List[Node] = cls.parse_nodes(v.get('nodes', []))

# Convert the generic dict to CustomGraph structure
return CustomGraph(
directed= v.get('directed', False),
multigraph= v.get('multigraph', False),
graph= v.get('graph', {}),
nodes= list_of_parsed_nodes
)

@classmethod
def parse_nodes(cls, nodes_to_parse: List[Any]) -> List[Node]:
"""Parses potential nodes into typed node objects.
Parameters
----------
nodes_to_parse : List[Any]
A list of potential nodes.
Returns
-------
List[Node]
A list of typed nodes.
Raises
------
ValidationError
Raised when node parsing fails.
Exception
Raised for any error that occurs during parsing of the node.
"""

try:
nodes_parsed_list: List[Node] = []

for node in nodes_to_parse:
parsed_node = Node.parse_obj(node)
nodes_parsed_list.append(parsed_node)

return nodes_parsed_list

except ValidationError as e:
raise ValidationError(f"Something has gone with parsing the nodes - {e}", model=Node)

except Exception as e:
raise Exception(f"Something has gone wrong with parsing the node - {e}")
64 changes: 38 additions & 26 deletions src/provenaclient/modules/prov.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from provenaclient.modules.module_helpers import *
from provenaclient.utils.helpers import read_file_helper, write_file_helper, get_and_validate_file_path
from typing import List
from provenaclient.models.general import HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse
from provenaclient.models.general import CustomLineageResponse, HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import ModelRunRecord, ConvertModelRunsResponse, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, PostUpdateModelRunResponse
from ProvenaInterfaces.RegistryAPI import ItemModelRun
from ProvenaInterfaces.SharedTypes import StatusResponse

Expand Down Expand Up @@ -214,7 +214,7 @@ async def update_model_run(self, model_run_id: str, reason: str, record: ModelRu
record=record
)

async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the upstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -228,13 +228,15 @@ async def explore_upstream(self, starting_id: str, depth: int = PROV_API_DEFAULT
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
upstream_response = await self._prov_api_client.explore_upstream(starting_id=starting_id, depth=depth)
typed_upstream_response = CustomLineageResponse.parse_obj(upstream_response.dict())
return typed_upstream_response

async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Explores in the downstream direction (inputs/associations)
starting at the specified node handle ID.
The search depth is bounded by the depth parameter which has a default maximum of 100.
Expand All @@ -248,13 +250,15 @@ async def explore_downstream(self, starting_id: str, depth: int = PROV_API_DEFAU
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = await self._prov_api_client.explore_downstream(starting_id=starting_id, depth=depth)
typed_downstream_response = CustomLineageResponse.parse_obj(typed_downstream_response.dict())
return typed_downstream_response

async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (inputs) which involved in a model run
naturally in the upstream direction.
Expand All @@ -267,13 +271,15 @@ async def get_contributing_datasets(self, starting_id: str, depth: int = PROV_AP
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
contributing_datasets = await self._prov_api_client.get_contributing_datasets(starting_id=starting_id, depth=depth)
typed_contributing_datasets = CustomLineageResponse.parse_obj(contributing_datasets.dict())
return typed_contributing_datasets

async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches datasets (outputs) which are derived from the model run
naturally in the downstream direction.
Expand All @@ -286,13 +292,15 @@ async def get_effected_datasets(self, starting_id: str, depth: int = PROV_API_DE
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
effected_datasets_response = await self._prov_api_client.get_effected_datasets(starting_id=starting_id, depth=depth)
typed_effected_datasets = CustomLineageResponse.parse_obj(effected_datasets_response.dict())
return typed_effected_datasets

async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the upstream direction.
Expand All @@ -305,13 +313,15 @@ async def get_contributing_agents(self, starting_id: str, depth: int = PROV_API_
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
contributing_agents_response = await self._prov_api_client.get_contributing_agents(starting_id=starting_id, depth=depth)
typed_contributing_agents = CustomLineageResponse.parse_obj(contributing_agents_response.dict())
return typed_contributing_agents

async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> LineageResponse:
async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFAULT_SEARCH_DEPTH) -> CustomLineageResponse:
"""Fetches agents (organisations or peoples) that are involved or impacted by the model run.
naturally in the downstream direction.
Expand All @@ -324,11 +334,13 @@ async def get_effected_agents(self, starting_id: str, depth: int = PROV_API_DEFA
Returns
-------
LineageResponse
A response containing the status, node count, and networkx serialised graph response.
CustomLineageResponse
A typed response containing the status, node count, and networkx serialised graph response.
"""

return await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
effected_agents_response = await self._prov_api_client.get_effected_agents(starting_id=starting_id, depth=depth)
typed_effected_agents = CustomLineageResponse.parse_obj(effected_agents_response.dict())
return typed_effected_agents

async def register_batch_model_runs(self, batch_model_run_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse:
"""This function allows you to register multiple model runs in one go (batch) asynchronously.
Expand Down
14 changes: 13 additions & 1 deletion tests/adhoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def random_num() -> int: return random.randint(100, 1000)

# print(item_counts)

"""Example for downloading specific files..."""
"""Example for downloading specific files...
# Downloading a file at root level
await client.datastore.io.download_specific_file(dataset_id="10378.1/1876000", s3_path="metadata.json", destination_directory="./")
Expand All @@ -278,8 +278,20 @@ def random_num() -> int: return random.randint(100, 1000)
# my_dataset = await client.datastore.interactive_dataset(dataset_id="10378.1/1948400")
# await my_dataset.download_all_files(destination_directory="./")
"""

response = await client.prov_api.explore_upstream(
starting_id="10378.1/1965416",
depth=2
)

print(response.record_count)

assert response.graph

print(response.graph.nodes)

for node in response.graph.nodes:
print(node.id, node.item_subtype)

asyncio.run(main())

0 comments on commit 9f088da

Please sign in to comment.