Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Glow-like multi-scale flow #7

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,51 @@ def test_autoregressive_transforms():
assert (torch.triu(J, diagonal=1) == 0).all(), t
assert (torch.tril(J[:4, :4], diagonal=-1) == 0).all(), t
assert (torch.tril(J[4:, 4:], diagonal=-1) == 0).all(), t


def test_Glow(tmp_path):
flow = Glow((3, 32, 32), context=[5, 0, 5])

# Evaluation of log_prob
x, y = randn(8, 3, 32, 32), [randn(5, 16, 16), None, randn(8, 5, 4, 4)]
log_p = flow(y).log_prob(x)

assert log_p.shape == (8,)
assert log_p.requires_grad

flow.zero_grad(set_to_none=True)
loss = -log_p.mean()
loss.backward()

for p in flow.parameters():
assert p.grad is not None

# Sampling
x = flow(y).sample()

assert x.shape == (8, 3, 32, 32)

# Reparameterization trick
x = flow(y).rsample()

flow.zero_grad(set_to_none=True)
loss = x.square().sum().sqrt()
loss.backward()

for p in flow.parameters():
assert p.grad is not None

# Saving
torch.save(flow, tmp_path / 'flow.pth')

# Loading
flow_bis = torch.load(tmp_path / 'flow.pth')

x, y = randn(3, 32, 32), [randn(5, 16, 16), None, randn(5, 4, 4)]

seed = torch.seed()
log_p = flow(y).log_prob(x)
torch.manual_seed(seed)
log_p_bis = flow_bis(y).log_prob(x)

assert torch.allclose(log_p, log_p_bis)
17 changes: 17 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,20 @@ def test_MonotonicMLP():
J = torch.autograd.functional.jacobian(net, x)

assert (J >= 0).all()


def test_FCN():
net = FCN(3, 5)

# Non-batched
x = randn(3, 64, 64)
y = net(x)

assert y.shape == (5, 64, 64)
assert y.requires_grad

# Batched
x = randn(8, 3, 32, 32)
y = net(x)

assert y.shape == (8, 5, 32, 32)
67 changes: 57 additions & 10 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ def test_univariate_transforms():
assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t


def test_multivariate_transforms():
ts = [
LULinearTransform(randn(3, 3), dim=-2),
PermutationTransform(torch.randperm(3), dim=-2),
PixelShuffleTransform(dim=-2),
]

for t in ts:
# Shapes
x = randn(256, 3, 8)
y = t(x)

assert t.forward_shape(x.shape) == y.shape, t
assert t.inverse_shape(y.shape) == x.shape, t

# Inverse
z = t.inv(y)

assert x.shape == z.shape, t
assert torch.allclose(x, z, atol=1e-4), t

# Jacobian
x = randn(3, 8)
y = t(x)

jacobian = torch.autograd.functional.jacobian(t, x)
jacobian = jacobian.reshape(3 * 8, 3 * 8)

_, ladj = torch.slogdet(jacobian)

assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t

# Inverse Jacobian
z = t.inv(y)

jacobian = torch.autograd.functional.jacobian(t.inv, y)
jacobian = jacobian.reshape(3 * 8, 3 * 8)

_, ladj = torch.slogdet(jacobian)

assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t


def test_FFJTransform():
a = torch.randn(3)
f = lambda x, t: a * x
Expand All @@ -80,20 +123,24 @@ def test_FFJTransform():
assert ladj.shape == x.shape[:-1]


def test_PermutationTransform():
t = PermutationTransform(torch.randperm(8))
def test_DropTransform():
dist = Normal(randn(3), abs(randn(3)) + 1)
t = DropTransform(dist)

x = torch.randn(256, 8)
# Call
x = randn(256, 5)
y = t(x)

assert x.shape == y.shape

match = x[:, :, None] == y[:, None, :]

assert (match.sum(dim=-1) == 1).all()
assert (match.sum(dim=-2) == 1).all()
assert t.forward_shape(x.shape) == y.shape
assert t.inverse_shape(y.shape) == x.shape

# Inverse
z = t.inv(y)

assert x.shape == z.shape
assert (x == z).all()
assert not torch.allclose(x, z)

# Jacobian
ladj = t.log_abs_det_jacobian(x, y)

assert ladj.shape == (256,)
192 changes: 192 additions & 0 deletions zuko/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
'NAF',
'FreeFormJacobianTransform',
'CNF',
'ConvCouplingTransform',
'Glow',
]

import abc
Expand Down Expand Up @@ -753,3 +755,193 @@ def __init__(
)

super().__init__(transforms, base)


class ConvCouplingTransform(TransformModule):
r"""Creates a convolution coupling transformation.

Arguments:
channels: The number of channels.
context: The number of context channels.
spatial: The number of spatial dimensions.
univariate: The univariate transformation constructor.
shapes: The shapes of the univariate transformation parameters.
kwargs: Keyword arguments passed to :class:`zuko.nn.FCN`.
"""

def __init__(
self,
channels: int,
context: int = 0,
spatial: int = 2,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: List[Size] = [(), ()],
**kwargs,
):
super().__init__()

self.d = channels // 2
self.dim = -(spatial + 1)

# Univariate transformation
self.univariate = univariate
self.shapes = list(map(Size, shapes))
self.sizes = [s.numel() for s in self.shapes]

# Hyper network
kwargs.setdefault('activation', nn.ELU)
kwargs.setdefault('normalize', True)

self.hyper = FCN(
in_channels=self.d + context,
out_channels=(channels - self.d) * sum(self.sizes),
spatial=spatial,
**kwargs,
)

def extra_repr(self) -> str:
base = self.univariate(*map(torch.randn, self.shapes))

return f'(base): {base}'

def meta(self, y: Tensor, x: Tensor) -> Transform:
if y is not None:
x = torch.cat(broadcast(x, y, ignore=abs(self.dim)), dim=self.dim)

total = sum(self.sizes)

phi = self.hyper(x)
phi = phi.unflatten(self.dim, (phi.shape[self.dim] // total, total))
phi = phi.movedim(self.dim, -1)
phi = phi.split(self.sizes, -1)
phi = (p.unflatten(-1, s + (1,)) for p, s in zip(phi, self.shapes))
phi = (p.squeeze(-1) for p in phi)

return self.univariate(*phi)

def forward(self, y: Tensor = None) -> Transform:
return CouplingTransform(partial(self.meta, y), self.d, self.dim)


class Glow(DistributionModule):
r"""Creates a Glow-like multi-scale flow.

References:
| Glow: Generative Flow with Invertible 1x1 Convolutions (Kingma et al., 2018)
| https://arxiv.org/abs/1807.03039

Arguments:
shape: The shape of a sample.
context: The number of context channels at each scale.
transforms: The number of coupling transformations at each scale.
kwargs: Keyword arguments passed to :class:`ConvCouplingTransform`.
"""

def __init__(
self,
shape: Size,
context: Union[int, List[int]] = 0,
transforms: List[int] = [8, 8, 8],
**kwargs,
):
super().__init__()

channels, *space = shape
spatial = len(space)
dim = -len(shape)
scales = len(transforms)

assert all(s % 2**scales == 0 for s in space), (
f"'shape' cannot be downscaled {scales} times"
)

if isinstance(context, int):
context = [context] * len(transforms)

self.flows = nn.ModuleList()
self.bases = nn.ModuleList()

for i, K in enumerate(transforms):
flow = []
flow.append(Unconditional(PixelShuffleTransform, dim=dim))

channels = channels * 2**spatial
space = [s // 2 for s in space]

for _ in range(K):
flow.extend([
Unconditional(
PermutationTransform,
torch.randperm(channels),
dim=dim,
buffer=True,
),
Unconditional(
LULinearTransform,
torch.eye(channels),
dim=dim,
),
ConvCouplingTransform(
channels=channels,
context=context[i],
spatial=spatial,
**kwargs,
),
])

if i < len(transforms) - 1:
drop = channels // 2
else:
drop = channels

self.flows.append(nn.ModuleList(flow))
self.bases.append(
Unconditional(
DiagNormal,
torch.zeros(drop, *space),
torch.ones(drop, *space),
ndims=spatial + 1,
buffer=True,
)
)

channels = channels - drop

def forward(self, y: Iterable[Tensor] = None) -> NormalizingFlow:
r"""
Arguments:
y: A sequence of contexts :math:`y_i`. There should be one context
per scale, but a context can be :py:`None`.

Returns:
A multi-scale flow :math:`p(X | y)`.
"""

if y is None:
y = [None] * len(self.flows)

# Transforms
transforms = []
context_shapes = []

for flow, base, y_i in zip(self.flows, self.bases, y):
for t in flow:
transforms.append(t(y_i))

transforms.append(DropTransform(base(y_i)))

if y_i is not None:
context_shapes.append(y_i.shape)

transform = ComposedTransform(*transforms[:-1])

# Base
base = transforms[-1].dist
dim = -len(base.event_shape)

batch_shapes = (shape[:dim] for shape in context_shapes)
batch_shape = torch.broadcast_shapes(*batch_shapes)

base = base.expand(batch_shape)

return NormalizingFlow(transform, base)
Loading