From ab93b9ed14765d8fe03a8bfddad0769cb842323a Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:08:03 -0700 Subject: [PATCH] Handle skeleton decoding internally (#1961) * Reorganize (and add) imports * Add (and reorganize) imports * Modify decode_preview_image to return bytes if specified * Implement (minimally tested) replace_jsonpickle_decode * Add support for using idx_to_node map i.e. loading from Labels (slp file) * Ignore None items in reduce_list * Convert large function to SkeletonDecoder class * Update SkeletonDecoder.decode docstring * Move decode_preview_image to SkeletonDecoder * Use SkeletonDecoder instead of jsonpickle in tests * Remove unused imports * Add test for decoding dict vs tuple pystates --- sleap/gui/widgets/docks.py | 8 +- sleap/skeleton.py | 311 +++++++++++++++++- sleap/util.py | 18 - .../fly_skeleton_legs_pystate_dict.json | 1 + tests/fixtures/skeletons.py | 15 +- tests/test_skeleton.py | 31 +- tests/test_util.py | 8 - 7 files changed, 344 insertions(+), 48 deletions(-) create mode 100644 tests/data/skeleton/fly_skeleton_legs_pystate_dict.json diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 3375e4713..bd20bf79a 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -30,10 +30,8 @@ ) from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.widgets.views import CollapsibleWidget -from sleap.skeleton import Skeleton -from sleap.util import decode_preview_image, find_files_by_suffix, get_package_file - -# from sleap.gui.app import MainWindow +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.util import find_files_by_suffix, get_package_file class DockWidget(QDockWidget): @@ -365,7 +363,7 @@ def create_templates_groupbox(self) -> QGroupBox: def updatePreviewImage(preview_image_bytes: bytes): # Decode the preview image - preview_image = decode_preview_image(preview_image_bytes) + preview_image = SkeletonDecoder.decode_preview_image(preview_image_bytes) # Create a QImage from the Image preview_image = QtGui.QImage( diff --git a/sleap/skeleton.py b/sleap/skeleton.py index eca393b8e..ed083fc0e 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -6,24 +6,25 @@ their connection to each other, and needed meta-data. """ -import attr -import cattr -import numpy as np -import jsonpickle -import json -import h5py +import base64 import copy - +import json import operator from enum import Enum +from io import BytesIO from itertools import count -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text +from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union +import attr +import cattr +import h5py +import jsonpickle import networkx as nx +import numpy as np from networkx.readwrite import json_graph +from PIL import Image from scipy.io import loadmat - NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] @@ -85,6 +86,296 @@ def matches(self, other: "Node") -> bool: return other.name == self.name and other.weight == self.weight +class SkeletonDecoder: + """Replace jsonpickle.decode with our own decoder. + + This function will decode the following from jsonpickle's encoded format: + + `Node` objects from + { + "py/object": "sleap.skeleton.Node", + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + to `Node(name="thorax1", weight=1.0)` + + `EdgeType` objects from + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + to `EdgeType(1)` + + `bytes` from + { + "py/b64": "aVZC..." + } + to `b"iVBO..."` + + and any repeated objects from + { + "py/id": 1 + } + to the object with the same reconstruction id (from top to bottom). + """ + + def __init__(self): + self.decoded_objects: List[Union[Node, EdgeType]] = [] + + def _decode_id(self, id: int) -> Union[Node, EdgeType]: + """Decode the object with the given `py/id` value of `id`. + + Args: + id: The `py/id` value to decode (1-indexed). + objects: The dictionary of objects that have already been decoded. + + Returns: + The object with the given `py/id` value. + """ + return self.decoded_objects[id - 1] + + @staticmethod + def _decode_state(state: dict) -> Node: + """Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph. + + We support states in either dictionary or tuple format: + { + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + or + { + "py/state": {"name": "thorax1", "weight": 1.0} + } + + Args: + state: The state to decode, i.e. state = dict["py/state"] + + Returns: + The `Node` object reconstructed from the state. + """ + + if "py/tuple" in state: + return Node(*state["py/tuple"]) + + return Node(**state) + + @staticmethod + def _decode_object_dict(object_dict) -> Node: + """Decode dict containing `py/object` key in the serialized nx_graph. + + Args: + object_dict: The dict to decode, i.e. + object_dict = {"py/object": ..., "py/state":...} + + Raises: + ValueError: If object_dict does not have 'py/object' and 'py/state' keys. + ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'. + + Returns: + The decoded `Node` object. + """ + + if object_dict["py/object"] != "sleap.skeleton.Node": + raise ValueError("Only 'sleap.skeleton.Node' objects are supported.") + + node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"]) + return node + + def _decode_node(self, encoded_node: dict) -> Node: + """Decode an item believed to be an encoded `Node` object. + + Also updates the list of decoded objects. + + Args: + encoded_node: The encoded node to decode. + + Returns: + The decoded node and the updated list of decoded objects. + """ + + if isinstance(encoded_node, int): + # Using index mapping to replace the object (load from Labels) + return encoded_node + elif "py/object" in encoded_node: + decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node) + self.decoded_objects.append(decoded_node) + elif "py/id" in encoded_node: + decoded_node: Node = self._decode_id(encoded_node["py/id"]) + + return decoded_node + + def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]: + """Decode the 'nodes' key in the serialized nx_graph. + + The encoded_nodes is a list of dictionary of two types: + - A dictionary with 'py/object' and 'py/state' keys. + - A dictionary with 'py/id' key. + + Args: + encoded_nodes: The list of encoded nodes to decode. + + Returns: + The decoded nodes. + """ + + decoded_nodes: List[Dict[str, Node]] = [] + for e_node_dict in encoded_nodes: + e_node = e_node_dict["id"] + d_node = self._decode_node(e_node) + decoded_nodes.append({"id": d_node}) + + return decoded_nodes + + def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType: + """Decode the 'reduce' key in the serialized nx_graph. + + The reduce_dict is a dictionary in the following format: + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + + Args: + reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...} + + Returns: + The decoded `EdgeType` object. + """ + + reduce_list = reduce_dict["py/reduce"] + has_py_type = has_py_tuple = False + for reduce_item in reduce_list: + if reduce_item is None: + # Sometimes the reduce list has None values, skip them + continue + if ( + "py/type" in reduce_item + and reduce_item["py/type"] == "sleap.skeleton.EdgeType" + ): + has_py_type = True + elif "py/tuple" in reduce_item: + edge_type: int = reduce_item["py/tuple"][0] + has_py_tuple = True + + if not has_py_type or not has_py_tuple: + raise ValueError( + "Only 'sleap.skeleton.EdgeType' objects are supported. " + "The 'py/reduce' list must have dictionaries with 'py/type' and " + "'py/tuple' keys." + f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}" + ) + + edge = EdgeType(edge_type) + self.decoded_objects.append(edge) + + return edge + + def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType: + """Decode the 'type' key in the serialized nx_graph. + + Args: + encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key. + + Returns: + The decoded `EdgeType` object. + """ + + if "py/reduce" in encoded_edge_type: + edge_type = self._decode_reduce_dict(encoded_edge_type) + else: + # Expect a "py/id" instead of "py/reduce" + edge_type = self._decode_id(encoded_edge_type["py/id"]) + return edge_type + + def _decode_links( + self, links: List[dict] + ) -> List[Dict[str, Union[int, Node, EdgeType]]]: + """Decode the 'links' key in the serialized nx_graph. + + The links are the edges in the graph and will have the following keys: + - source: The source node of the edge. + - target: The destination node of the edge. + - type: The type of the edge (e.g. BODY, SYMMETRY). + and more. + + Args: + encoded_links: The list of encoded links to decode. + """ + + for link in links: + for key, value in link.items(): + if key == "source": + link[key] = self._decode_node(value) + elif key == "target": + link[key] = self._decode_node(value) + elif key == "type": + link[key] = self._decode_edge_type(value) + + return links + + @staticmethod + def decode_preview_image( + img_b64: bytes, return_bytes: bool = False + ) -> Union[Image.Image, bytes]: + """Decode a skeleton preview image byte string representation to a `PIL.Image` + + Args: + img_b64: a byte string representation of a skeleton preview image + return_bytes: whether to return the decoded image as bytes + + Returns: + Either a PIL.Image of the skeleton preview image or the decoded image as bytes + (if `return_bytes` is True). + """ + bytes = base64.b64decode(img_b64) + if return_bytes: + return bytes + + buffer = BytesIO(bytes) + img = Image.open(buffer) + return img + + def _decode(self, json_str: str): + dicts = json.loads(json_str) + + # Enforce same format across template and non-template skeletons + if "nx_graph" not in dicts: + # Non-template skeletons use the dicts as the "nx_graph" + dicts = {"nx_graph": dicts} + + # Decode the graph + nx_graph = dicts["nx_graph"] + + self.decoded_objects = [] # Reset the decoded objects incase reusing decoder + for key, value in nx_graph.items(): + if key == "nodes": + nx_graph[key] = self._decode_nodes(value) + elif key == "links": + nx_graph[key] = self._decode_links(value) + + # Decode the preview image (if it exists) + preview_image = dicts.get("preview_image", None) + if preview_image is not None: + dicts["preview_image"] = SkeletonDecoder.decode_preview_image( + preview_image["py/b64"], return_bytes=True + ) + + return dicts + + @classmethod + def decode(cls, json_str: str) -> Dict: + """Decode the given json string into a dictionary. + + Returns: + A dict with `Node`s, `EdgeType`s, and `bytes` decoded/reconstructed. + """ + decoder = cls() + return decoder._decode(json_str) + + class Skeleton: """The main object for representing animal skeletons. @@ -1071,7 +1362,7 @@ def from_json( Returns: An instance of the `Skeleton` object decoded from the JSON. """ - dicts = jsonpickle.decode(json_str) + dicts: dict = SkeletonDecoder.decode(json_str) nx_graph = dicts.get("nx_graph", dicts) graph = json_graph.node_link_graph(nx_graph) diff --git a/sleap/util.py b/sleap/util.py index eef762ff4..bc3389b7d 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -3,13 +3,11 @@ Try not to put things in here unless they really have no other place. """ -import base64 import json import os import re import shutil from collections import defaultdict -from io import BytesIO from pathlib import Path from typing import Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse @@ -26,7 +24,6 @@ from importlib.resources import files # New in 3.9+ except ImportError: from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. -from PIL import Image import sleap.version as sleap_version @@ -374,18 +371,3 @@ def find_files_by_suffix( def parse_uri_path(uri: str) -> str: """Parse a URI starting with 'file:///' to a posix path.""" return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix() - - -def decode_preview_image(img_b64: bytes) -> Image: - """Decode a skeleton preview image byte string representation to a `PIL.Image` - - Args: - img_b64: a byte string representation of a skeleton preview image - - Returns: - A PIL.Image of the skeleton preview - """ - bytes = base64.b64decode(img_b64) - buffer = BytesIO(bytes) - img = Image.open(buffer) - return img diff --git a/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json new file mode 100644 index 000000000..eae83d6bc --- /dev/null +++ b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json @@ -0,0 +1 @@ +{"directed": true, "graph": {"name": "skeleton_legs.mat", "num_edges_inserted": 23}, "links": [{"edge_insert_idx": 1, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "neck", "weight": 1.0}}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "head", "weight": 1.0}}, "type": {"py/reduce": [{"py/type": "sleap.skeleton.EdgeType"}, {"py/tuple": [1]}]}}, {"edge_insert_idx": 0, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "thorax", "weight": 1.0}}, "target": {"py/id": 1}, "type": {"py/id": 3}}, {"edge_insert_idx": 2, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "abdomen", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 3, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingL", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 4, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingR", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 5, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 8, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 11, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 14, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 17, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 20, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 6, "key": 0, "source": {"py/id": 8}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 7, "key": 0, "source": {"py/id": 14}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 9, "key": 0, "source": {"py/id": 9}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 10, "key": 0, "source": {"py/id": 16}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 12, "key": 0, "source": {"py/id": 10}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 13, "key": 0, "source": {"py/id": 18}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 15, "key": 0, "source": {"py/id": 11}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 16, "key": 0, "source": {"py/id": 20}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 18, "key": 0, "source": {"py/id": 12}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 19, "key": 0, "source": {"py/id": 22}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 21, "key": 0, "source": {"py/id": 13}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 22, "key": 0, "source": {"py/id": 24}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR3", "weight": 1.0}}, "type": {"py/id": 3}}], "multigraph": true, "nodes": [{"id": {"py/id": 2}}, {"id": {"py/id": 1}}, {"id": {"py/id": 4}}, {"id": {"py/id": 5}}, {"id": {"py/id": 6}}, {"id": {"py/id": 7}}, {"id": {"py/id": 8}}, {"id": {"py/id": 14}}, {"id": {"py/id": 15}}, {"id": {"py/id": 9}}, {"id": {"py/id": 16}}, {"id": {"py/id": 17}}, {"id": {"py/id": 10}}, {"id": {"py/id": 18}}, {"id": {"py/id": 19}}, {"id": {"py/id": 11}}, {"id": {"py/id": 20}}, {"id": {"py/id": 21}}, {"id": {"py/id": 12}}, {"id": {"py/id": 22}}, {"id": {"py/id": 23}}, {"id": {"py/id": 13}}, {"id": {"py/id": 24}}, {"id": {"py/id": 25}}]} \ No newline at end of file diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 311510e6a..b432ca2c7 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -3,14 +3,27 @@ from sleap.skeleton import Skeleton TEST_FLY_LEGS_SKELETON = "tests/data/skeleton/fly_skeleton_legs.json" +TEST_FLY_LEGS_SKELETON_DICT = "tests/data/skeleton/fly_skeleton_legs_pystate_dict.json" @pytest.fixture def fly_legs_skeleton_json(): - """Path to fly_skeleton_legs.json""" + """Path to fly_skeleton_legs.json + + This skeleton json has py/state in tuple format. + """ return TEST_FLY_LEGS_SKELETON +@pytest.fixture +def fly_legs_skeleton_dict_json(): + """Path to fly_skeleton_legs_pystate_dict.json + + This skeleton json has py/state dict format. + """ + return TEST_FLY_LEGS_SKELETON_DICT + + @pytest.fixture def stickman(): diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f7c3a853..2fc32628a 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,10 +1,9 @@ -import os import copy +import os -import jsonpickle import pytest -from sleap.skeleton import Skeleton +from sleap.skeleton import Skeleton, SkeletonDecoder def test_add_dupe_node(skeleton): @@ -194,9 +193,9 @@ def test_json(skeleton: Skeleton, tmpdir): ) assert skeleton.is_template == False json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) - assert "nx_graph" not in json_dict_keys + assert "nx_graph" in json_dict_keys # SkeletonDecoder adds this key assert "preview_image" not in json_dict_keys assert "description" not in json_dict_keys @@ -208,7 +207,7 @@ def test_json(skeleton: Skeleton, tmpdir): skeleton._is_template = True json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) assert "nx_graph" in json_dict_keys assert "preview_image" in json_dict_keys @@ -224,6 +223,26 @@ def test_json(skeleton: Skeleton, tmpdir): assert skeleton.matches(skeleton_copy) +def test_decode_preview_image(flies13_skeleton: Skeleton): + skeleton = flies13_skeleton + img_b64 = skeleton.preview_image + img = SkeletonDecoder.decode_preview_image(img_b64) + assert img.mode == "RGBA" + + +def test_skeleton_decoder(fly_legs_skeleton_json, fly_legs_skeleton_dict_json): + """Test that SkeletonDecoder can decode both tuple and dict py/state formats.""" + + skeleton_tuple_pystate = Skeleton.load_json(fly_legs_skeleton_json) + assert isinstance(skeleton_tuple_pystate, Skeleton) + + skeleton_dict_pystate = Skeleton.load_json(fly_legs_skeleton_dict_json) + assert isinstance(skeleton_dict_pystate, Skeleton) + + # These are the same skeleton, so they should match + assert skeleton_dict_pystate.matches(skeleton_tuple_pystate) + + def test_hdf5(skeleton, stickman, tmpdir): filename = os.path.join(tmpdir, "skeleton.h5") diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..35b41afa8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,4 @@ import pytest -from sleap.skeleton import Skeleton from sleap.util import * @@ -147,10 +146,3 @@ def test_save_dict_to_hdf5(tmpdir): assert f["bar"][-1].decode() == "zop" assert f["cab"]["a"][()] == 2 - - -def test_decode_preview_image(flies13_skeleton: Skeleton): - skeleton = flies13_skeleton - img_b64 = skeleton.preview_image - img = decode_preview_image(img_b64) - assert img.mode == "RGBA"