diff --git a/src/rs_distributions/transforms/fill_scale_tril.py b/src/rs_distributions/transforms/fill_scale_tril.py index 03c97b4..f248f3f 100644 --- a/src/rs_distributions/transforms/fill_scale_tril.py +++ b/src/rs_distributions/transforms/fill_scale_tril.py @@ -5,6 +5,10 @@ class FillTriL(Transform): + """ + Transform for converting a real-valued vector into a lower triangular matrix + """ + def __init__(self): super().__init__() @@ -21,16 +25,30 @@ def bijective(self): return True def _call(self, x): + """ + Converts real-valued vector to lower triangular matrix. + + Args: + x (torch.Tensor): input real-valued vector + Returns: + torch.Tensor: Lower triangular matrix + """ + return vec_to_tril_matrix(x) def _inverse(self, y): return tril_matrix_to_vec(y) def log_abs_det_jacobian(self, x, y): - return torch.zeros(x.shape[0], dtype=x.dtype, device=x.device) + batch_shape = x.shape[:-1] + return torch.zeros(batch_shape, dtype=x.dtype, device=x.device) class DiagTransform(Transform): + """ + Applies transformation to the diagonal of a square matrix + """ + def __init__(self, diag_transform): super().__init__() self.diag_transform = diag_transform @@ -48,10 +66,15 @@ def bijective(self): return self.diag_transform.bijective def _call(self, x): + """ + Args: + x (torch.Tensor): Input matrix + Returns + torch.Tensor: Transformed matrix + """ diagonal = x.diagonal(dim1=-2, dim2=-1) transformed_diagonal = self.diag_transform(diagonal) - shifted_diag = transformed_diagonal - result = x.diagonal_scatter(shifted_diag, dim1=-2, dim2=-1) + result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1) return result @@ -66,9 +89,20 @@ def log_abs_det_jacobian(self, x, y): class FillScaleTriL(ComposeTransform): + """ + A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix. + The diagonal of the matrix is transformed with `diag_transform`. + """ + def __init__(self, diag_transform=SoftplusTransform()): super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)]) + self.diag_transform = diag_transform @property def bijective(self): return True + + def log_abs_det_jacobian(self, x, y): + x = FillTriL()._call(x) + diagonal = x.diagonal(dim1=-2, dim2=-1) + return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal) diff --git a/tests/transforms/fill_scale_tril.py b/tests/transforms/fill_scale_tril.py index 475960d..8d55a94 100644 --- a/tests/transforms/fill_scale_tril.py +++ b/tests/transforms/fill_scale_tril.py @@ -1,9 +1,12 @@ import pytest from rs_distributions.transforms.fill_scale_tril import ( FillScaleTriL, + FillTriL, + DiagTransform, ) import torch from torch.distributions.constraints import lower_cholesky +from torch.distributions.transforms import SoftplusTransform, ExpTransform @pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) @@ -35,3 +38,36 @@ def test_forward_equals_inverse(batch_shape, d): assert torch.allclose( input_vector, invL, atol=1e-4 ), "Original input and the result of applying inverse transformation are not close enough" + + +@pytest.mark.parametrize( + "batch_shape, d, diag_transform", + [ + ((2, 3), 6, SoftplusTransform()), + ((1, 4, 5), 10, SoftplusTransform()), + ((2, 3), 6, ExpTransform()), + ((1, 4, 5), 10, ExpTransform()), + ], +) +def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform): + transform = FillScaleTriL(diag_transform=diag_transform) + input_shape = batch_shape + (d,) + input_vector = torch.randn(input_shape, requires_grad=True) + transformed_vector = transform(input_vector) + + # Calculate log abs det jacobian with autograd + log_abs_det_jacobian = transform.log_abs_det_jacobian( + input_vector, transformed_vector + ) + + # Extract diagonal elements from input and transformed vectors + filltril = FillTriL() + diagtransform = DiagTransform(diag_transform=diag_transform) + tril = filltril(input_vector) + diagonal_transformed = diagtransform(tril) + + # Calculate diagonal gradients with autograd + diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed) + + # Assert diagonal gradients are approximately equal + assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4)