Skip to content

Commit

Permalink
Added gradient checkpointing in neural models and propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed May 1, 2024
1 parent 3f5b896 commit a214495
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 132 deletions.
21 changes: 12 additions & 9 deletions dflat/metasurface/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
"Nanocylinders_TiO2_U300H600": "https://www.dropbox.com/scl/fi/sn44f2xzadcrag0jgzdsq/Nanocylinders_TiO2_U300H600.zip?rlkey=5hivknv8cvfy3gzzsyolb8bz5&dl=1",
"Nanocylinders_TiO2_U350H600": "https://www.dropbox.com/scl/fi/43mf1xidor3mti9dv8bce/Nanocylinders_TiO2_U350H600.zip?rlkey=cyj6xb3reh5iv2rj9l1byxk27&dl=1",
"Nanoellipse_TiO2_U350H600": "https://www.dropbox.com/scl/fi/6phh6a0kztbccy76vzwjd/Nanoellipse_TiO2_U350H600.zip?rlkey=0hn8cr2kgs3t9134kmrhf1ogx&dl=1",
"Nanofins_TiO2_U350H600": "https://www.dropbox.com/scl/fi/co65yfwugkvugi7r8bqaj/Nanofins_TiO2_U350H600.zip?rlkey=8e0pzvzul8xlzl9szf15lbrzx&dl=1"
}
"Nanofins_TiO2_U350H600": "https://www.dropbox.com/scl/fi/co65yfwugkvugi7r8bqaj/Nanofins_TiO2_U350H600.zip?rlkey=8e0pzvzul8xlzl9szf15lbrzx&dl=1",
}


def model_config_path(model_name):
dir_path = Path("DFlat/Models/")
Expand All @@ -26,18 +27,18 @@ def model_config_path(model_name):
config_exists = os.path.exists(os.path.join(dir_path, model_name, "config.yaml"))
if not config_exists:
print("Downloading the model from online storage.")
zip_path = dir_path / "data.zip"
zip_path = dir_path / "data.zip"
load_url = req_paths[model_name]

with requests.get(load_url, stream=True) as response:
response.raise_for_status()
with open(zip_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
with open(zip_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(dir_path)

zip_path.unlink()

config_path = os.path.join(dir_path, model_name, "config.yaml")
Expand All @@ -47,6 +48,7 @@ def model_config_path(model_name):

return config_path, ckpt_path


def load_optical_model(model_name):
"""Loads a neural optical model.
Expand All @@ -59,10 +61,11 @@ def load_optical_model(model_name):

config_path, ckpt_path = model_config_path(model_name)
config = OmegaConf.load(config_path)

optical_model = instantiate_from_config(config.model, ckpt_path, strict=True)
return optical_model


def instantiate_from_config(config_model, ckpt_path=None, strict=False):
assert "target" in config_model, "Expected key `target` to instantiate."
target_str = config_model["target"]
Expand All @@ -82,6 +85,7 @@ def instantiate_from_config(config_model, ckpt_path=None, strict=False):

return loaded_module


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
Expand Down Expand Up @@ -115,4 +119,3 @@ def load_trainer(config_path):
**config_trainer.get("params", dict()),
)
return trainer

55 changes: 55 additions & 0 deletions dflat/metasurface/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,67 @@
import torch.nn as nn
import torch.nn.functional as F


def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)


class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])

with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors

@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads


class VanillaMLP(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
use_checkpoint=True,
):
super().__init__()
self.blocks = nn.ModuleList()
self.use_checkpoint = use_checkpoint

chi = in_channels
for ch in channels:
Expand All @@ -24,6 +76,9 @@ def __init__(
self.blocks.append(nn.Linear(chi, out_channels))

def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

def _forward(self, x):
for block in self.blocks:
x = block(x)
return x
94 changes: 14 additions & 80 deletions dflat/metasurface/nn_siren.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# COPIED FROM LUCIDRAIN
# LATER MODIFIED WITH VARIOUS TWEAKS AND ADJUSTMENTS
## https://github.com/lucidrains/siren-pytorch/blob/master/siren_pytorch/siren_pytorch.py


Expand All @@ -8,15 +9,13 @@
import torch.nn.functional as F
from einops import rearrange

from .nn import checkpoint


def exists(val):
return val is not None


def cast_tuple(val, repeat=1):
return val if isinstance(val, tuple) else ((val,) * repeat)


class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
Expand All @@ -37,10 +36,12 @@ def __init__(
use_bias=True,
activation=None,
dropout=0.0,
use_checkpoint=True,
):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
self.use_checkpoint = True

weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
Expand All @@ -53,14 +54,15 @@ def __init__(

def init_(self, weight, bias, c, w0):
dim = self.dim_in

w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)

if exists(bias):
bias.uniform_(-w_std, w_std)

def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

def _forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
out = self.dropout(out)
Expand All @@ -79,10 +81,12 @@ def __init__(
use_bias=True,
final_activation=None,
dropout=0.0,
use_checkpoint=True,
):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.use_checkpoint = use_checkpoint

self.layers = nn.ModuleList([])
for ind in range(num_layers):
Expand All @@ -98,7 +102,6 @@ def __init__(
is_first=is_first,
dropout=dropout,
)

self.layers.append(layer)

final_activation = (
Expand All @@ -112,79 +115,10 @@ def __init__(
activation=final_activation,
)

def forward(self, x, mods=None):
mods = cast_tuple(mods, self.num_layers)

for layer, mod in zip(self.layers, mods):
x = layer(x)

if exists(mod):
x *= rearrange(mod, "d -> () d")

return self.last_layer(x)


class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
super().__init__()
self.layers = nn.ModuleList([])

for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else (dim_hidden + dim_in)

self.layers.append(nn.Sequential(nn.Linear(dim, dim_hidden), nn.ReLU()))

def forward(self, z):
x = z
hiddens = []
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

def _forward(self, x):
for layer in self.layers:
x = layer(x)
hiddens.append(x)
x = torch.cat((x, z))

return tuple(hiddens)


class SirenWrapper(nn.Module):
def __init__(self, net, image_width, image_height, latent_dim=None):
super().__init__()
assert isinstance(net, SirenNet), "SirenWrapper must receive a Siren network"

self.net = net
self.image_width = image_width
self.image_height = image_height

self.modulator = None
if exists(latent_dim):
self.modulator = Modulator(
dim_in=latent_dim, dim_hidden=net.dim_hidden, num_layers=net.num_layers
)

tensors = [
torch.linspace(-1, 1, steps=image_height),
torch.linspace(-1, 1, steps=image_width),
]
mgrid = torch.stack(torch.meshgrid(*tensors, indexing="ij"), dim=-1)
mgrid = rearrange(mgrid, "h w c -> (h w) c")
self.register_buffer("grid", mgrid)

def forward(self, img=None, *, latent=None):
modulate = exists(self.modulator)
assert not (
modulate ^ exists(latent)
), "latent vector must be only supplied if `latent_dim` was passed in on instantiation"

mods = self.modulator(latent) if modulate else None

coords = self.grid.clone().detach().requires_grad_()
out = self.net(coords, mods)
out = rearrange(
out, "(h w) c -> () c h w", h=self.image_height, w=self.image_width
)

if exists(img):
return F.mse_loss(img, out)

return out
return self.last_layer(x)
24 changes: 10 additions & 14 deletions dflat/metasurface/optical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class NeuralCells(nn.Module):
def __init__(self, nn_config, param_bounds, trainable_model=False):
def __init__(self, nn_config, param_bounds, use_checkpoint=True, **kwargs):
"""Initializes a neural cell model from a config dictionary.
Args:
Expand All @@ -17,11 +17,19 @@ def __init__(self, nn_config, param_bounds, trainable_model=False):
trainable_model (bool, optional): Flag to set model parameters to trainable (requires grad). Defaults to False.
"""
super().__init__()

if "trainable_model" in kwargs:
print(
"Note: trainable_model key in NeuralCells is deprecated. Model will be set to requires_grad."
)
self.dim_in = nn_config.params.dim_in
self.dim_out = nn_config.params.dim_out
self.model = self._initialize_model(nn_config, trainable_model)
self.model = instantiate_from_config(
nn_config, ckpt_path=nn_config["ckpt_path"], strict=False
)
self.param_bounds = param_bounds
self.loss = get_obj_from_str(nn_config.loss)()
self.use_checkpoint = use_checkpoint

def training_step(self, x, y):
pred = self.model(x)
Expand Down Expand Up @@ -113,18 +121,6 @@ def denormalize(self, params):
)
return out

def _initialize_model(self, config, trainable_model):
model = instantiate_from_config(
config, ckpt_path=config["ckpt_path"], strict=False
)

if not trainable_model:
model = model.eval()
for param in model.parameters():
param.requires_grad = False

return model


class NeuralFields(nn.Module):
def __init__(self, nn_config, trainable_model=False):
Expand Down
1 change: 0 additions & 1 deletion dflat/metasurface/reverse_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def reverse_lookup_optimize(

print(f"Running optimization with device {device}")
model = load_optical_model(model_name).to(device)

pg = model.dim_out // 3
assert pg == P, f"Polarization dimension of amp, phase (dim1) expected to be {pg}."

Expand Down
5 changes: 3 additions & 2 deletions dflat/propagation/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.fft import fftshift, ifftshift, fft2, ifft2
from einops import rearrange

Expand Down Expand Up @@ -268,7 +269,7 @@ def forward(self, amplitude, phase, **kwargs):
amplitude, phase = self._regularize_field(amplitude, phase)

# propagate by the fresnel method
amplitude, phase = self.fresnel_transform(amplitude, phase)
amplitude, phase = checkpoint(self.fresnel_transform, amplitude, phase)

# Transform field back to the specified output grid
amplitude, phase = self._resample_field(amplitude, phase)
Expand Down Expand Up @@ -517,7 +518,7 @@ def forward(self, amplitude, phase, **kwargs):
amplitude, phase = self._regularize_field(amplitude, phase)

# propagate by the asm method
amplitude, phase = self.ASM_transform(amplitude, phase)
amplitude, phase = checkpoint(self.ASM_transform, amplitude, phase)

# Transform field back to the specified output grid and convert to 2D
amplitude, phase = self._resample_field(amplitude, phase)
Expand Down
Loading

0 comments on commit a214495

Please sign in to comment.