Skip to content

Commit

Permalink
Fix stochastic deadlocks in tf.py_functions.
Browse files Browse the repository at this point in the history
- These resulted random hangs during training that could not be
  interrupted via SIGINT or otherwise.
- Solution was to remove TF ops from inside py_functions in the data
  pipeline.
- Interesting side effect is a big performance speedup.
- See: tensorflow/tensorflow#32454
  • Loading branch information
talmo committed Feb 13, 2020
1 parent c120486 commit 613c201
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
16 changes: 10 additions & 6 deletions sleap/nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
if hasattr(numpy.random, "_bit_generator"):
numpy.random.bit_generator = numpy.random._bit_generator

import numpy as np
import tensorflow as tf
import attr
from typing import List, Text
Expand Down Expand Up @@ -190,22 +191,23 @@ def py_augment(image, instances):
aug_det = self.augmenter.to_deterministic()

# Augment the image.
aug_img = aug_det.augment_image(image)
aug_img = aug_det.augment_image(image.numpy())

# Augment each set of points for each instance.
aug_instances = []
for instance in instances:
kps = ia.KeypointsOnImage.from_xy_array(instance, tuple(image.shape))
kps = ia.KeypointsOnImage.from_xy_array(instance.numpy(), tuple(image.shape))
aug_instance = aug_det.augment_keypoints(kps).to_xy_array()
aug_instances.append(aug_instance)

# Convert the results to tensors.
aug_img = tf.convert_to_tensor(aug_img, dtype=image.dtype)
# aug_img = tf.convert_to_tensor(aug_img, dtype=image.dtype)

# This will get converted to a rank 3 tensor (n_instances, n_nodes, 2).
aug_instances = [
tf.convert_to_tensor(x, dtype=instances.dtype) for x in aug_instances
]
aug_instances = np.stack(aug_instances, axis=0)
# aug_instances = [
# tf.convert_to_tensor(x, dtype=instances.dtype) for x in aug_instances
# ]

return aug_img, aug_instances

Expand All @@ -216,6 +218,8 @@ def augment(frame_data):
[frame_data["image"], frame_data["instances"]],
[frame_data["image"].dtype, frame_data["instances"].dtype],
)
image.set_shape(frame_data["image"].get_shape())
instances.set_shape(frame_data["instances"].get_shape())
frame_data.update({"image": image, "instances": instances})
return frame_data

Expand Down
35 changes: 20 additions & 15 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data providers for pipeline I/O."""

import numpy as np
import tensorflow as tf
import attr
from typing import Text, Optional, List
Expand Down Expand Up @@ -119,19 +120,24 @@ def make_dataset(
def py_fetch_lf(ind):
"""Local function that will not be autographed."""
lf = self.labels[int(ind.numpy())]
video_ind = self.labels.videos.index(lf.video)
frame_ind = lf.frame_idx
video_ind = np.array(self.labels.videos.index(lf.video)).astype("int32")
frame_ind = np.array(lf.frame_idx).astype("int64")
raw_image = lf.image
image = tf.convert_to_tensor(raw_image)
raw_image_size = tf.convert_to_tensor(raw_image.shape, dtype=tf.int32)
instances = [
tf.convert_to_tensor(inst.points_array, dtype=tf.float32)
for inst in lf.instances
]
skeleton_inds = [
self.labels.skeletons.index(inst.skeleton) for inst in lf.instances
]
return image, raw_image_size, instances, video_ind, frame_ind, skeleton_inds
raw_image_size = np.array(raw_image.shape).astype("int32")
instances = np.stack(
[inst.points_array.astype("float32") for inst in lf.instances], axis=0
)
skeleton_inds = np.array(
[self.labels.skeletons.index(inst.skeleton) for inst in lf.instances]
).astype("int32")
return (
raw_image,
raw_image_size,
instances,
video_ind,
frame_ind,
skeleton_inds,
)

def fetch_lf(ind):
"""Local function that fetches a sample given the index."""
Expand Down Expand Up @@ -245,9 +251,8 @@ def py_fetch_frame(ind):
"""Local function that will not be autographed."""
frame_ind = int(ind.numpy())
raw_image = self.video.get_frame(frame_ind)
image = tf.convert_to_tensor(raw_image)
raw_image_size = tf.convert_to_tensor(raw_image.shape, dtype=tf.int32)
return image, raw_image_size, frame_ind
raw_image_size = np.array(raw_image.shape).astype("int32")
return raw_image, raw_image_size, np.array(frame_ind).astype("int64")

def fetch_frame(ind):
"""Local function that fetches a sample given the index."""
Expand Down

0 comments on commit 613c201

Please sign in to comment.