diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index 24f273f115..c3088ad342 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -913,7 +913,7 @@ def _create_quantized_module(module): # (lambda: custom.NotEqual(), lambda: ...), # (lambda: custom.Equal(), lambda: ...), (lambda: custom.Bmm(), lambda: (randn(1, 100, 100), randn(1, 100, 100))), - (lambda: custom.CumSum(), lambda: (randn(10, 100), 0)), + (lambda: custom.CumSum(), lambda: (randn(10, 100), tensor(0))), # (lambda: custom.MaskedFill(), lambda: ...), # (lambda: custom.Mean(), lambda: ...), # (lambda: custom.Sum(), lambda: ...),