Skip to content

Commit

Permalink
Added test for FillScaleTriL().log_abs_det_jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisA92 committed May 17, 2024
1 parent 70e48f3 commit c083949
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
40 changes: 37 additions & 3 deletions src/rs_distributions/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@


class FillTriL(Transform):
"""
Transform for converting a real-valued vector into a lower triangular matrix
"""

def __init__(self):
super().__init__()

Expand All @@ -21,16 +25,30 @@ def bijective(self):
return True

def _call(self, x):
"""
Converts real-valued vector to lower triangular matrix.
Args:
x (torch.Tensor): input real-valued vector
Returns:
torch.Tensor: Lower triangular matrix
"""

return vec_to_tril_matrix(x)

def _inverse(self, y):
return tril_matrix_to_vec(y)

def log_abs_det_jacobian(self, x, y):
return torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
batch_shape = x.shape[:-1]
return torch.zeros(batch_shape, dtype=x.dtype, device=x.device)


class DiagTransform(Transform):
"""
Applies transformation to the diagonal of a square matrix
"""

def __init__(self, diag_transform):
super().__init__()
self.diag_transform = diag_transform
Expand All @@ -48,10 +66,15 @@ def bijective(self):
return self.diag_transform.bijective

def _call(self, x):
"""
Args:
x (torch.Tensor): Input matrix
Returns
torch.Tensor: Transformed matrix
"""
diagonal = x.diagonal(dim1=-2, dim2=-1)
transformed_diagonal = self.diag_transform(diagonal)
shifted_diag = transformed_diagonal
result = x.diagonal_scatter(shifted_diag, dim1=-2, dim2=-1)
result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1)

return result

Expand All @@ -66,9 +89,20 @@ def log_abs_det_jacobian(self, x, y):


class FillScaleTriL(ComposeTransform):
"""
A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix.
The diagonal of the matrix is transformed with `diag_transform`.
"""

def __init__(self, diag_transform=SoftplusTransform()):
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
self.diag_transform = diag_transform

@property
def bijective(self):
return True

def log_abs_det_jacobian(self, x, y):
x = FillTriL()._call(x)
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal)
36 changes: 36 additions & 0 deletions tests/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
from rs_distributions.transforms.fill_scale_tril import (
FillScaleTriL,
FillTriL,
DiagTransform,
)
import torch
from torch.distributions.constraints import lower_cholesky
from torch.distributions.transforms import SoftplusTransform, ExpTransform


@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)])
Expand Down Expand Up @@ -35,3 +38,36 @@ def test_forward_equals_inverse(batch_shape, d):
assert torch.allclose(
input_vector, invL, atol=1e-4
), "Original input and the result of applying inverse transformation are not close enough"


@pytest.mark.parametrize(
"batch_shape, d, diag_transform",
[
((2, 3), 6, SoftplusTransform()),
((1, 4, 5), 10, SoftplusTransform()),
((2, 3), 6, ExpTransform()),
((1, 4, 5), 10, ExpTransform()),
],
)
def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform):
transform = FillScaleTriL(diag_transform=diag_transform)
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape, requires_grad=True)
transformed_vector = transform(input_vector)

# Calculate log abs det jacobian with autograd
log_abs_det_jacobian = transform.log_abs_det_jacobian(
input_vector, transformed_vector
)

# Extract diagonal elements from input and transformed vectors
filltril = FillTriL()
diagtransform = DiagTransform(diag_transform=diag_transform)
tril = filltril(input_vector)
diagonal_transformed = diagtransform(tril)

# Calculate diagonal gradients with autograd
diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed)

# Assert diagonal gradients are approximately equal
assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4)

0 comments on commit c083949

Please sign in to comment.