diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e2a3cc5..90358e4 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,6 +3,7 @@ import pytest import torch +from torch.distributions import * from zuko.distributions import * diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5d89af0..70da322 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,6 +4,7 @@ import torch from torch import randn +from torch.distributions import * from zuko.transforms import * diff --git a/zuko/distributions.py b/zuko/distributions.py index 0a977c3..3668337 100644 --- a/zuko/distributions.py +++ b/zuko/distributions.py @@ -1,5 +1,19 @@ r"""Parameterizable probability distributions.""" +__all__ = [ + 'NormalizingFlow', + 'Joint', + 'GeneralizedNormal', + 'DiagNormal', + 'BoxUniform', + 'TransformedUniform', + 'Truncated', + 'Sort', + 'TopK', + 'Minimum', + 'Maximum', +] + import math import torch @@ -10,6 +24,8 @@ from torch.distributions.utils import _sum_rightmost from typing import * +from .transforms import ComposedTransform + Distribution._validate_args = False Distribution.arg_constraints = {} @@ -17,8 +33,8 @@ class NormalizingFlow(Distribution): r"""Creates a normalizing flow for a random variable :math:`X` towards a base - distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable - transformations :math:`f_1, f_2, \dots, f_n`. + distribution :math:`p(Z)` through a sequence of :math:`n` invertible and + differentiable transformations :math:`f_1, f_2, \dots, f_n`. The density of a realization :math:`x` is given by the change of variables @@ -36,7 +52,7 @@ class NormalizingFlow(Distribution): | https://arxiv.org/abs/1912.02762 Arguments: - transforms: A list of transformations :math:`f_i`. + transforms: A sequence of transformations :math:`f_i`. base: A base distribution :math:`p(Z)`. Example: @@ -45,25 +61,29 @@ class NormalizingFlow(Distribution): tensor(1.1316) """ + has_rsample = True + def __init__( self, - transforms: List[Transform], + transforms: Iterable[Transform], base: Distribution, ): super().__init__() - codomain_dim = ComposeTransform(transforms).codomain.event_dim - reinterpreted = codomain_dim - len(base.event_shape) + transform = ComposedTransform(*transforms) + reinterpreted = transform.codomain_dim - len(base.event_shape) if reinterpreted > 0: base = Independent(base, reinterpreted) - self.transforms = transforms + self.transform = transform self.base = base def __repr__(self) -> str: - lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)] - lines.append(f'(base): {self.base}') + lines = [ + f'(transform): {self.transform}', + f'(base): {self.base}', + ] lines = indent('\n'.join(lines), ' ') return self.__class__.__name__ + '(\n' + lines + '\n)' @@ -74,53 +94,33 @@ def batch_shape(self) -> Size: @property def event_shape(self) -> Size: - shape = self.base.event_shape - - for t in reversed(self.transforms): - shape = t.inverse_shape(shape) - - return shape + return self.transform.inverse_shape(self.base.event_shape) def expand(self, batch_shape: Size, new: Distribution = None): new = self._get_checked_instance(NormalizingFlow, new) - new.transforms = self.transforms + new.transform = self.transform new.base = self.base.expand(batch_shape) - Distribution.__init__(new, batch_shape=batch_shape, validate_args=False) + Distribution.__init__(new, validate_args=False) return new def log_prob(self, x: Tensor) -> Tensor: - acc = 0 - event_dim = len(self.event_shape) - - for t in self.transforms: - x, ladj = t.call_and_ladj(x) - acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim) - event_dim += t.codomain.event_dim - t.domain.event_dim + z, ladj = self.transform.call_and_ladj(x) + ladj = _sum_rightmost( + ladj, + len(self.base.event_shape) - self.transform.codomain_dim, + ) - return self.base.log_prob(x) + acc - - @property - def has_rsample(self) -> bool: - return self.base.has_rsample + return self.base.log_prob(z) + ladj def rsample(self, shape: Size = ()): - x = self.base.rsample(shape) - - for t in reversed(self.transforms): - x = t.inv(x) - - return x - - def sample(self, shape: Size = ()): - with torch.no_grad(): - x = self.base.sample(shape) - - for t in reversed(self.transforms): - x = t.inv(x) + if self.base.has_rsample: + z = self.base.rsample(shape) + else: + z = self.base.sample(shape) - return x + return self.transform.inv(z) class Joint(Distribution): @@ -372,7 +372,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.uniform = Uniform(base.cdf(lower), base.cdf(upper)) @@ -430,7 +430,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.n = n diff --git a/zuko/flows.py b/zuko/flows.py index f208117..faa6ab5 100644 --- a/zuko/flows.py +++ b/zuko/flows.py @@ -22,6 +22,7 @@ from functools import partial from math import ceil from torch import Tensor, LongTensor, Size +from torch.distributions import * from typing import * from .distributions import * diff --git a/zuko/nn.py b/zuko/nn.py index 1c3667e..99a0d81 100644 --- a/zuko/nn.py +++ b/zuko/nn.py @@ -39,7 +39,7 @@ def forward(self, x: Tensor) -> Tensor: class MLP(nn.Sequential): r"""Creates a multi-layer perceptron (MLP). - Also known as fully connected feedforward network, an MLP is a series of + Also known as fully connected feedforward network, an MLP is a sequence of non-linear parametric functions .. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}), diff --git a/zuko/transforms.py b/zuko/transforms.py index 730299b..121885f 100644 --- a/zuko/transforms.py +++ b/zuko/transforms.py @@ -1,12 +1,30 @@ r"""Parameterizable transformations.""" +__all__ = [ + 'ComposedTransform', + 'IdentityTransform', + 'CosTransform', + 'SinTransform', + 'SoftclipTransform', + 'MonotonicAffineTransform', + 'MonotonicRQSTransform', + 'MonotonicTransform', + 'UnconstrainedMonotonicTransform', + 'SOSPolynomialTransform', + 'FFJTransform', + 'AutoregressiveTransform', + 'PermutationTransform', +] + import math import torch import torch.nn.functional as F -from torch import Tensor, LongTensor +from textwrap import indent +from torch import Tensor, LongTensor, Size from torch.distributions import * from torch.distributions import constraints +from torch.distributions.utils import _sum_rightmost from typing import * from .utils import bisection, broadcast, gauss_legendre, odeint @@ -28,6 +46,97 @@ def _call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: Transform.call_and_ladj = _call_and_ladj +class ComposedTransform(Transform): + r"""Creates a transformation :math:`f(x) = f_n \circ \dots \circ f_1(x)`. + + Arguments: + transforms: A sequence of transformations :math:`f_i`. + """ + + def __init__(self, *transforms: Transform, **kwargs): + super().__init__(**kwargs) + + assert transforms, "'transforms' cannot be empty" + + event_dim = 0 + + for t in reversed(transforms): + event_dim = t.domain.event_dim + max(event_dim - t.codomain.event_dim, 0) + + self.domain_dim = event_dim + + for t in transforms: + event_dim += t.codomain.event_dim - t.domain.event_dim + + self.codomain_dim = event_dim + self.transforms = transforms + + def __repr__(self) -> str: + lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)] + lines = indent('\n'.join(lines), ' ') + + return f'{self.__class__.__name__}(\n' + lines + '\n)' + + @property + def domain(self) -> constraints.Constraint: + domain = self.transforms[0].domain + reinterpreted = self.domain_dim - domain.event_dim + + if reinterpreted > 0: + return constraints.independent(domain, reinterpreted) + else: + return domain + + @property + def codomain(self) -> constraints.Constraint: + codomain = self.transforms[-1].codomain + reinterpreted = self.codomain_dim - codomain.event_dim + + if reinterpreted > 0: + return constraints.independent(codomain, reinterpreted) + else: + return codomain + + @property + def bijective(self) -> bool: + return all(t.bijective for t in self.transforms) + + def _call(self, x: Tensor) -> Tensor: + for t in self.transforms: + x = t(x) + return x + + def _inverse(self, y: Tensor) -> Tensor: + for t in reversed(self.transforms): + y = t.inv(y) + return y + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + event_dim = self.domain_dim + acc = 0 + + for t in self.transforms: + x, ladj = t.call_and_ladj(x) + acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim) + event_dim += t.codomain.event_dim - t.domain.event_dim + + return x, acc + + def forward_shape(self, shape: Size) -> Size: + for t in self.transforms: + shape = t.forward_shape(shape) + return shape + + def inverse_shape(self, shape: Size) -> Size: + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + return shape + + class IdentityTransform(Transform): r"""Creates a transformation :math:`f(x) = x`.""" diff --git a/zuko/utils.py b/zuko/utils.py index 843cce0..f9b4c58 100644 --- a/zuko/utils.py +++ b/zuko/utils.py @@ -6,7 +6,6 @@ import numpy as np import torch -import torch.nn as nn from functools import lru_cache from torch import Tensor