Skip to content

Commit

Permalink
Tweak actions to ignore certain tests
Browse files Browse the repository at this point in the history
ignore dataset tests and cuda call
  • Loading branch information
DeanHazineh committed Jan 25, 2024
1 parent 81e43b7 commit 4240619
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ jobs:
pip install .
- name: Run pytest suite
run: |
pytest tests/
pytest tests/ --ignore=tests/test_metasurface.py tests/test_propagation.py
2 changes: 1 addition & 1 deletion dflat/metasurface/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self):
data = scipy.io.loadmat(datpath)
self.phase = data["phase"]
self.trans = np.sqrt(np.clip(data["transmission"], 0, np.finfo(np.float32).max))
self.params = [data["radius_m"], data["wavelength_m"].flatten()]
self.params = [data["radius_m"], data["wavelength_m"].flatten()] # This is diameter not radius
self.param_limits = [[30e-9, 150e-9], [310e-9, 750e-9]]

# Transform data into a cell-level dataset ([0, 1])
Expand Down
4 changes: 2 additions & 2 deletions dflat/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import matplotlib.pyplot as plt

# Call trainer on model
#config_path = 'metasurface/ckpt/Nanofins_TiO2_U350H600_Medium/config.yaml'
config_path = 'metasurface/ckpt/Nanocylinders_TiO2_U180H600_Medium/config.yaml'
config_path = 'metasurface/ckpt/Nanofins_TiO2_U350H600_Medium/config.yaml'
#config_path = 'metasurface/ckpt/Nanocylinders_TiO2_U180H600_Medium/config.yaml'
trainer = load_trainer(config_path)
trainer.train()

Expand Down
210 changes: 210 additions & 0 deletions tests/test_propagation_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import pytest
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from functools import partial

from dflat.propagation import PointSpreadFunction, ASMPropagation, FresnelPropagation
from dflat.initialize import focusing_lens


@pytest.fixture
def shared_init():
sd = {
"in_size": [201, 201],
"in_dx_m": [2e-6, 2e-6],
"out_distance_m": 10e-3,
"out_size": [200, 200],
"out_dx_m": [1e-6, 1e-6],
"out_resample_dx_m": [2e-6, 2e-6],
"manual_upsample_factor": 1.0,
"radial_symmetry": False,
"diffraction_engine": "asm",
}

get_lens = partial(
focusing_lens,
in_size=sd["in_size"],
in_dx_m=sd["in_dx_m"],
wavelength_set_m=[600e-9],
depth_set_m=[20e-3],
fshift_set_m=[[0.0, 0.0]],
out_distance_m=sd["out_distance_m"],
aperture_radius_m=200e-6,
)

def rewrap_phase(x):
x = x.cpu().numpy() if torch.is_tensor(x) else x
cidx = x.shape[-1] // 2
return np.angle(np.exp(1j * (x - x[..., cidx : cidx + 1, cidx : cidx + 1])))

return sd, get_lens, rewrap_phase


class Test_PointSpreadFunction:
@pytest.fixture(autouse=True)
def initialize(self, shared_init):
self.init_dict, self.get_lens, self.rewrap_phase = shared_init

def test_init(self):
for engine in ["asm", "fresnel"]:
sd = deepcopy(self.init_dict)
sd["diffraction_engine"] = engine
PointSpreadFunction(**sd)

with pytest.raises(AssertionError):
sd["diffraction_engine"] = "invalid"
PointSpreadFunction(**sd)

def test_forward(self):
def reshape_dat(x):
return x.view(-1, *self.init_dict["out_size"]).cpu().numpy()

# Repeat claculate for radial symmetry and 2D
device_list = []
device_list.append("cpu")

for radial_flag in [True, False]:
amp, phase, aperture = self.get_lens(radial_symmetry=radial_flag)
amp = torch.tensor(amp, dtype=torch.float32)
phase = torch.tensor(phase, dtype=torch.float32)

wavelength_set_m = [400e-9, 600e-9]
ps_locs_m = [[0, 0, 10e-3], [0, 0, 20e-3]]
ps_locs_m = torch.tensor(ps_locs_m, dtype=torch.float32)

sd = deepcopy(self.init_dict)
sd["radial_symmetry"] = radial_flag

# Repeat calculation on cuda and cpu
for device in device_list:
amp = amp.to(device=device)
phase = phase.to(device=device)
ps_locs_m = ps_locs_m.to(device=device)

sd["diffraction_engine"] = "fresnel"
fresnel = PointSpreadFunction(**sd)
fres_int, fres_phase = fresnel(
amp,
phase,
wavelength_set_m,
ps_locs_m,
aperture,
normalize_to_aperture=True,
)

sd["diffraction_engine"] = "asm"
asm = PointSpreadFunction(**sd)
asm_int, asm_phase = asm(
amp,
phase,
wavelength_set_m,
ps_locs_m,
aperture,
normalize_to_aperture=True,
)

assert (
fres_int.shape
== fres_phase.shape
== asm_int.shape
== asm_phase.shape
)
assert list(fres_int.shape[-2:]) == self.init_dict["out_size"]
assert fres_int.shape[0] == amp.shape[0]
assert fres_int.shape[1] == len(ps_locs_m)
assert fres_int.shape[2] == len(wavelength_set_m)

fres_int = reshape_dat(fres_int)
asm_int = reshape_dat(asm_int)
fres_phase = self.rewrap_phase(reshape_dat(fres_phase))
asm_phase = self.rewrap_phase(reshape_dat(asm_phase))

mse_int = np.mean((fres_int - asm_int) ** 2)
mse_phase = np.mean((fres_phase - asm_phase) ** 2)
assert mse_int < 1e-8

fig, ax = plt.subplots(4, 4)
for c in range(4):
ax[0, c].imshow(fres_int[c])
ax[1, c].imshow(asm_int[c])
ax[2, c].imshow(fres_phase[c], cmap="hsv")
ax[3, c].imshow(asm_phase[c], cmap="hsv")
for axi in ax.flatten():
axi.axis("off")
script_dir = os.path.dirname(os.path.abspath(__file__))
out_dir = os.path.join(script_dir, "out")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
plot_path = os.path.join(
out_dir, f"psf_radial_{radial_flag}_{device}.png"
)
plt.savefig(plot_path)
plt.close()


class Test_Propagation:
@pytest.fixture(autouse=True)
def initialize(self, shared_init):
self.init_dict, self.get_lens, self.rewrap_phase = shared_init
del self.init_dict["diffraction_engine"]

def test_forward(self):
#device_list = ["cuda"] if torch.cuda.is_available else []
device_list = []
device_list.append("cpu")

for radial_flag in [True, False]:
amp, phase, aperture = self.get_lens(radial_symmetry=radial_flag)
amp = torch.tensor(amp, dtype=torch.float32)
phase = torch.tensor(phase, dtype=torch.float32)
wavelength_set_m = [400e-9, 600e-9]

sd = deepcopy(self.init_dict)
sd["radial_symmetry"] = radial_flag

# Repeat calculation on cuda and cpu
for device in device_list:
amp = amp.to(device=device)
phase = phase.to(device=device)

asm = ASMPropagation(**sd)
asm_amp, asm_phase = asm(amp, phase, wavelength_set_m)

fresnel = FresnelPropagation(**sd)
fres_amp, fres_phase = fresnel(amp, phase, wavelength_set_m)

assert (
asm_amp.shape
== asm_phase.shape
== fres_amp.shape
== fres_phase.shape
)
assert list(fres_amp.shape[-2:]) == self.init_dict["out_size"]
assert fres_amp.shape[0] == amp.shape[0]
assert fres_amp.shape[1] == len(wavelength_set_m)

fres_amp = fres_amp.cpu().numpy()
asm_amp = asm_amp.cpu().numpy()
fres_phase = self.rewrap_phase(fres_phase)
asm_phase = self.rewrap_phase(asm_phase)

fig, ax = plt.subplots(4, 2)
for c in range(2):
ax[0, c].imshow(fres_amp[0, c])
ax[1, c].imshow(asm_amp[0, c])
ax[2, c].imshow(fres_phase[0, c], cmap="hsv")
ax[3, c].imshow(asm_phase[0, c], cmap="hsv")
for axi in ax.flatten():
axi.axis("off")
script_dir = os.path.dirname(os.path.abspath(__file__))
out_dir = os.path.join(script_dir, "out")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
plot_path = os.path.join(
out_dir, f"prop_radial_{radial_flag}_{device}.png"
)
plt.savefig(plot_path)
plt.close()

0 comments on commit 4240619

Please sign in to comment.