Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFDebertaModel and TFDebertaV2Model throws TypeError when keras.fit with Mixed Precision #31989

Closed
4 tasks
pinesnow72 opened this issue Jul 16, 2024 · 3 comments · Fixed by #32618 or pinesnow72/transformers#2

Comments

@pinesnow72
Copy link
Contributor

System Info

  • transformers version: 4.41.2
  • Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
  • Python version: 3.12.3
  • Huggingface_hub version: 0.23.2
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): 2.16.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@ArthurZucker, @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am trying to fine-tune TFDebertaModel and TFDebertaV2Model for NER task with setting mixed precision

policy = keras.mixed_precision.Policy('mixed_float16')
keras.mixed_precision.set_global_policy(policy)

model = TFDebertaModel.from_pretrained('deberta-base')
# or 
# model = TFDebertaV2Model.from_pretrained('deberta-v3-base')

....

model.fit(x=train_data, validation_data=valid_data, epochs=10)

However, when training this model, TypeError was thrown in TFDebertaEmbeddings like the followings:
TypeError: Exception encountered when calling layer 'embeddings' (type TFDebertaEmbeddings).
in user code:
File "/home/swlee/miniconda3/envs/tf216/lib/python3.12/site-packages/transformers/models/deberta/modeling_tf_deberta.py", line 929, in call *
final_embeddings = final_embeddings * mask
TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.

The case of TFDebertaV2Model was same with this.
With mixed precision, TF and Keras requires to use Layer.dtype for model or layer's weights and Layer.compute_dtype for internal tensor computation. But the current TFDebertaModel and TFDebertaV2Model codes do not seem to reflect this requirement and definitely assume the dtype would be tf.float32

Expected behavior

I hope that this bug could be fixed soon to support mixed precision.
Actually, I tried to search and correct some error-prone code snippets in modeling_tf_deberta.py and modeling_tf_deberta_v2.py.
Here is the list (but, I am not sure this is exhausted):

[in modeling_tf_deberta.py]

(lines: 105, 106)

  output = tf.where(rmask, float("-inf"), inputs)
  output = stable_softmax(output, self.axis)

(correction would be)

  output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)  # mixed precision # float("-inf")
  output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)  # mixed precision # output

(lines: 133, 135, 139)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
  inputs = tf.where(mask, 0.0, inputs) * scale
  return tf.where(mask, 0.0, upstream) * scale

(correction would be)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)  # mixed precision # dtype=tf.float32)
  inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale  # mixed precision # 0.0
  return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale  # mixed precision # 0.0

(lines: 705, 707)

  qkvw = tf.TensorArray(dtype=tf.float32, size=3)
  qkvw_inside = tf.TensorArray(dtype=tf.float32, size=self.num_attention_heads)

(correction would be)

  qkvw = tf.TensorArray(dtype=self.dtype, size=3)  # mixed precision # tf.float32
  qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads)  # mixed precision # tf.float32

(lines: 799)

  pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=tf.float32))

(correction would be)

  pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32

(lines: 927)

  mask = tf.cast(tf.expand_dims(mask, axis=2), tf.float32)

(correction would be)

  mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)

[in modeling_tf_deberta_v2.py]

(lines: 106, 107)

  output = tf.where(rmask, float("-inf"), inputs)
  output = stable_softmax(output, self.axis)

(correction would be)

  output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)  # mixed precision # float("-inf")
  output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)  # mixed precision # output

(lines: 135, 137, 141)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
  inputs = tf.where(mask, 0.0, inputs) * scale
  return tf.where(mask, 0.0, upstream) * scale

(correction would be)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)  # mixed precision # dtype=tf.float32)
  inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale  # mixed precision # 0.0
  return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale  # mixed precision # 0.0

(lines: 391, 404)

  out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
  input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)

(correction would be)

  out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
  input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype)  # mixed precision # tf.float32)

(lines: 770)

  scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))

(correction would be)

  scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))

(lines: 853, 867)

  scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, tf.float32))
  scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32))

(correction would be)

  scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))
  scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))

(lines: 1034)

  mask = tf.cast(tf.expand_dims(mask, axis=2), tf.float32)

(correction would be)

  mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)  # mixed precision # tf.float32)
@amyeroberts
Copy link
Collaborator

Hi @pinesnow72, thanks for raising an issue!

As you're trying something with custom code, specifically training with mixed precision, this is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

cc @Rocketknight1

@ArthurZucker
Copy link
Collaborator

(But @pinesnow72 your intuition is correct, here if we create the -inf in the mask based on float dtype it is problematic. Feel free to open a PR with your proposed changes, I am certain @Rocketknight1 will be able to review !)

@Rocketknight1
Copy link
Member

Yes, agreed - if you're willing to open the PR, I think it would be a good change!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants