Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Bijector #103

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from flowtorch.bijectors.autoregressive import Autoregressive
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.compose import Compose
from flowtorch.bijectors.coupling import ConvCouplingBijector
from flowtorch.bijectors.coupling import CouplingBijector
from flowtorch.bijectors.elementwise import Elementwise
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
Expand All @@ -28,13 +30,17 @@
from flowtorch.bijectors.softplus import Softplus
from flowtorch.bijectors.spline import Spline
from flowtorch.bijectors.spline_autoregressive import SplineAutoregressive
from flowtorch.bijectors.split_bijector import ReshapeBijector
from flowtorch.bijectors.split_bijector import SplitBijector
from flowtorch.bijectors.tanh import Tanh
from flowtorch.bijectors.volume_preserving import VolumePreserving

standard_bijectors = [
("Affine", Affine),
("AffineAutoregressive", AffineAutoregressive),
("AffineFixed", AffineFixed),
("ConvCouplingBijector", ConvCouplingBijector),
("CouplingBijector", CouplingBijector),
("ELU", ELU),
("Exp", Exp),
("LeakyReLU", LeakyReLU),
Expand All @@ -55,6 +61,8 @@
("Compose", Compose),
("Invert", Invert),
("VolumePreserving", VolumePreserving),
("ReshapeBijector", ReshapeBijector),
("SplitBijector", SplitBijector),
]


Expand Down
21 changes: 17 additions & 4 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,28 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(
AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Autoregressive.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
13 changes: 9 additions & 4 deletions flowtorch/bijectors/affine_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,38 @@ def __init__(
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
loc: float = 0.0,
scale: float = 1.0
scale: float = 1.0,
) -> None:
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.loc = loc
self.scale = scale

def _forward(
self,
x: torch.Tensor,
*inputs: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = inputs[0]
y = self.loc + self.scale * x
ladj: Optional[torch.Tensor] = None
if requires_log_detJ():
ladj = self._log_abs_det_jacobian(x, y, params)
return y, ladj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
y = inputs[0]
x = (y - self.loc) / self.scale
ladj: Optional[torch.Tensor] = None
if requires_log_detJ():
ladj = self._log_abs_det_jacobian(x, y, params)
return x, ladj

def _log_abs_det_jacobian(
self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
self,
x: torch.Tensor,
y: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> torch.Tensor:
return torch.full_like(x, math.log(abs(self.scale)))
12 changes: 9 additions & 3 deletions flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch
import torch.distributions.constraints as constraints
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.bijective_tensor import (
BijectiveTensor,
to_bijective_tensor,
)
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters.dense_autoregressive import DenseAutoregressive

Expand Down Expand Up @@ -60,7 +63,7 @@ def inverse(
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
_params = self._params_fn(x_new.clone(), context=context)
_params = self._params_fn(x_new.clone(), inverse=False, context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
Expand All @@ -78,6 +81,9 @@ def inverse(
return x_new

def _log_abs_det_jacobian(
self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
self,
x: torch.Tensor,
y: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> torch.Tensor:
raise NotImplementedError
42 changes: 32 additions & 10 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.bijective_tensor import (
BijectiveTensor,
to_bijective_tensor,
)
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters import Parameters
from torch.distributions import constraints

ParamFnType = Callable[
[Optional[torch.Tensor], Optional[torch.Tensor]], Optional[Sequence[torch.Tensor]]
[Optional[torch.Tensor], Optional[torch.Tensor]],
Optional[Sequence[torch.Tensor]],
]


Expand Down Expand Up @@ -60,6 +64,9 @@ def _check_bijective_x(
and x.check_context(context)
)

def _forward_pre_ops(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return (x,)

def forward(
self,
x: torch.Tensor,
Expand All @@ -71,8 +78,13 @@ def forward(
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
y, log_detJ = self._forward(x, params)
x_tuple = self._forward_pre_ops(x)
params = (
self._params_fn(*x_tuple, inverse=False, context=context)
if self._params_fn is not None
else None
)
y, log_detJ = self._forward(*x_tuple, params=params)
if (
is_record_flow_graph_enabled()
and not isinstance(y, BijectiveTensor)
Expand All @@ -84,7 +96,7 @@ def forward(

def _forward(
self,
x: torch.Tensor,
*x: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Expand All @@ -104,6 +116,9 @@ def _check_bijective_y(
and y.check_context(context)
)

def _inverse_pre_ops(self, y: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return (y,)

def inverse(
self,
y: torch.Tensor,
Expand All @@ -117,8 +132,13 @@ def inverse(
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
x, log_detJ = self._inverse(y, params)
y_tuple = self._inverse_pre_ops(y)
params = (
self._params_fn(*y_tuple, inverse=True, context=context)
if self._params_fn is not None
else None
)
x, log_detJ = self._inverse(*y_tuple, params=params)

if (
is_record_flow_graph_enabled()
Expand All @@ -130,7 +150,7 @@ def inverse(

def _inverse(
self,
y: torch.Tensor,
*y: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Expand Down Expand Up @@ -170,10 +190,12 @@ def log_abs_det_jacobian(
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not " "from cache."
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
self._params_fn(x, inverse=False, context=context)
if self._params_fn is not None
else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj
Expand Down
10 changes: 8 additions & 2 deletions flowtorch/bijectors/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import torch
import torch.distributions
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ
from flowtorch.bijectors.bijective_tensor import (
BijectiveTensor,
to_bijective_tensor,
)
from flowtorch.bijectors.utils import (
is_record_flow_graph_enabled,
requires_log_detJ,
)
from torch.distributions.utils import _sum_rightmost


Expand Down
133 changes: 133 additions & 0 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc
from copy import deepcopy
from typing import Optional, Sequence, Tuple

import flowtorch.parameters
import torch
from flowtorch.bijectors.ops.affine import Affine as AffineOp
from flowtorch.parameters import ConvCoupling, DenseCoupling
from torch.distributions import constraints


_REAL3d = deepcopy(constraints.real)
_REAL3d.event_dim = 3

_REAL1d = deepcopy(constraints.real)
_REAL1d.event_dim = 1


class CouplingBijector(AffineOp):
"""
Examples:
>>> params = DenseCoupling()
>>> bij = CouplingBijector(params)
>>> bij = bij(shape=torch.Size([32,]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(1, 32,requires_grad=True)
>>> y = bij.forward(x).detach_from_flow()
>>> x_bis = bij.inverse(y)
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL1d
codomain: constraints.Constraint = _REAL1d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if params_fn is None:
params_fn = DenseCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)

def _forward(
self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = inputs[0]
assert self._params_fn is not None

y, ldj = super()._forward(x, params=params)
return y, ldj

def _inverse(
self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
y = inputs[0]
assert self._params_fn is not None

x, ldj = super()._inverse(y, params=params)
return x, ldj


class ConvCouplingBijector(CouplingBijector):
"""
Examples:
>>> params = ConvCoupling()
>>> bij = ConvCouplingBijector(params)
>>> bij = bij(shape=torch.Size([3,16,16]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(4, 3, 16, 16)
>>> y = bij.forward(x)
>>> x_bis = bij.inverse(y.detach_from_flow())
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL3d
codomain: constraints.Constraint = _REAL3d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if not len(shape) == 3:
raise ValueError(f"Expected a 3d-tensor shape, got {shape}")

if params_fn is None:
params_fn = ConvCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Loading