Skip to content

Commit

Permalink
Handle skeleton encoding internally (#1970)
Browse files Browse the repository at this point in the history
* start class `SkeletonEncoder`

* _encoded_objects need to be a dict to add to

* add notebook for testing

* format

* fix type in docstring

* finish classmethod for encoding Skeleton as a json string

* test encoded Skeleton as json string by decoding it

* add test for decoded encoded skeleton

* update jupyter notebook for easy testing

* constraining attrs in dev environment to make sure decode format is always the same locally

* encode links first then encode source then target then type

* save first enconding statically as an input to _get_or_assign_id so that we do not always get py/id

* save first encoding statically

* first encoding is passed to _get_or_assign_id

* use first_encoding variable to determine if we should assign a py/id

* add print statements for debugging

* update notebook for easy testing

* black

* remove comment

* adding attrs constraint to show this passes for certain attrs version only

* add import

* switch out jsonpickle.encode

* oops remove import

* can attrs be unconstrained?

* forgot comma

* pin attrs for testing

* test Skeleton from json, template, with symmetries, and template

* use SkeletonEncoder.encode

* black

* try removing None values in EdgeType reduced

* Handle case when nodes are replaced by integer indices from caller

* Remove prototyping notebook

* Remove attrs pins

* Remove sort keys (which flips the neccessary ordering of our py/ids)

* Do not add extra indents to encoded file

* Only append links after fully encoded (fat-finger)

* Remove outdated comment

* Lint

---------

Co-authored-by: Talmo Pereira <[email protected]>
Co-authored-by: roomrys <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent ab93b9e commit ef803f6
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 8 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ channels:

dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg
- conda-forge::jsmin
Expand Down
2 changes: 1 addition & 1 deletion environment_no_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ channels:

dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg
- conda-forge::jsmin
Expand Down
197 changes: 192 additions & 5 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,193 @@ def decode(cls, json_str: str) -> Dict:
return decoder._decode(json_str)


class SkeletonEncoder:
"""Replace jsonpickle.encode with our own encoder.
The input is a dictionary containing python objects that need to be encoded as
JSON strings. The output is a JSON string that represents the input dictionary.
`Node(name='neck', weight=1.0)` =>
{
"py/object": "sleap.Skeleton.Node",
"py/state": {"py/tuple" ["neck", 1.0]}
}
`<EdgeType.BODY: 1>` =>
{"py/reduce": [
{"py/type": "sleap.Skeleton.EdgeType"},
{"py/tuple": [1] }
]
}`
Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0.
`EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`.
See sleap.skeleton.Node and sleap.skeleton.EdgeType.
If the object has been "seen" before, it will not be encoded as the full JSON string
but referenced by its `py/id`, which starts at 1 and indexes the objects in the
order they are seen so that the second time the first object is used, it will be
referenced as `{"py/id": 1}`.
"""

def __init__(self):
"""Initializes a SkeletonEncoder instance."""
# Maps object id to py/id
self._encoded_objects: Dict[int, int] = {}

@classmethod
def encode(cls, data: Dict[str, Any]) -> str:
"""Encodes the input dictionary as a JSON string.
Args:
data: The data to encode.
Returns:
json_str: The JSON string representation of the data.
"""
encoder = cls()
encoded_data = encoder._encode(data)
json_str = json.dumps(encoded_data)
return json_str

def _encode(self, obj: Any) -> Any:
"""Recursively encodes the input object.
Args:
obj: The object to encode. Can be a dictionary, list, Node, EdgeType or
primitive data type.
Returns:
The encoded object as a dictionary.
"""
if isinstance(obj, dict):
encoded_obj = {}
for key, value in obj.items():
if key == "links":
encoded_obj[key] = self._encode_links(value)
else:
encoded_obj[key] = self._encode(value)
return encoded_obj
elif isinstance(obj, list):
return [self._encode(v) for v in obj]
elif isinstance(obj, EdgeType):
return self._encode_edge_type(obj)
elif isinstance(obj, Node):
return self._encode_node(obj)
else:
return obj # Primitive data types

def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Encodes the list of links (edges) in the skeleton graph.
Args:
links: A list of dictionaries, each representing an edge in the graph.
Returns:
A list of encoded edge dictionaries with keys ordered as specified.
"""
encoded_links = []
for link in links:
# Use a regular dict (insertion order preserved in Python 3.7+)
encoded_link = {}

for key, value in link.items():
if key in ("source", "target"):
encoded_link[key] = self._encode_node(value)
elif key == "type":
encoded_link[key] = self._encode_edge_type(value)
else:
encoded_link[key] = self._encode(value)
encoded_links.append(encoded_link)

return encoded_links

def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]:
"""Encodes a Node object.
Args:
node: The Node object to encode or integer index. The latter requires that
the class has the `idx_to_node` attribute set.
Returns:
The encoded `Node` object as a dictionary.
"""
if isinstance(node, int):
# We sometimes have the node object already replaced by its index (when
# `node_to_idx` is provided). In this case, the node is already encoded.
return node

# Check if object has been encoded before
first_encoding = self._is_first_encoding(node)
py_id = self._get_or_assign_id(node, first_encoding)
if first_encoding:
# Full encoding
return {
"py/object": "sleap.skeleton.Node",
"py/state": {"py/tuple": [node.name, node.weight]},
}
else:
# Reference by py/id
return {"py/id": py_id}

def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]:
"""Encodes an EdgeType object.
Args:
edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or
`EdgeType.SYMMETRY` enum with values 1 and 2 respectively.
Returns:
The encoded EdgeType object as a dictionary.
"""
# Check if object has been encoded before
first_encoding = self._is_first_encoding(edge_type)
py_id = self._get_or_assign_id(edge_type, first_encoding)
if first_encoding:
# Full encoding
return {
"py/reduce": [
{"py/type": "sleap.skeleton.EdgeType"},
{"py/tuple": [edge_type.value]},
]
}
else:
# Reference by py/id
return {"py/id": py_id}

def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int:
"""Gets or assigns a py/id for the object.
Args:
The object to get or assign a py/id for.
Returns:
The py/id assigned to the object.
"""
# Object id is unique for each object in the current session
obj_id = id(obj)
# Assign a py/id to the object if it hasn't been assigned one yet
if first_encoding:
py_id = len(self._encoded_objects) + 1 # py/id starts at 1
# Assign the py/id to the object and store it in _encoded_objects
self._encoded_objects[obj_id] = py_id
return self._encoded_objects[obj_id]

def _is_first_encoding(self, obj: Any) -> bool:
"""Checks if the object is being encoded for the first time.
Args:
obj: The object to check.
Returns:
True if this is the first encoding of the object, False otherwise.
"""
obj_id = id(obj)
first_time = obj_id not in self._encoded_objects
return first_time


class Skeleton:
"""The main object for representing animal skeletons.
Expand Down Expand Up @@ -1228,7 +1415,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D

# This is a weird hack to serialize the whole _graph into a dict.
# I use the underlying to_json and parse it.
return json.loads(obj.to_json(node_to_idx))
return json.loads(obj.to_json(node_to_idx=node_to_idx))

@classmethod
def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton":
Expand Down Expand Up @@ -1292,10 +1479,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
"""
jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)
if node_to_idx is not None:
indexed_node_graph = nx.relabel_nodes(
G=self._graph, mapping=node_to_idx
) # map nodes to int
# Map Nodes to int
indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx)
else:
# Keep graph nodes as Node objects
indexed_node_graph = self._graph

# Encode to JSON
Expand All @@ -1314,7 +1501,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
else:
data = graph

json_str = jsonpickle.encode(data)
json_str = SkeletonEncoder.encode(data)

return json_str

Expand Down
55 changes: 54 additions & 1 deletion tests/test_skeleton.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,62 @@
import copy
import os

import pytest
import json

from networkx.readwrite import json_graph
from sleap.skeleton import Skeleton, SkeletonDecoder
from sleap.skeleton import SkeletonEncoder


def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json):
"""
Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton.
"""
# Get the skeleton from the fixture
skeleton = Skeleton.load_json(fly_legs_skeleton_json)
# Get the graph from the skeleton
indexed_node_graph = skeleton._graph
graph = json_graph.node_link_data(indexed_node_graph)

# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

# Check that the decoded skeleton is the same as the original skeleton
assert skeleton.matches(decoded_skeleton)


@pytest.mark.parametrize(
"skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"]
)
def test_decoded_encoded_Skeleton(skeleton_fixture_name, request):
"""
Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton.
"""
# Use request.getfixturevalue to get the actual fixture value by name
skeleton = request.getfixturevalue(skeleton_fixture_name)

# Get the graph from the skeleton
indexed_node_graph = skeleton._graph
graph = json_graph.node_link_data(indexed_node_graph)

# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

# Check that the decoded skeleton is the same as the original skeleton
assert skeleton.matches(decoded_skeleton)

# Now make everything into a JSON string
skeleton_json_str = skeleton.to_json()
decoded_skeleton_json_str = decoded_skeleton.to_json()

# Check that the JSON strings are the same
assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str)


def test_add_dupe_node(skeleton):
Expand Down

0 comments on commit ef803f6

Please sign in to comment.