diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index 210d8401..cb7d1855 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -52,6 +52,9 @@ def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): return jac_mat.expand(*expand_shape) def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): + if module.groups != 1: + raise NotImplementedError("Groups greater than 1 are not supported yet") + V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] @@ -71,6 +74,9 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): return self.reshape_like_output(jac_mat, module) def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + if module.groups != 1: + raise NotImplementedError("Groups greater than 1 are not supported yet") + V = mat.shape[0] G = module.groups C_in = module.input0.shape[1] diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index 1956b6ff..0a261b88 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -100,6 +100,9 @@ def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): return mat.sum(axes) def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): + if module.groups != 1: + raise NotImplementedError("Groups greater than 1 are not supported yet") + dims = self.dim_text dims_joined = dims.replace(",", "") @@ -109,6 +112,9 @@ def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): return self.reshape_like_output(jac_mat, module) def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + if module.groups != 1: + raise NotImplementedError("Groups greater than 1 are not supported yet") + V = mat.shape[0] N, C_out = module.output_shape[0], module.output_shape[1] C_in = module.input0_shape[1] diff --git a/test/core/derivatives/convolution_settings.py b/test/core/derivatives/convolution_settings.py index 1a45f6dc..efc4cfd1 100644 --- a/test/core/derivatives/convolution_settings.py +++ b/test/core/derivatives/convolution_settings.py @@ -276,3 +276,152 @@ "id_prefix": "non-default-conv", }, ] + +CONVOLUTION_FAIL_SETTINGS = [ + # groups - 2 + { + "module_fn": lambda: torch.nn.Conv1d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 7)), + "id_prefix": "groups-2", + }, + { + "module_fn": lambda: torch.nn.Conv2d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 7, 7)), + "id_prefix": "groups-2", + }, + { + "module_fn": lambda: torch.nn.Conv3d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 3, 7, 7)), + "id_prefix": "groups-2", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose1d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 7)), + "id_prefix": "groups-2", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose2d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 7, 7)), + "id_prefix": "groups-2", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose3d( + in_channels=4, + out_channels=6, + kernel_size=2, + padding=0, + dilation=2, + groups=2, + ), + "input_fn": lambda: torch.rand(size=(3, 4, 3, 7, 7)), + "id_prefix": "groups-2", + }, + # groups - 3 + { + "module_fn": lambda: torch.nn.Conv1d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 7)), + "id_prefix": "groups-3", + }, + { + "module_fn": lambda: torch.nn.Conv2d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 7, 7)), + "id_prefix": "groups-3", + }, + { + "module_fn": lambda: torch.nn.Conv3d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 3, 7, 7)), + "id_prefix": "groups-3", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose1d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 7)), + "id_prefix": "groups-3", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose2d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 7, 7)), + "id_prefix": "groups-3", + }, + { + "module_fn": lambda: torch.nn.ConvTranspose3d( + in_channels=6, + out_channels=9, + kernel_size=2, + padding=0, + dilation=2, + groups=3, + ), + "input_fn": lambda: torch.rand(size=(3, 6, 3, 7, 7)), + "id_prefix": "groups-3", + }, +] diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index f4661df7..06d6db82 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -13,6 +13,7 @@ from test.core.derivatives.problem import make_test_problems from test.core.derivatives.settings import SETTINGS from test.core.derivatives.loss_settings import LOSS_FAIL_SETTINGS +from test.core.derivatives.convolution_settings import CONVOLUTION_FAIL_SETTINGS import pytest import torch @@ -30,6 +31,9 @@ LOSS_FAIL_PROBLEMS = make_test_problems(LOSS_FAIL_SETTINGS) LOSS_FAIL_IDS = [problem.make_id() for problem in LOSS_FAIL_PROBLEMS] +CONVOLUTION_FAIL_PROBLEMS = make_test_problems(CONVOLUTION_FAIL_SETTINGS) +CONVOLUTION_FAIL_IDS = [problem.make_id() for problem in CONVOLUTION_FAIL_PROBLEMS] + @pytest.mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) def test_jac_mat_prod(problem, V=3): @@ -200,6 +204,21 @@ def test_sqrt_hessian_squared_equals_hessian(problem): problem.tear_down() +@pytest.mark.parametrize("problem", CONVOLUTION_FAIL_PROBLEMS, ids=CONVOLUTION_FAIL_IDS) +def test_weight_jac_mat_prod_should_fail(problem): + with pytest.raises(NotImplementedError): + test_weight_jac_mat_prod(problem) + + +@pytest.mark.parametrize( + "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] +) +@pytest.mark.parametrize("problem", CONVOLUTION_FAIL_PROBLEMS, ids=CONVOLUTION_FAIL_IDS) +def test_weight_jac_t_mat_prod_should_fail(problem, sum_batch): + with pytest.raises(NotImplementedError): + test_weight_jac_t_mat_prod(problem, sum_batch) + + @pytest.mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_should_fail(problem): with pytest.raises(ValueError): diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 2f7962d6..a0291220 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -89,6 +89,30 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 3, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv1d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(12, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, + { + "input_fn": lambda: torch.rand(3, 3, 8), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv1d( + 3, 6, 2, stride=4, padding=2, padding_mode="zeros", dilation=3 + ), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(18, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 7), "module_fn": lambda: torch.nn.Sequential( @@ -100,6 +124,17 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 2, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv1d(2, 3, 2, padding=0, dilation=2, groups=1), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(15, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -111,6 +146,30 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 3, 7, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv2d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(72, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 5), + }, + { + "input_fn": lambda: torch.rand(3, 3, 8, 8), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv2d( + 3, 6, 2, stride=4, padding=2, padding_mode="zeros", dilation=3 + ), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(54, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -122,6 +181,17 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 2, 7, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv2d(2, 3, 2, padding=0, dilation=2), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(75, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 2, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -133,6 +203,30 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 3, 2, 7, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv3d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(72, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, + { + "input_fn": lambda: torch.rand(3, 3, 4, 8, 8), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.Conv3d( + 3, 6, 2, padding=2, stride=4, dilation=3, padding_mode="zeros" + ), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(108, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 2, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -145,12 +239,12 @@ "target_fn": lambda: classification_targets((3,), 5), }, { - "input_fn": lambda: torch.rand(3, 1, 7), + "input_fn": lambda: torch.rand(3, 2, 3, 7, 7), "module_fn": lambda: torch.nn.Sequential( - torch.nn.ConvTranspose1d(1, 2, 2), + torch.nn.Conv3d(2, 3, 2, dilation=2, padding=0), torch.nn.ReLU(), torch.nn.Flatten(), - torch.nn.Linear(16, 5), + torch.nn.Linear(75, 5), ), "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), @@ -166,6 +260,17 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 3, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.ConvTranspose1d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(16, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 7), "module_fn": lambda: torch.nn.Sequential( @@ -177,6 +282,17 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 2, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.ConvTranspose1d(2, 3, 2, padding=0, dilation=5, stride=3), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(72, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(3, 3, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -188,6 +304,28 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), "target_fn": lambda: classification_targets((3,), 5), }, + { + "input_fn": lambda: torch.rand(3, 3, 7, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.ConvTranspose2d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(128, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "target_fn": lambda: classification_targets((3,), 5), + }, + { + "input_fn": lambda: torch.rand(3, 2, 9, 9), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.ConvTranspose2d(2, 4, 2, padding=0, dilation=2, groups=1), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(484, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "target_fn": lambda: classification_targets((3,), 5), + }, { "input_fn": lambda: torch.rand(2, 3, 2, 7, 7), "module_fn": lambda: torch.nn.Sequential( @@ -199,13 +337,24 @@ "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((2,), 5), }, + { + "input_fn": lambda: torch.rand(2, 3, 2, 7, 7), + "module_fn": lambda: torch.nn.Sequential( + torch.nn.ConvTranspose3d(3, 2, 2, bias=False), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(384, 5), + ), + "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), + "target_fn": lambda: classification_targets((2,), 5), + }, { "input_fn": lambda: torch.rand(2, 3, 5, 5, 5), "module_fn": lambda: torch.nn.Sequential( - torch.nn.ConvTranspose3d(3, 2, 2, padding=2, dilation=1, stride=2), + torch.nn.ConvTranspose3d(3, 2, 2, padding=2, dilation=2, stride=2), torch.nn.ReLU(), torch.nn.Flatten(), - torch.nn.Linear(432, 5), + torch.nn.Linear(686, 5), ), "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((2,), 5),