From 7467b813b364582aa0580380ed915dc50ee5d5ba Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Sun, 26 May 2024 08:28:20 -0400 Subject: [PATCH] Allow attaching custom metadata to lattices and electrons --- covalent/_dispatcher_plugins/local.py | 6 ++++-- covalent/_serialize/electron.py | 4 ++-- covalent/_serialize/lattice.py | 6 ++---- covalent/_shared_files/schemas/electron.py | 15 +++++---------- covalent/_shared_files/schemas/lattice.py | 15 +++++---------- covalent/_shared_files/schemas/result.py | 6 +++++- covalent_dispatcher/_dal/importers/electron.py | 8 ++++++-- covalent_dispatcher/_dal/importers/lattice.py | 8 ++++++-- .../_dal/importers/result_import_test.py | 4 ++-- .../serialize/lattice_serialization_test.py | 6 +++--- 10 files changed, 40 insertions(+), 38 deletions(-) diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 3138873e5..e1a277e2f 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -585,7 +585,9 @@ def register_manifest( if parent_dispatch_id: endpoint = f"{BASE_ENDPOINT}/{parent_dispatch_id}/sublattices" - r = APIClient(dispatcher_addr).post(endpoint, data=stripped.model_dump_json()) + r = APIClient(dispatcher_addr).post( + endpoint, data=stripped.model_dump_json(exclude_unset=True) + ) r.raise_for_status() parsed_resp = ResultSchema.model_validate(r.json()) @@ -616,7 +618,7 @@ def register_derived_manifest( params = {"reuse_previous_results": reuse_previous_results} r = APIClient(dispatcher_addr).post( - endpoint, data=stripped.model_dump_json(), params=params + endpoint, data=stripped.model_dump_json(exclude_unset=True), params=params ) r.raise_for_status() diff --git a/covalent/_serialize/electron.py b/covalent/_serialize/electron.py index fe5763675..b90879fbf 100644 --- a/covalent/_serialize/electron.py +++ b/covalent/_serialize/electron.py @@ -210,8 +210,8 @@ def _get_node_custom_assets(node_attrs: dict) -> Dict[str, AssetSchema]: def serialize_node(node_id: int, node_attrs: dict, node_storage_path) -> ElectronSchema: meta = _serialize_node_metadata(node_attrs, node_storage_path) assets = _serialize_node_assets(node_attrs, node_storage_path) - custom_assets = _get_node_custom_assets(node_attrs) - return ElectronSchema(id=node_id, metadata=meta, assets=assets, custom_assets=custom_assets) + assets._custom = _get_node_custom_assets(node_attrs) + return ElectronSchema(id=node_id, metadata=meta, assets=assets) def deserialize_node(e: ElectronSchema, metadata_only: bool = False) -> dict: diff --git a/covalent/_serialize/lattice.py b/covalent/_serialize/lattice.py index 3ab39f2bc..6fbd1b98c 100644 --- a/covalent/_serialize/lattice.py +++ b/covalent/_serialize/lattice.py @@ -194,12 +194,10 @@ def _get_lattice_custom_assets(lat: Lattice) -> Dict[str, AssetSchema]: def serialize_lattice(lat, storage_path: str) -> LatticeSchema: meta = _serialize_lattice_metadata(lat) assets = _serialize_lattice_assets(lat, storage_path) - custom_assets = _get_lattice_custom_assets(lat) + assets._custom = _get_lattice_custom_assets(lat) tg = serialize_transport_graph(lat.transport_graph, storage_path) - return LatticeSchema( - metadata=meta, assets=assets, custom_assets=custom_assets, transport_graph=tg - ) + return LatticeSchema(metadata=meta, assets=assets, transport_graph=tg) def deserialize_lattice(model: LatticeSchema) -> Lattice: diff --git a/covalent/_shared_files/schemas/electron.py b/covalent/_shared_files/schemas/electron.py index b245cc93d..c5da65e1d 100644 --- a/covalent/_shared_files/schemas/electron.py +++ b/covalent/_shared_files/schemas/electron.py @@ -19,7 +19,7 @@ from datetime import datetime from typing import Dict, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from .asset import AssetSchema from .common import StatusEnum @@ -91,6 +91,8 @@ class ElectronAssets(BaseModel): # user dependent assets hooks: AssetSchema + _custom: Optional[Dict[str, AssetSchema]] = None + class ElectronMetadata(BaseModel): task_group_id: int @@ -103,6 +105,8 @@ class ElectronMetadata(BaseModel): start_time: Optional[datetime] = None end_time: Optional[datetime] = None + _custom: Optional[Dict] = None + # For use by redispatch def reset(self): self.status = StatusEnum.NEW_OBJECT @@ -114,12 +118,3 @@ class ElectronSchema(BaseModel): id: int metadata: ElectronMetadata assets: ElectronAssets - custom_assets: Optional[Dict[str, AssetSchema]] = None - - @field_validator("custom_assets") - def check_custom_asset_keys(cls, v): - if v is not None: - for key in v: - if key in ASSET_FILENAME_MAP: - raise ValueError(f"Asset {key} conflicts with built-in key") - return v diff --git a/covalent/_shared_files/schemas/lattice.py b/covalent/_shared_files/schemas/lattice.py index 6a3e2bbf9..2fece9c80 100644 --- a/covalent/_shared_files/schemas/lattice.py +++ b/covalent/_shared_files/schemas/lattice.py @@ -18,7 +18,7 @@ from typing import Dict, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from .asset import AssetSchema from .transport_graph import TransportGraphSchema @@ -91,6 +91,8 @@ class LatticeAssets(BaseModel): # lattice.metadata hooks: AssetSchema + _custom: Optional[Dict[str, AssetSchema]] = None + class LatticeMetadata(BaseModel): name: str # __name__ @@ -101,18 +103,11 @@ class LatticeMetadata(BaseModel): python_version: Optional[str] = None covalent_version: Optional[str] = None + _custom: Optional[Dict] = None + class LatticeSchema(BaseModel): metadata: LatticeMetadata assets: LatticeAssets - custom_assets: Optional[Dict[str, AssetSchema]] = None transport_graph: TransportGraphSchema - - @field_validator("custom_assets") - def check_custom_asset_keys(cls, v): - if v is not None: - for key in v: - if key in ASSET_FILENAME_MAP: - raise ValueError(f"Asset {key} conflicts with built-in key") - return v diff --git a/covalent/_shared_files/schemas/result.py b/covalent/_shared_files/schemas/result.py index fa771bf9b..3160c3708 100644 --- a/covalent/_shared_files/schemas/result.py +++ b/covalent/_shared_files/schemas/result.py @@ -17,7 +17,7 @@ """FastAPI models for /api/v1/resultv2 endpoints""" from datetime import datetime -from typing import Optional +from typing import Dict, Optional from pydantic import BaseModel @@ -54,6 +54,8 @@ class ResultMetadata(BaseModel): start_time: Optional[datetime] = None end_time: Optional[datetime] = None + _custom: Optional[Dict] = None + # For use by redispatch def reset(self): self.dispatch_id = "" @@ -67,6 +69,8 @@ class ResultAssets(BaseModel): result: AssetSchema error: AssetSchema + _custom: Optional[Dict[str, AssetSchema]] = None + class ResultSchema(BaseModel): metadata: ResultMetadata diff --git a/covalent_dispatcher/_dal/importers/electron.py b/covalent_dispatcher/_dal/importers/electron.py index d4b5047c5..1f3ca51fc 100644 --- a/covalent_dispatcher/_dal/importers/electron.py +++ b/covalent_dispatcher/_dal/importers/electron.py @@ -133,6 +133,10 @@ def import_electron_assets( asset_recs = {} for asset_key, asset in e.assets: + # Register these later + if asset_key == "_custom": + continue + node_storage_path, object_key = object_store.get_uri_components( dispatch_id, e.id, @@ -157,8 +161,8 @@ def import_electron_assets( asset.remote_uri = f"file://{local_uri}" # Register custom assets - if e.custom_assets: - for asset_key, asset in e.custom_assets.items(): + if e.assets._custom: + for asset_key, asset in e.assets._custom.items(): object_key = f"{asset_key}.data" local_uri = os.path.join(node_storage_path, object_key) diff --git a/covalent_dispatcher/_dal/importers/lattice.py b/covalent_dispatcher/_dal/importers/lattice.py index a14938f98..9e7f97037 100644 --- a/covalent_dispatcher/_dal/importers/lattice.py +++ b/covalent_dispatcher/_dal/importers/lattice.py @@ -94,6 +94,10 @@ def import_lattice_assets( # Register built-in assets for asset_key, asset in lat.assets: + # Deal with these later + if asset_key == "_custom": + continue + storage_path, object_key = object_store.get_uri_components( dispatch_id=dispatch_id, node_id=None, @@ -118,8 +122,8 @@ def import_lattice_assets( asset.remote_uri = f"file://{local_uri}" # Register custom assets - if lat.custom_assets: - for asset_key, asset in lat.custom_assets.items(): + if lat.assets._custom: + for asset_key, asset in lat.assets._custom.items(): object_key = f"{asset_key}.data" local_uri = os.path.join(storage_path, object_key) diff --git a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py index 819f88bc6..964e7cbc5 100644 --- a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py +++ b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py @@ -275,8 +275,8 @@ def test_import_result_with_custom_assets(mocker, test_db): prefix="covalent-" ) as srv_dir: manifest = get_mock_result(dispatch_id, sdk_dir) - manifest.lattice.custom_assets = {"custom_lattice_asset": AssetSchema(size=0)} - manifest.lattice.transport_graph.nodes[0].custom_assets = { + manifest.lattice.assets._custom = {"custom_lattice_asset": AssetSchema(size=0)} + manifest.lattice.transport_graph.nodes[0].assets._custom = { "custom_electron_asset": AssetSchema(size=0) } filtered_res = import_result(manifest, srv_dir, None) diff --git a/tests/covalent_tests/serialize/lattice_serialization_test.py b/tests/covalent_tests/serialize/lattice_serialization_test.py index 4247b6230..709041bde 100644 --- a/tests/covalent_tests/serialize/lattice_serialization_test.py +++ b/tests/covalent_tests/serialize/lattice_serialization_test.py @@ -89,10 +89,10 @@ def workflow(x, y): with tempfile.TemporaryDirectory() as d: manifest = serialize_lattice(workflow, d) - assert ["custom_lat_asset"] == list(manifest.custom_assets.keys()) + assert ["custom_lat_asset"] == list(manifest.assets._custom.keys()) node_0 = manifest.transport_graph.nodes[0] - assert "custom_electron_asset" in node_0.custom_assets + assert "custom_electron_asset" in node_0.assets._custom node_1 = manifest.transport_graph.nodes[1] - assert not node_1.custom_assets + assert not node_1.assets._custom