Skip to content

Commit

Permalink
Fixed documentation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisA92 committed May 17, 2024
1 parent c083949 commit 11baf53
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ def test_forward_equals_inverse(batch_shape, d):
)
def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform):
transform = FillScaleTriL(diag_transform=diag_transform)
filltril = FillTriL()
diagtransform = DiagTransform(diag_transform=diag_transform)
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape, requires_grad=True)
transformed_vector = transform(input_vector)

# Calculate log abs det jacobian with autograd
# Calculate gradients log_abs_det_jacobian from FillScaleTriL
log_abs_det_jacobian = transform.log_abs_det_jacobian(
input_vector, transformed_vector
)

# Extract diagonal elements from input and transformed vectors
filltril = FillTriL()
diagtransform = DiagTransform(diag_transform=diag_transform)
tril = filltril(input_vector)
diagonal_transformed = diagtransform(tril)

# Calculate diagonal gradients with autograd
# Calculate diagonal gradients
diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed)

# Assert diagonal gradients are approximately equal
Expand Down

0 comments on commit 11baf53

Please sign in to comment.