Skip to content

Commit

Permalink
Improve error message for instance norm when channels is incorrect
Browse files Browse the repository at this point in the history
ghstack-source-id: 100724569b7ad7333a8e24f232e7306c713b080d
Pull Request resolved: #94624
  • Loading branch information
soulitzer committed Mar 3, 2023
1 parent ec8bffe commit a96d231
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
17 changes: 17 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8143,6 +8143,23 @@ def test_InstanceNorm3d_general(self, device):
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)

@parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__)
@parametrize_test("no_batch_dim", [True, False])
@parametrize_test("affine", [True, False])
def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine):
inst_norm = instance_norm_cls(4, affine=affine)
size = [2] * inst_norm._get_no_batch_dim()
if not no_batch_dim:
size = [3] + size
t = torch.randn(size)
if affine:
with self.assertRaisesRegex(ValueError, "expected input's size at dim="):
inst_norm(t)
else:
with warnings.catch_warnings(record=True) as w:
inst_norm(t)
self.assertIn("which is not used because affine=False", str(w[0].message))

def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device):
x = torch.rand(10)[None, :, None]
with self.assertRaises(ValueError):
Expand Down
13 changes: 13 additions & 0 deletions torch/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

import warnings
from torch import Tensor

from .batchnorm import _LazyNormBase, _NormBase
Expand Down Expand Up @@ -68,6 +70,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)

feature_dim = input.dim() - self._get_no_batch_dim()
if input.size(feature_dim) != self.num_features:
if self.affine:
raise ValueError(
f"expected input's size at dim={feature_dim} to match num_features"
f" ({self.num_features}), but got: {input.size(feature_dim)}.")
else:
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
"You can silence this warning by not passing in num_features, "
"which is not used because affine=False")

if input.dim() == self._get_no_batch_dim():
return self._handle_no_batch_input(input)

Expand Down

0 comments on commit a96d231

Please sign in to comment.