Skip to content

Commit

Permalink
added params_size and default AffineTransform shift to FillScaleTriL
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisA92 committed May 22, 2024
1 parent 11baf53 commit fd7cb3d
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/rs_distributions/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.distributions import Transform, ComposeTransform, constraints
from torch.distributions.transforms import SoftplusTransform
from torch.distributions.transforms import SoftplusTransform, AffineTransform
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec


Expand Down Expand Up @@ -94,7 +94,12 @@ class FillScaleTriL(ComposeTransform):
The diagonal of the matrix is transformed with `diag_transform`.
"""

def __init__(self, diag_transform=SoftplusTransform()):
def __init__(self, diag_transform=None):
if diag_transform is None:
diag_transform = torch.distributions.ComposeTransform((
SoftplusTransform(),
AffineTransform(1e-5, 1.),
))
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
self.diag_transform = diag_transform

Expand All @@ -106,3 +111,17 @@ def log_abs_det_jacobian(self, x, y):
x = FillTriL()._call(x)
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal)

@staticmethod
def params_size(event_size):
"""
Returns the number of parameters required to create lower triangular matrix, which is given by n*(n+1)//2
Args:
event_size (int): size of event
Returns:
int: Number of parameters needed
"""
return event_size * (event_size + 1) // 2

0 comments on commit fd7cb3d

Please sign in to comment.