Skip to content

Commit

Permalink
Allow attaching custom metadata to lattices and electrons
Browse files Browse the repository at this point in the history
  • Loading branch information
cjao committed May 26, 2024
1 parent 573195b commit ee78a82
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 38 deletions.
6 changes: 4 additions & 2 deletions covalent/_dispatcher_plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,9 @@ def register_manifest(
if parent_dispatch_id:
endpoint = f"{BASE_ENDPOINT}/{parent_dispatch_id}/sublattices"

r = APIClient(dispatcher_addr).post(endpoint, data=stripped.model_dump_json())
r = APIClient(dispatcher_addr).post(
endpoint, data=stripped.model_dump_json(exclude_unset=True)
)
r.raise_for_status()

parsed_resp = ResultSchema.model_validate(r.json())
Expand Down Expand Up @@ -616,7 +618,7 @@ def register_derived_manifest(

params = {"reuse_previous_results": reuse_previous_results}
r = APIClient(dispatcher_addr).post(
endpoint, data=stripped.model_dump_json(), params=params
endpoint, data=stripped.model_dump_json(exclude_unset=True), params=params
)
r.raise_for_status()

Expand Down
4 changes: 2 additions & 2 deletions covalent/_serialize/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def _get_node_custom_assets(node_attrs: dict) -> Dict[str, AssetSchema]:
def serialize_node(node_id: int, node_attrs: dict, node_storage_path) -> ElectronSchema:
meta = _serialize_node_metadata(node_attrs, node_storage_path)
assets = _serialize_node_assets(node_attrs, node_storage_path)
custom_assets = _get_node_custom_assets(node_attrs)
return ElectronSchema(id=node_id, metadata=meta, assets=assets, custom_assets=custom_assets)
assets._custom = _get_node_custom_assets(node_attrs)
return ElectronSchema(id=node_id, metadata=meta, assets=assets)


def deserialize_node(e: ElectronSchema, metadata_only: bool = False) -> dict:
Expand Down
6 changes: 2 additions & 4 deletions covalent/_serialize/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,10 @@ def _get_lattice_custom_assets(lat: Lattice) -> Dict[str, AssetSchema]:
def serialize_lattice(lat, storage_path: str) -> LatticeSchema:
meta = _serialize_lattice_metadata(lat)
assets = _serialize_lattice_assets(lat, storage_path)
custom_assets = _get_lattice_custom_assets(lat)
assets._custom = _get_lattice_custom_assets(lat)
tg = serialize_transport_graph(lat.transport_graph, storage_path)

return LatticeSchema(
metadata=meta, assets=assets, custom_assets=custom_assets, transport_graph=tg
)
return LatticeSchema(metadata=meta, assets=assets, transport_graph=tg)


def deserialize_lattice(model: LatticeSchema) -> Lattice:
Expand Down
15 changes: 5 additions & 10 deletions covalent/_shared_files/schemas/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime
from typing import Dict, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from .asset import AssetSchema
from .common import StatusEnum
Expand Down Expand Up @@ -91,6 +91,8 @@ class ElectronAssets(BaseModel):
# user dependent assets
hooks: AssetSchema

_custom: Optional[Dict[str, AssetSchema]] = None


class ElectronMetadata(BaseModel):
task_group_id: int
Expand All @@ -103,6 +105,8 @@ class ElectronMetadata(BaseModel):
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None

_custom: Optional[Dict] = None

# For use by redispatch
def reset(self):
self.status = StatusEnum.NEW_OBJECT
Expand All @@ -114,12 +118,3 @@ class ElectronSchema(BaseModel):
id: int
metadata: ElectronMetadata
assets: ElectronAssets
custom_assets: Optional[Dict[str, AssetSchema]] = None

@field_validator("custom_assets")
def check_custom_asset_keys(cls, v):
if v is not None:
for key in v:
if key in ASSET_FILENAME_MAP:
raise ValueError(f"Asset {key} conflicts with built-in key")
return v
15 changes: 5 additions & 10 deletions covalent/_shared_files/schemas/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Dict, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from .asset import AssetSchema
from .transport_graph import TransportGraphSchema
Expand Down Expand Up @@ -91,6 +91,8 @@ class LatticeAssets(BaseModel):
# lattice.metadata
hooks: AssetSchema

_custom: Optional[Dict[str, AssetSchema]] = None


class LatticeMetadata(BaseModel):
name: str # __name__
Expand All @@ -101,18 +103,11 @@ class LatticeMetadata(BaseModel):
python_version: Optional[str] = None
covalent_version: Optional[str] = None

_custom: Optional[Dict] = None


class LatticeSchema(BaseModel):
metadata: LatticeMetadata
assets: LatticeAssets
custom_assets: Optional[Dict[str, AssetSchema]] = None

transport_graph: TransportGraphSchema

@field_validator("custom_assets")
def check_custom_asset_keys(cls, v):
if v is not None:
for key in v:
if key in ASSET_FILENAME_MAP:
raise ValueError(f"Asset {key} conflicts with built-in key")
return v
6 changes: 5 additions & 1 deletion covalent/_shared_files/schemas/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""FastAPI models for /api/v1/resultv2 endpoints"""

from datetime import datetime
from typing import Optional
from typing import Dict, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -54,6 +54,8 @@ class ResultMetadata(BaseModel):
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None

_custom: Optional[Dict] = None

# For use by redispatch
def reset(self):
self.dispatch_id = ""
Expand All @@ -67,6 +69,8 @@ class ResultAssets(BaseModel):
result: AssetSchema
error: AssetSchema

_custom: Optional[Dict[str, AssetSchema]] = None


class ResultSchema(BaseModel):
metadata: ResultMetadata
Expand Down
8 changes: 6 additions & 2 deletions covalent_dispatcher/_dal/importers/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def import_electron_assets(
asset_recs = {}

for asset_key, asset in e.assets:
# Register these later
if asset_key == "_custom":
continue

node_storage_path, object_key = object_store.get_uri_components(
dispatch_id,
e.id,
Expand All @@ -157,8 +161,8 @@ def import_electron_assets(
asset.remote_uri = f"file://{local_uri}"

# Register custom assets
if e.custom_assets:
for asset_key, asset in e.custom_assets.items():
if e.assets._custom:
for asset_key, asset in e.assets._custom.items():
object_key = f"{asset_key}.data"
local_uri = os.path.join(node_storage_path, object_key)

Expand Down
8 changes: 6 additions & 2 deletions covalent_dispatcher/_dal/importers/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def import_lattice_assets(

# Register built-in assets
for asset_key, asset in lat.assets:
# Deal with these later
if asset_key == "_custom":
continue

storage_path, object_key = object_store.get_uri_components(
dispatch_id=dispatch_id,
node_id=None,
Expand All @@ -118,8 +122,8 @@ def import_lattice_assets(
asset.remote_uri = f"file://{local_uri}"

# Register custom assets
if lat.custom_assets:
for asset_key, asset in lat.custom_assets.items():
if lat.assets._custom:
for asset_key, asset in lat.assets._custom.items():
object_key = f"{asset_key}.data"
local_uri = os.path.join(storage_path, object_key)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def test_import_result_with_custom_assets(mocker, test_db):
prefix="covalent-"
) as srv_dir:
manifest = get_mock_result(dispatch_id, sdk_dir)
manifest.lattice.custom_assets = {"custom_lattice_asset": AssetSchema(size=0)}
manifest.lattice.transport_graph.nodes[0].custom_assets = {
manifest.lattice.assets._custom = {"custom_lattice_asset": AssetSchema(size=0)}
manifest.lattice.transport_graph.nodes[0].assets._custom = {
"custom_electron_asset": AssetSchema(size=0)
}
filtered_res = import_result(manifest, srv_dir, None)
Expand Down
6 changes: 3 additions & 3 deletions tests/covalent_tests/serialize/lattice_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def workflow(x, y):

with tempfile.TemporaryDirectory() as d:
manifest = serialize_lattice(workflow, d)
assert ["custom_lat_asset"] == list(manifest.custom_assets.keys())
assert ["custom_lat_asset"] == list(manifest.assets._custom.keys())

node_0 = manifest.transport_graph.nodes[0]
assert "custom_electron_asset" in node_0.custom_assets
assert "custom_electron_asset" in node_0.assets._custom

node_1 = manifest.transport_graph.nodes[1]
assert not node_1.custom_assets
assert not node_1.assets._custom

0 comments on commit ee78a82

Please sign in to comment.