From f804c28060fc0b66608d4bbe7ce37bba80560832 Mon Sep 17 00:00:00 2001 From: Rishabh Thakur Date: Mon, 9 Sep 2024 21:09:25 +0530 Subject: [PATCH] Keras QAT Docs update for standalone batchnorms Signed-off-by: Rishabh Thakur --- Docs/api_docs/keras_quantsim.rst | 2 +- Docs/keras_code_examples/quantization.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Docs/api_docs/keras_quantsim.rst b/Docs/api_docs/keras_quantsim.rst index ad2c9936db..aeb436ed1e 100644 --- a/Docs/api_docs/keras_quantsim.rst +++ b/Docs/api_docs/keras_quantsim.rst @@ -43,7 +43,7 @@ Code Examples .. literalinclude:: ../keras_code_examples/quantization.py :language: python - :lines: 37-40 + :lines: 37-42 **Quantize with Fine tuning** diff --git a/Docs/keras_code_examples/quantization.py b/Docs/keras_code_examples/quantization.py index 7fda020437..3932ae5f40 100644 --- a/Docs/keras_code_examples/quantization.py +++ b/Docs/keras_code_examples/quantization.py @@ -38,6 +38,8 @@ import tensorflow as tf from aimet_tensorflow.keras import quantsim +# Optional import only required for fine-tuning +from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper def evaluate(model: tf.keras.Model, forward_pass_callback_args): """ @@ -68,6 +70,13 @@ def quantize_model(): sim.compute_encodings(evaluate, forward_pass_callback_args=(dummy_x, dummy_y)) # Do some fine-tuning + # Note:: For GPU workloads and models with non-trainable BatchNorms is not supported, + # So user need to explicitly set the BatchNorms to trainable. + # Below code snippet sets the BatchNorms to trainable + for layer in sim.model.layers: + if isinstance(layer, QcQuantizeWrapper) and isinstance(layer._layer_to_wrap, tf.keras.layers.BatchNormalization): + layer._layer_to_wrap.trainable = True + sim.model.fit(x=dummy_x, y=dummy_y, epochs=10) quantize_model()