Skip to content

Commit

Permalink
Handle skeleton decoding internally (#1961)
Browse files Browse the repository at this point in the history
* Reorganize (and add) imports

* Add (and reorganize) imports

* Modify decode_preview_image to return bytes if specified

* Implement (minimally tested) replace_jsonpickle_decode

* Add support for using idx_to_node map
i.e. loading from Labels (slp file)

* Ignore None items in reduce_list

* Convert large function to SkeletonDecoder class

* Update SkeletonDecoder.decode docstring

* Move decode_preview_image to SkeletonDecoder

* Use SkeletonDecoder instead of jsonpickle in tests

* Remove unused imports

* Add test for decoding dict vs tuple pystates
  • Loading branch information
roomrys authored Sep 25, 2024
1 parent 3c7f5af commit ab93b9e
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 48 deletions.
8 changes: 3 additions & 5 deletions sleap/gui/widgets/docks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@
)
from sleap.gui.dialogs.formbuilder import YamlFormWidget
from sleap.gui.widgets.views import CollapsibleWidget
from sleap.skeleton import Skeleton
from sleap.util import decode_preview_image, find_files_by_suffix, get_package_file

# from sleap.gui.app import MainWindow
from sleap.skeleton import Skeleton, SkeletonDecoder
from sleap.util import find_files_by_suffix, get_package_file


class DockWidget(QDockWidget):
Expand Down Expand Up @@ -365,7 +363,7 @@ def create_templates_groupbox(self) -> QGroupBox:
def updatePreviewImage(preview_image_bytes: bytes):

# Decode the preview image
preview_image = decode_preview_image(preview_image_bytes)
preview_image = SkeletonDecoder.decode_preview_image(preview_image_bytes)

# Create a QImage from the Image
preview_image = QtGui.QImage(
Expand Down
311 changes: 301 additions & 10 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@
their connection to each other, and needed meta-data.
"""

import attr
import cattr
import numpy as np
import jsonpickle
import json
import h5py
import base64
import copy

import json
import operator
from enum import Enum
from io import BytesIO
from itertools import count
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union

import attr
import cattr
import h5py
import jsonpickle
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
from PIL import Image
from scipy.io import loadmat


NodeRef = Union[str, "Node"]
H5FileRef = Union[str, h5py.File]

Expand Down Expand Up @@ -85,6 +86,296 @@ def matches(self, other: "Node") -> bool:
return other.name == self.name and other.weight == self.weight


class SkeletonDecoder:
"""Replace jsonpickle.decode with our own decoder.
This function will decode the following from jsonpickle's encoded format:
`Node` objects from
{
"py/object": "sleap.skeleton.Node",
"py/state": { "py/tuple": ["thorax1", 1.0] }
}
to `Node(name="thorax1", weight=1.0)`
`EdgeType` objects from
{
"py/reduce": [
{ "py/type": "sleap.skeleton.EdgeType" },
{ "py/tuple": [1] }
]
}
to `EdgeType(1)`
`bytes` from
{
"py/b64": "aVZC..."
}
to `b"iVBO..."`
and any repeated objects from
{
"py/id": 1
}
to the object with the same reconstruction id (from top to bottom).
"""

def __init__(self):
self.decoded_objects: List[Union[Node, EdgeType]] = []

def _decode_id(self, id: int) -> Union[Node, EdgeType]:
"""Decode the object with the given `py/id` value of `id`.
Args:
id: The `py/id` value to decode (1-indexed).
objects: The dictionary of objects that have already been decoded.
Returns:
The object with the given `py/id` value.
"""
return self.decoded_objects[id - 1]

@staticmethod
def _decode_state(state: dict) -> Node:
"""Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph.
We support states in either dictionary or tuple format:
{
"py/state": { "py/tuple": ["thorax1", 1.0] }
}
or
{
"py/state": {"name": "thorax1", "weight": 1.0}
}
Args:
state: The state to decode, i.e. state = dict["py/state"]
Returns:
The `Node` object reconstructed from the state.
"""

if "py/tuple" in state:
return Node(*state["py/tuple"])

return Node(**state)

@staticmethod
def _decode_object_dict(object_dict) -> Node:
"""Decode dict containing `py/object` key in the serialized nx_graph.
Args:
object_dict: The dict to decode, i.e.
object_dict = {"py/object": ..., "py/state":...}
Raises:
ValueError: If object_dict does not have 'py/object' and 'py/state' keys.
ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'.
Returns:
The decoded `Node` object.
"""

if object_dict["py/object"] != "sleap.skeleton.Node":
raise ValueError("Only 'sleap.skeleton.Node' objects are supported.")

node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"])
return node

def _decode_node(self, encoded_node: dict) -> Node:
"""Decode an item believed to be an encoded `Node` object.
Also updates the list of decoded objects.
Args:
encoded_node: The encoded node to decode.
Returns:
The decoded node and the updated list of decoded objects.
"""

if isinstance(encoded_node, int):
# Using index mapping to replace the object (load from Labels)
return encoded_node
elif "py/object" in encoded_node:
decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node)
self.decoded_objects.append(decoded_node)
elif "py/id" in encoded_node:
decoded_node: Node = self._decode_id(encoded_node["py/id"])

return decoded_node

def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]:
"""Decode the 'nodes' key in the serialized nx_graph.
The encoded_nodes is a list of dictionary of two types:
- A dictionary with 'py/object' and 'py/state' keys.
- A dictionary with 'py/id' key.
Args:
encoded_nodes: The list of encoded nodes to decode.
Returns:
The decoded nodes.
"""

decoded_nodes: List[Dict[str, Node]] = []
for e_node_dict in encoded_nodes:
e_node = e_node_dict["id"]
d_node = self._decode_node(e_node)
decoded_nodes.append({"id": d_node})

return decoded_nodes

def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType:
"""Decode the 'reduce' key in the serialized nx_graph.
The reduce_dict is a dictionary in the following format:
{
"py/reduce": [
{ "py/type": "sleap.skeleton.EdgeType" },
{ "py/tuple": [1] }
]
}
Args:
reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...}
Returns:
The decoded `EdgeType` object.
"""

reduce_list = reduce_dict["py/reduce"]
has_py_type = has_py_tuple = False
for reduce_item in reduce_list:
if reduce_item is None:
# Sometimes the reduce list has None values, skip them
continue
if (
"py/type" in reduce_item
and reduce_item["py/type"] == "sleap.skeleton.EdgeType"
):
has_py_type = True
elif "py/tuple" in reduce_item:
edge_type: int = reduce_item["py/tuple"][0]
has_py_tuple = True

if not has_py_type or not has_py_tuple:
raise ValueError(
"Only 'sleap.skeleton.EdgeType' objects are supported. "
"The 'py/reduce' list must have dictionaries with 'py/type' and "
"'py/tuple' keys."
f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}"
)

edge = EdgeType(edge_type)
self.decoded_objects.append(edge)

return edge

def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType:
"""Decode the 'type' key in the serialized nx_graph.
Args:
encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key.
Returns:
The decoded `EdgeType` object.
"""

if "py/reduce" in encoded_edge_type:
edge_type = self._decode_reduce_dict(encoded_edge_type)
else:
# Expect a "py/id" instead of "py/reduce"
edge_type = self._decode_id(encoded_edge_type["py/id"])
return edge_type

def _decode_links(
self, links: List[dict]
) -> List[Dict[str, Union[int, Node, EdgeType]]]:
"""Decode the 'links' key in the serialized nx_graph.
The links are the edges in the graph and will have the following keys:
- source: The source node of the edge.
- target: The destination node of the edge.
- type: The type of the edge (e.g. BODY, SYMMETRY).
and more.
Args:
encoded_links: The list of encoded links to decode.
"""

for link in links:
for key, value in link.items():
if key == "source":
link[key] = self._decode_node(value)
elif key == "target":
link[key] = self._decode_node(value)
elif key == "type":
link[key] = self._decode_edge_type(value)

return links

@staticmethod
def decode_preview_image(
img_b64: bytes, return_bytes: bool = False
) -> Union[Image.Image, bytes]:
"""Decode a skeleton preview image byte string representation to a `PIL.Image`
Args:
img_b64: a byte string representation of a skeleton preview image
return_bytes: whether to return the decoded image as bytes
Returns:
Either a PIL.Image of the skeleton preview image or the decoded image as bytes
(if `return_bytes` is True).
"""
bytes = base64.b64decode(img_b64)
if return_bytes:
return bytes

buffer = BytesIO(bytes)
img = Image.open(buffer)
return img

def _decode(self, json_str: str):
dicts = json.loads(json_str)

# Enforce same format across template and non-template skeletons
if "nx_graph" not in dicts:
# Non-template skeletons use the dicts as the "nx_graph"
dicts = {"nx_graph": dicts}

# Decode the graph
nx_graph = dicts["nx_graph"]

self.decoded_objects = [] # Reset the decoded objects incase reusing decoder
for key, value in nx_graph.items():
if key == "nodes":
nx_graph[key] = self._decode_nodes(value)
elif key == "links":
nx_graph[key] = self._decode_links(value)

# Decode the preview image (if it exists)
preview_image = dicts.get("preview_image", None)
if preview_image is not None:
dicts["preview_image"] = SkeletonDecoder.decode_preview_image(
preview_image["py/b64"], return_bytes=True
)

return dicts

@classmethod
def decode(cls, json_str: str) -> Dict:
"""Decode the given json string into a dictionary.
Returns:
A dict with `Node`s, `EdgeType`s, and `bytes` decoded/reconstructed.
"""
decoder = cls()
return decoder._decode(json_str)


class Skeleton:
"""The main object for representing animal skeletons.
Expand Down Expand Up @@ -1071,7 +1362,7 @@ def from_json(
Returns:
An instance of the `Skeleton` object decoded from the JSON.
"""
dicts = jsonpickle.decode(json_str)
dicts: dict = SkeletonDecoder.decode(json_str)
nx_graph = dicts.get("nx_graph", dicts)
graph = json_graph.node_link_graph(nx_graph)

Expand Down
Loading

0 comments on commit ab93b9e

Please sign in to comment.