Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort encoded Skeleton dictionary for backwards compatibility #1975

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import attr
import cattr
import h5py
import jsonpickle
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
Expand Down Expand Up @@ -421,11 +420,30 @@ def encode(cls, data: Dict[str, Any]) -> str:
Returns:
json_str: The JSON string representation of the data.
"""

# This is required for backwards compatibility with SLEAP <=1.3.4
sorted_data = cls._recursively_sort_dict(data)

encoder = cls()
encoded_data = encoder._encode(data)
encoded_data = encoder._encode(sorted_data)
json_str = json.dumps(encoded_data)
return json_str

@staticmethod
def _recursively_sort_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively sorts the dictionary by keys."""
sorted_dict = dict(sorted(dictionary.items()))
for key, value in sorted_dict.items():
if isinstance(value, dict):
sorted_dict[key] = SkeletonEncoder._recursively_sort_dict(value)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
sorted_dict[key][i] = SkeletonEncoder._recursively_sort_dict(
item
)
return sorted_dict

def _encode(self, obj: Any) -> Any:
"""Recursively encodes the input object.

Expand Down Expand Up @@ -1477,7 +1495,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
Returns:
A string containing the JSON representation of the skeleton.
"""
jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)

if node_to_idx is not None:
# Map Nodes to int
indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def test_decoded_encoded_Skeleton(skeleton_fixture_name, request):
# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Assert that the encoded json has keys in sorted order (backwards compatibility)
encoded_dict = json.loads(encoded_json_str)
sorted_keys = sorted(encoded_dict.keys())
assert list(encoded_dict.keys()) == sorted_keys
for key, value in encoded_dict.items():
if isinstance(value, dict):
assert list(value.keys()) == sorted(value.keys())
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
assert list(item.keys()) == sorted(item.keys())

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

Expand Down
Loading