diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index ad650f4f36..7d2b1c1fcc 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -64,7 +64,7 @@ logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils) dtypes_to_ignore_for_quantization = (int, float, bool, str, tuple, type(None)) -torch_dtypes_to_ignore_for_quantization = [torch.int, torch.int8, torch.int16, torch.int32, torch.int64, torch.bool] +torch_dtypes_to_ignore_for_quantization = [torch.int, torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, torch.uint8] allowed_output_types = (torch.Tensor, *dtypes_to_ignore_for_quantization) DROPOUT_TYPES = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)