From 954407b450c4e98e4bef22e597e1c30457aa9bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Thu, 22 Dec 2022 16:16:17 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20Glow-like=20multi-scale=20flo?= =?UTF-8?q?w?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows.py | 48 +++++++ tests/test_nn.py | 17 +++ tests/test_transforms.py | 67 ++++++++-- zuko/distributions.py | 4 +- zuko/flows.py | 191 +++++++++++++++++++++++++++ zuko/nn.py | 125 +++++++++++++++++- zuko/transforms.py | 278 +++++++++++++++++++++++++++++++++++++-- 7 files changed, 709 insertions(+), 21 deletions(-) diff --git a/tests/test_flows.py b/tests/test_flows.py index 5dd84d7..fa614af 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -121,3 +121,51 @@ def test_autoregressive_transforms(): assert (torch.triu(J, diagonal=1) == 0).all(), t assert (torch.tril(J[:4, :4], diagonal=-1) == 0).all(), t assert (torch.tril(J[4:, 4:], diagonal=-1) == 0).all(), t + + +def test_Glow(tmp_path): + flow = Glow((3, 32, 32), context=[5, 0, 5]) + + # Evaluation of log_prob + x, y = randn(8, 3, 32, 32), [randn(5, 16, 16), None, randn(8, 5, 4, 4)] + log_p = flow(y).log_prob(x) + + assert log_p.shape == (8,) + assert log_p.requires_grad + + flow.zero_grad(set_to_none=True) + loss = -log_p.mean() + loss.backward() + + for p in flow.parameters(): + assert p.grad is not None + + # Sampling + x = flow(y).sample() + + assert x.shape == (8, 3, 32, 32) + + # Reparameterization trick + x = flow(y).rsample() + + flow.zero_grad(set_to_none=True) + loss = x.square().sum().sqrt() + loss.backward() + + for p in flow.parameters(): + assert p.grad is not None + + # Saving + torch.save(flow, tmp_path / 'flow.pth') + + # Loading + flow_bis = torch.load(tmp_path / 'flow.pth') + + x, y = randn(3, 32, 32), [randn(5, 16, 16), None, randn(5, 4, 4)] + + seed = torch.seed() + log_p = flow(y).log_prob(x) + torch.manual_seed(seed) + log_p_bis = flow_bis(y).log_prob(x) + + assert torch.allclose(log_p, log_p_bis) diff --git a/tests/test_nn.py b/tests/test_nn.py index 0ebdff6..742ed27 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -70,3 +70,20 @@ def test_MonotonicMLP(): J = torch.autograd.functional.jacobian(net, x) assert (J >= 0).all() + + +def test_FCN(): + net = FCN(3, 5) + + # Non-batched + x = randn(3, 64, 64) + y = net(x) + + assert y.shape == (5, 64, 64) + assert y.requires_grad + + # Batched + x = randn(8, 3, 32, 32) + y = net(x) + + assert y.shape == (8, 5, 32, 32) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5d89af0..1170c30 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -57,6 +57,49 @@ def test_univariate_transforms(): assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t +def test_multivariate_transforms(): + ts = [ + LULinearTransform(randn(3, 3), dim=-2), + PermutationTransform(torch.randperm(3), dim=-2), + PixelShuffleTransform(dim=-2), + ] + + for t in ts: + # Shapes + x = randn(256, 3, 8) + y = t(x) + + assert t.forward_shape(x.shape) == y.shape, t + assert t.inverse_shape(y.shape) == x.shape, t + + # Inverse + z = t.inv(y) + + assert x.shape == z.shape, t + assert torch.allclose(x, z, atol=1e-4), t + + # Jacobian + x = randn(3, 8) + y = t(x) + + jacobian = torch.autograd.functional.jacobian(t, x) + jacobian = jacobian.reshape(3 * 8, 3 * 8) + + _, ladj = torch.slogdet(jacobian) + + assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t + + # Inverse Jacobian + z = t.inv(y) + + jacobian = torch.autograd.functional.jacobian(t.inv, y) + jacobian = jacobian.reshape(3 * 8, 3 * 8) + + _, ladj = torch.slogdet(jacobian) + + assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t + + def test_FFJTransform(): a = torch.randn(3) f = lambda x, t: a * x @@ -79,20 +122,24 @@ def test_FFJTransform(): assert ladj.shape == x.shape[:-1] -def test_PermutationTransform(): - t = PermutationTransform(torch.randperm(8)) +def test_DropTransform(): + dist = Normal(randn(3), abs(randn(3)) + 1) + t = DropTransform(dist) - x = torch.randn(256, 8) + # Call + x = randn(256, 5) y = t(x) - assert x.shape == y.shape - - match = x[:, :, None] == y[:, None, :] - - assert (match.sum(dim=-1) == 1).all() - assert (match.sum(dim=-2) == 1).all() + assert t.forward_shape(x.shape) == y.shape + assert t.inverse_shape(y.shape) == x.shape + # Inverse z = t.inv(y) assert x.shape == z.shape - assert (x == z).all() + assert not torch.allclose(x, z) + + # Jacobian + ladj = t.log_abs_det_jacobian(x, y) + + assert ladj.shape == (256,) diff --git a/zuko/distributions.py b/zuko/distributions.py index 0a977c3..0a421a3 100644 --- a/zuko/distributions.py +++ b/zuko/distributions.py @@ -372,7 +372,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.uniform = Uniform(base.cdf(lower), base.cdf(upper)) @@ -430,7 +430,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.n = n diff --git a/zuko/flows.py b/zuko/flows.py index f208117..9247a7a 100644 --- a/zuko/flows.py +++ b/zuko/flows.py @@ -13,6 +13,8 @@ 'NAF', 'FreeFormJacobianTransform', 'CNF', + 'ConvCouplingTransform', + 'Glow', ] import abc @@ -752,3 +754,192 @@ def __init__( ) super().__init__(transforms, base) + + +class ConvCouplingTransform(TransformModule): + r"""Creates a convolution coupling transformation. + + Arguments: + channels: The number of channels. + context: The number of context channels. + spatial: The number of spatial dimensions. + univariate: The univariate transformation constructor. + shapes: The shapes of the univariate transformation parameters. + kwargs: Keyword arguments passed to :class:`zuko.nn.FCN`. + """ + + def __init__( + self, + channels: int, + context: int = 0, + spatial: int = 2, + univariate: Callable[..., Transform] = MonotonicAffineTransform, + shapes: List[Size] = [(), ()], + **kwargs, + ): + super().__init__() + + self.d = channels // 2 + self.dim = -(spatial + 1) + + # Univariate transformation + self.univariate = univariate + self.shapes = list(map(Size, shapes)) + self.sizes = [s.numel() for s in self.shapes] + + # Hyper network + kwargs.setdefault('activation', nn.ELU) + kwargs.setdefault('normalize', True) + + self.hyper = FCN( + in_channels=self.d + context, + out_channels=(channels - self.d) * sum(self.sizes), + spatial=spatial, + **kwargs, + ) + + def extra_repr(self) -> str: + base = self.univariate(*map(torch.randn, self.shapes)) + + return f'(base): {base}' + + def meta(self, y: Tensor, x: Tensor) -> Transform: + if y is not None: + x = torch.cat(broadcast(x, y, ignore=abs(self.dim)), dim=self.dim) + + total = sum(self.sizes) + + phi = self.hyper(x) + phi = phi.unflatten(self.dim, (phi.shape[self.dim] // total, total)) + phi = phi.movedim(self.dim, -1) + phi = phi.split(self.sizes, -1) + phi = (p.unflatten(-1, s + (1,)) for p, s in zip(phi, self.shapes)) + phi = (p.squeeze(-1) for p in phi) + + return self.univariate(*phi) + + def forward(self, y: Tensor = None) -> Transform: + return CouplingTransform(partial(self.meta, y), self.d, self.dim) + + +class Glow(FlowModule): + r"""Creates a Glow-like multi-scale flow. + + References: + | Glow: Generative Flow with Invertible 1x1 Convolutions (Kingma et al., 2018) + | https://arxiv.org/abs/1807.03039 + + Arguments: + shape: The shape of a sample. + context: The number of context channels at each scale. + transforms: The number of coupling transformations at each scale. + kwargs: Keyword arguments passed to :class:`ConvCouplingTransform`. + """ + + def __init__( + self, + shape: Size, + context: Union[int, List[int]] = 0, + transforms: List[int] = [8, 8, 8], + **kwargs, + ): + nn.Module.__init__(self) + + channels, *space = shape + spatial = len(space) + dim = -len(shape) + scales = len(transforms) + + assert all(s % 2**scales == 0 for s in space), ( + f"'shape' cannot be downscaled {scales} times" + ) + + if isinstance(context, int): + context = [context] * len(transforms) + + self.flows = nn.ModuleList() + self.bases = nn.ModuleList() + + for i, K in enumerate(transforms): + flow = [] + flow.append(Unconditional(PixelShuffleTransform, dim=dim)) + + channels = channels * 2**spatial + space = [s // 2 for s in space] + + for _ in range(K): + flow.extend([ + Unconditional( + PermutationTransform, + torch.randperm(channels), + dim=dim, + buffer=True, + ), + Unconditional( + LULinearTransform, + torch.eye(channels), + dim=dim, + ), + ConvCouplingTransform( + channels=channels, + context=context[i], + spatial=spatial, + **kwargs, + ), + ]) + + self.flows.append(nn.ModuleList(flow)) + self.bases.append( + Unconditional( + DiagNormal, + torch.zeros(channels // 2, *space), + torch.ones(channels // 2, *space), + ndims=spatial + 1, + buffer=True, + ) + ) + + channels = channels // 2 + + self.bases.append( + Unconditional( + DiagNormal, + torch.zeros(channels, *space), + torch.ones(channels, *space), + ndims=spatial + 1, + buffer=True, + ) + ) + + def forward(self, y: List[Tensor] = None) -> NormalizingFlow: + r""" + Arguments: + y: A list of contexts :math:`y`. There should be one element :math:`y_i` + per scale, but elements can be :py:`None`. + + Returns: + A multi-scale flow :math:`p(X | y)`. + """ + + if y is None: + y = [None] * len(self.flows) + + # Transforms + transforms = [] + + for flow, base, y_i in zip(self.flows, self.bases, y): + for t in flow: + transforms.append(t(y_i)) + + transforms.append(DropTransform(base(y_i))) + + # Base + base = self.bases[-1](None) + dim = -len(base.event_shape) + + batch_shapes = [y_i.shape[:dim] for y_i in y if y_i is not None] + batch_shape = torch.broadcast_shapes(*batch_shapes) + + base = base.expand(batch_shape) + + return NormalizingFlow(transforms, base) diff --git a/zuko/nn.py b/zuko/nn.py index 1c3667e..8aa902e 100644 --- a/zuko/nn.py +++ b/zuko/nn.py @@ -1,6 +1,6 @@ r"""Neural networks, layers and modules.""" -__all__ = ['MLP', 'MaskedMLP', 'MonotonicMLP'] +__all__ = ['MLP', 'MaskedMLP', 'MonotonicMLP', 'FCN'] import torch import torch.nn as nn @@ -283,3 +283,126 @@ def __init__(self, *args, **kwargs): layer.__class__ = MonotonicLinear elif isinstance(layer, nn.ELU): layer.__class__ = TwoWayELU + + +class FCN(nn.Sequential): + r"""Creates a fully convolutional neural network (FCN). + + The architecture is inspired by ConvNeXt blocks which mix depthwise and 1 by 1 + convolutions to improve the efficiency/accuracy trade-off. + + References: + | A ConvNet for the 2020s (Lui et al., 2022) + | https://arxiv.org/abs/2201.03545 + + Arguments: + in_channels: The number of input channels. + out_channels: The number of output channels. + hidden_channels: The number of hidden channels. + hidden_blocks: The number of hidden blocks. Each block consists in an optional + normalization, a depthwise convolution, an activation and a 1 by 1 + convolution. + kernel_size: The size of the convolution kernels. + activation: The activation function constructor. If :py:`None`, use + :class:`torch.nn.ReLU` instead. + normalize: Whether channels are normalized or not. + spatial: The number of spatial dimensions. Can be either 1, 2 or 3. + kwargs: Keyword arguments passed to :class:`torch.nn.Conv2d`. + + Example: + >>> net = FCN(3, 16, 64, activation=nn.ELU) + >>> net + FCN( + (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + (1): Conv2d(64, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64) + (2): ELU(alpha=1.0) + (3): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + (4): Conv2d(64, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64) + (5): ELU(alpha=1.0) + (6): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + (7): Conv2d(64, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64) + (8): ELU(alpha=1.0) + (9): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) + (10): Conv2d(64, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + ) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + hidden_blocks: int = 3, + kernel_size: int = 5, + activation: Callable[[], nn.Module] = None, + normalize: bool = False, + spatial: int = 2, + **kwargs, + ): + # Components + convolution = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + }.get(spatial) + + if activation is None: + activation = nn.ReLU + + if normalize: + normalization = lambda: LayerNorm(dim=-(spatial + 1)) + else: + normalization = lambda: None + + layers = [ + convolution( + in_channels, + hidden_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + **kwargs, + ) + ] + + for i in range(hidden_blocks): + layers.extend([ + normalization(), + convolution( + hidden_channels, + hidden_channels * 4, + groups=hidden_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + **kwargs, + ), + activation(), + convolution(hidden_channels * 4, hidden_channels, kernel_size=1), + ]) + + layers.append( + convolution( + hidden_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + **kwargs, + ) + ) + + layers = filter(lambda l: l is not None, layers) + + super().__init__(*layers) + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial = spatial + + def forward(self, x: Tensor) -> Tensor: + dim = -(self.spatial + 1) + batch_shape = x.shape[:dim] + + x = x.reshape(-1, *x.shape[dim:]) + y = super().forward(x) + y = y.reshape(*batch_shape, *y.shape[dim:]) + + return y diff --git a/zuko/transforms.py b/zuko/transforms.py index 730299b..1911ba8 100644 --- a/zuko/transforms.py +++ b/zuko/transforms.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F -from torch import Tensor, LongTensor +from torch import Tensor, LongTensor, Size from torch.distributions import * from torch.distributions import constraints from typing import * @@ -553,21 +553,148 @@ def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: return y, ladj.sum(dim=-1) +class CouplingTransform(Transform): + r"""Transform via a coupling scheme. + + .. math:: \begin{cases} + y_{ Tensor: + x0, x1 = x.tensor_split((self.d,), dim=self.dim) + y1 = self.meta(x0)(x1) + + return torch.cat((x0, y1), dim=self.dim) + + def _inverse(self, y: Tensor) -> Tensor: + x0, y1 = y.tensor_split((self.d,), dim=self.dim) + x1 = self.meta(x0).inv(y1) + + return torch.cat((x0, x1), dim=self.dim) + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + x0, x1 = x.tensor_split((self.d,), dim=self.dim) + _, y1 = y.tensor_split((self.d,), dim=self.dim) + + return self.meta(x0).log_abs_det_jacobian(x1, y1).flatten(self.dim).sum(dim=-1) + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + x0, x1 = x.tensor_split((self.d,), dim=self.dim) + y1, ladj = self.meta(x0).call_and_ladj(x1) + + return torch.cat((x0, y1), dim=self.dim), ladj.flatten(self.dim).sum(dim=-1) + + +class LULinearTransform(Transform): + r"""Creates a transformation :math:`f(x) = LU x`. + + Arguments: + LU: A matrix whose lower and upper triangular parts are the non-zero elements + of :math:`L` and :math:`U`, with shape :math:`(*, D, D)`. + dim: The dimension along which the product is applied. + """ + + bijective = True + + def __init__( + self, + LU: Tensor, + dim: int = -1, + **kwargs, + ): + super().__init__(**kwargs) + + I = torch.eye(LU.shape[-1]).to(LU) + + self.L = torch.tril(LU, diagonal=-1) + I + self.U = torch.triu(LU, diagonal=+1) + I + + if hasattr(torch.linalg, 'solve_triangular'): + self.solve = torch.linalg.solve_triangular + else: + self.solve = lambda A, B, **kws: torch.triangular_solve(B, A, **kws).solution + + self.dim = dim + + @property + def domain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + @property + def codomain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + def _call(self, x: Tensor) -> Tensor: + shape = x.shape + flat = shape[: self.dim] + (shape[self.dim], -1) + + return ((self.L @ self.U) @ x.reshape(flat)).reshape(shape) + + def _inverse(self, y: Tensor) -> Tensor: + shape = y.shape + flat = shape[: self.dim] + (shape[self.dim], -1) + + return self.solve( + self.U, + self.solve( + self.L, + y.reshape(flat), + upper=False, + unitriangular=True, + ), + upper=True, + unitriangular=True, + ).reshape(shape) + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + return x.new_zeros(x.shape[: self.dim]) + + class PermutationTransform(Transform): - r"""Creates a transformation that permutes the elements. + r"""Creates a transformation that permutes the elements along a dimension. Arguments: order: The permutation order, with shape :math:`(*, D)`. + dim: The dimension along which the elements are permuted. """ - domain = constraints.real_vector - codomain = constraints.real_vector bijective = True - def __init__(self, order: LongTensor, **kwargs): + def __init__(self, order: LongTensor, dim: int = -1, **kwargs): super().__init__(**kwargs) self.order = order + self.dim = dim def __repr__(self) -> str: order = self.order.tolist() @@ -577,11 +704,146 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({order})' + @property + def domain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + @property + def codomain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + def _call(self, x: Tensor) -> Tensor: + return x.index_select(self.dim, self.order) + + def _inverse(self, y: Tensor) -> Tensor: + return y.index_select(self.dim, torch.argsort(self.order)) + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + return x.new_zeros(x.shape[: self.dim]) + + +class PixelShuffleTransform(Transform): + r"""Creates a transformation that rearranges pixels into channels. + + See :class:`torch.nn.PixelShuffle` for a 2-d equivalent. + + Arguments: + dim: The channel dimension. + """ + + bijective = True + + def __init__(self, dim: int = -3, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + self.src = [i * 2 + 1 for i in range(dim + 1, 0)] + self.dst = [i + dim + 1 for i in range(dim + 1, 0)] + + @property + def domain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + @property + def codomain(self): + return constraints.independent(constraints.real, abs(self.dim)) + def _call(self, x: Tensor) -> Tensor: - return x[..., self.order] + space = ((s // 2, 2) for s in x.shape[self.dim + 1 :]) + space = (b for a in space for b in a) + + x = x.reshape(*x.shape[: self.dim], -1, *space) + x = x.movedim(self.src, self.dst) + x = x.flatten(self.dim * 2 + 1, self.dim) + + return x + + def _inverse(self, y: Tensor) -> Tensor: + shape = self.inverse_shape(y.shape) + + y = y.unflatten(self.dim, [shape[self.dim]] + [2] * (abs(self.dim) - 1)) + y = y.movedim(self.dst, self.src) + y = y.reshape(shape) + + return y + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + return x.new_zeros(x.shape[: self.dim]) + + def forward_shape(self, shape: Size) -> Size: + shape = list(shape) + shape[self.dim] *= 2 ** (abs(self.dim) - 1) + + for i in range(self.dim + 1, 0): + shape[i] //= 2 + + return Size(shape) + + def inverse_shape(self, shape: Size) -> Size: + shape = list(shape) + shape[self.dim] //= 2 ** (abs(self.dim) - 1) + + for i in range(self.dim + 1, 0): + shape[i] *= 2 + + return Size(shape) + + +class DropTransform(Transform): + r"""Creates a transformation that drops elements along a dimension. + + The :py:`log_abs_det_jacobian` method returns the log-density of the dropped + elements :math:`z` within a distribution :math:`p(Z)`. The inverse transformation + augments the dimension with a random variable :math:`z \sim p(Z)`. + + References: + | Augmented Normalizing Flows: Bridging the Gap Between Generative Flows and Latent Variable Models (Huang et al., 2020) + | https://arxiv.org/abs/2002.07101 + + Arguments: + dist: The distribution :math:`p(Z)`. + """ + + bijective = False + + def __init__(self, dist: Distribution, **kwargs): + super().__init__(**kwargs) + + if dist.batch_shape: + dist = Independent(dist, len(dist.batch_shape)) + + assert dist.event_shape, "'dist' has to be multivariate" + + self.dist = dist + self.dim = -len(dist.event_shape) + self.d = dist.event_shape[0] + + @property + def domain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + @property + def codomain(self): + return constraints.independent(constraints.real, abs(self.dim)) + + def _call(self, x: Tensor) -> Tensor: + z, x = x.tensor_split((self.d,), dim=self.dim) + return x def _inverse(self, y: Tensor) -> Tensor: - return y[..., torch.argsort(self.order)] + z = self.dist.sample(y.shape[: self.dim]) + return torch.cat((z, y), dim=self.dim) def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: - return x.new_zeros(x.shape[:-1]) + z, x = x.tensor_split((self.d,), dim=self.dim) + return self.dist.log_prob(z) + + def forward_shape(self, shape: Size) -> Size: + shape = list(shape) + shape[self.dim] -= self.d + return Size(shape) + + def inverse_shape(self, shape: Size) -> Size: + shape = list(shape) + shape[self.dim] += self.d + return Size(shape)