Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_nodes_from_json works with both str and dict args #398

Merged
merged 4 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/cript/nodes/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,13 @@ def _is_node_field_valid(node_type_list: list) -> bool:
return False


def load_nodes_from_json(nodes_json: str):
def load_nodes_from_json(nodes_json: Union[str, Dict]):
"""
User facing function, that return a node and all its children from a json string input.

Parameters
----------
nodes_json: str
nodes_json: Union[str, dict]
JSON string representation of a CRIPT node

Examples
Expand All @@ -416,7 +416,7 @@ def load_nodes_from_json(nodes_json: str):
>>> my_project_from_api_dict: dict = my_paginator.current_page_results[0] # doctest: +SKIP
>>> # Deserialize your Project dict into a Project node
>>> my_project_node_from_api = cript.load_nodes_from_json( # doctest: +SKIP
... nodes_json=json.dumps(my_project_from_api_dict)
... nodes_json=my_project_from_api_dict
... )

Raises
Expand All @@ -436,16 +436,28 @@ def load_nodes_from_json(nodes_json: str):

The function is intended for deserializing CRIPT nodes and should not be used for generic JSON.


Returns
-------
Union[CRIPT Node, List[CRIPT Node]]
Typically returns a single CRIPT node,
but if given a list of nodes, then it will serialize them and return a list of CRIPT nodes
"""
# Initialize the custom decoder hook for JSON deserialization
node_json_hook = _NodeDecoderHook()
json_nodes = json.loads(nodes_json, object_hook=node_json_hook)

# Check if the input is already a Python dictionary
if isinstance(nodes_json, Dict):
# If it's a dictionary, directly use the decoder hook to deserialize it
return node_json_hook(nodes_json)

# Check if the input is a JSON-formatted string
elif isinstance(nodes_json, str):
# If it's a JSON string, parse and deserialize it using the decoder hook
return json.loads(nodes_json, object_hook=node_json_hook)

# Raise an error if the input type is unsupported
else:
raise TypeError(f"Unsupported type for nodes_json: {type(nodes_json)}")
# TODO: enable this logic to replace proxies, once beartype is OK with that.
# def recursive_proxy_replacement(node, handled_nodes):
# if isinstance(node, _UIDProxy):
Expand All @@ -470,7 +482,6 @@ def load_nodes_from_json(nodes_json: str):
# return node
# handled_nodes = set()
# recursive_proxy_replacement(json_nodes, handled_nodes)
return json_nodes


def add_orphaned_nodes_to_project(project: Project, active_experiment: Experiment, max_iteration: int = -1):
Expand Down
35 changes: 35 additions & 0 deletions tests/nodes/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cript
from cript.nodes.util import _is_node_field_valid


Expand All @@ -16,3 +17,37 @@ def test_is_node_field_valid() -> None:
assert _is_node_field_valid(node_type_list="Project") is False

assert _is_node_field_valid(node_type_list=[]) is False


def test_load_node_from_json_dict_argument() -> None:
"""
tests that `cript.load_nodes_from_json` can correctly load the material node from a dict
instead of JSON string
"""
material_name = "my material name"
material_notes = "my material node notes"
material_bigsmiles = "my bigsmiles"
material_uuid = "29c796a1-8f08-41ea-8524-29e925f384af"

material_dict = {
"node": ["Material"],
"uid": f"_:{material_uuid}",
"uuid": material_uuid,
"name": material_name,
"notes": material_notes,
"property": [{"node": ["Property"], "uid": "_:aedce614-7acb-49d2-a2f6-47463f15b707", "uuid": "aedce614-7acb-49d2-a2f6-47463f15b707", "key": "modulus_shear", "type": "value", "value": 5.0, "unit": "GPa"}],
"computational_forcefield": {"node": ["ComputationalForcefield"], "uid": "_:059952a3-20f2-4739-96bd-a5ea43068065", "uuid": "059952a3-20f2-4739-96bd-a5ea43068065", "key": "amber", "building_block": "atom"},
"keyword": ["acetylene"],
"bigsmiles": material_bigsmiles,
}

# convert material from dict to node
my_material_node_from_dict = cript.load_nodes_from_json(nodes_json=material_dict)

# assert material is correctly deserialized from JSON dict to Material Python object
assert type(my_material_node_from_dict) == cript.Material
assert my_material_node_from_dict.name == material_name
assert my_material_node_from_dict.identifier[0]["bigsmiles"] == "my bigsmiles"

# convert UUID object to UUID str and compare
assert str(my_material_node_from_dict.uuid) == material_uuid