Skip to content

Commit

Permalink
TF SAM shape flexibility fixes (#23842)
Browse files Browse the repository at this point in the history
SAM shape flexibility fixes for compilation
  • Loading branch information
Rocketknight1 authored May 30, 2023
1 parent af45ec0 commit ac224de
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> t
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head)
hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
)

def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
Expand Down Expand Up @@ -509,7 +510,7 @@ def call(
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if sparse_prompt_embeddings.shape[1] != 0:
if shape_list(sparse_prompt_embeddings)[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else:
tokens = output_tokens
Expand Down Expand Up @@ -695,8 +696,8 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
Expand All @@ -722,12 +723,12 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2]
batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0,
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
Expand All @@ -754,7 +755,7 @@ def call(
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
Expand All @@ -763,7 +764,7 @@ def call(
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = input_boxes.shape[0]
batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
Expand Down Expand Up @@ -1376,8 +1377,8 @@ def call(
" got {}.".format(input_boxes.shape),
)
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
point_batch_size = shape_list(input_points)[1]
box_batch_size = shape_list(input_boxes)[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
Expand Down

0 comments on commit ac224de

Please sign in to comment.