Skip to content

Commit

Permalink
allow time dimension to be specified on init of the four time-related…
Browse files Browse the repository at this point in the history
… moddules to be added to text-to-image model. allow it to be changed with a function
  • Loading branch information
lucidrains committed Feb 10, 2024
1 parent 957b3c7 commit 7ebfa61
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
3 changes: 2 additions & 1 deletion lumiere_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
ConvolutionInflationBlock,
AttentionInflationBlock,
TemporalDownsample,
TemporalUpsample
TemporalUpsample,
set_time_dim_
)
54 changes: 43 additions & 11 deletions lumiere_pytorch/lumiere_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def divisible_by(num, den):
def is_odd(n):
return not divisible_by(n, 2)

def compact_values(d: dict):
return {k: v for k, v in d.items() if exists(v)}

# function that takes in the entire text-to-video network, and sets the time dimension

def set_time_dim_(model: Module, time_dim: int):
for model in model.modules():
if isinstance(model, (AttentionInflationBlock, ConvolutionInflationBlock, TemporalUpsample, TemporalDownsample)):
model.time_dim = time_dim

# decorator for converting an input tensor from either image or video format to 1d time

def image_or_video_to_time(fn):
Expand All @@ -45,14 +55,14 @@ def inner(
):

is_video = x.ndim == 5
assert is_video ^ exists(batch_size), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

if is_video:
batch_size = x.shape[0]
x = rearrange(x, 'b c t h w -> b h w c t')
else:
assert exists(batch_size)
x = rearrange(x, '(b t) c h w -> b h w c t', b = batch_size)
assert exists(batch_size) or exists(self.time_dim)
rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
x = rearrange(x, '(b t) c h w -> b h w c t', **compact_values(rearrange_kwargs))

x, ps = pack_one(x, '* c t')

Expand Down Expand Up @@ -96,8 +106,14 @@ def init_bilinear_kernel_1d_(conv: Module):
conv.weight.data[diag_mask] = bilinear_kernel

class TemporalDownsample(Module):
def __init__(self, dim):
def __init__(
self,
dim,
time_dim = None
):
super().__init__()
self.time_dim = time_dim

self.conv = nn.Conv1d(dim, dim, kernel_size = 3, stride = 2, padding = 1)
init_bilinear_kernel_1d_(self.conv)

Expand All @@ -109,8 +125,14 @@ def forward(
return self.conv(x)

class TemporalUpsample(Module):
def __init__(self, dim):
def __init__(
self,
dim,
time_dim = None
):
super().__init__()
self.time_dim = time_dim

self.conv = nn.ConvTranspose1d(dim, dim, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
init_bilinear_kernel_1d_(self.conv)

Expand All @@ -130,12 +152,15 @@ def __init__(
dim,
conv2d_kernel_size = 3,
conv1d_kernel_size = 3,
groups = 8
groups = 8,
time_dim = None
):
super().__init__()
assert is_odd(conv2d_kernel_size)
assert is_odd(conv1d_kernel_size)

self.time_dim = time_dim

self.spatial_conv = nn.Sequential(
nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
nn.GroupNorm(groups, num_channels = dim),
Expand All @@ -160,15 +185,17 @@ def forward(
):
residual = x
is_video = x.ndim == 5
assert is_video ^ exists(batch_size), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

if is_video:
batch_size = x.shape[0]
x = rearrange(x, 'b c t h w -> (b t) c h w')

x = self.spatial_conv(x)

x = rearrange(x, '(b t) c h w -> b h w c t', b = batch_size)
rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))

assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
x = rearrange(x, '(b t) c h w -> b h w c t', **rearrange_kwargs)

x, ps = pack_one(x, '* c t')

Expand All @@ -192,10 +219,13 @@ def __init__(
depth = 1,
prenorm = True,
residual_attn = True,
time_dim = None,
**attn_kwargs
):
super().__init__()

self.time_dim = time_dim

self.temporal_attns = ModuleList([])

for _ in range(depth):
Expand Down Expand Up @@ -223,14 +253,16 @@ def forward(
batch_size = None
):
is_video = x.ndim == 5
assert is_video ^ exists(batch_size), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

if is_video:
batch_size = x.shape[0]
x = rearrange(x, 'b c t h w -> b h w t c')
else:
assert exists(batch_size)
x = rearrange(x, '(b t) c h w -> b h w t c', b = batch_size)
assert exists(batch_size) or exists(self.time_dim)

rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))

x, ps = pack_one(x, '* t c')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lumiere-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Lumiere',
author = 'Phil Wang',
Expand Down

0 comments on commit 7ebfa61

Please sign in to comment.