Skip to content

Commit

Permalink
Update load module for absolute pathing and relative pathing
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Jan 27, 2024
1 parent 40c7b1d commit 88649b7
Show file tree
Hide file tree
Showing 4 changed files with 40,702 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model:
target: dflat.metasurface.optical_model.NeuralCells
ckpt_path: metasurface/ckpt/Nanocylinders_TiO2_U180H600_Medium/model.ckpt
relative_ckpt: True
params:
trainable_model: False
param_bounds:
Expand All @@ -24,12 +25,13 @@ trainer:
target: dflat.metasurface.trainer.Trainer_v1
data: dflat.metasurface.datasets.Nanocylinders_TiO2_U180nm_H600nm
ckpt_path: metasurface/ckpt/Nanocylinders_TiO2_U180H600_Medium/model.ckpt
relative_ckpt: True
params:
test_split: 0.10
learning_rate: .001
epochs: 6000
batch_size: 65536
checkpoint_every_n: 50
checkpoint_every_n: 100
update_figure_every_epoch: True
gradient_accumulation_steps: 1
cosine_anneal_warm_restart: False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ trainer:
target: dflat.metasurface.trainer.Trainer_v1
data: dflat.metasurface.datasets.Nanofins_TiO2_U350nm_H600nm
ckpt_path: metasurface/ckpt/Nanofins_TiO2_U350H600_Medium/model.ckpt
relative_ckpt: True
params:
test_split: 0.10
learning_rate: .001
epochs: 2000
epochs: 5000
batch_size: 65536
checkpoint_every_n: 50
update_figure_every_epoch: True
Expand Down
49 changes: 36 additions & 13 deletions dflat/metasurface/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def load_optical_model(config_path):
Returns:
nn.Module: Model with pretrained weights loaded if ckpt_path is specified in the config file.
"""
config_path = pkg_resources.resource_filename("dflat", config_path)
config = OmegaConf.load(config_path)
config = load_config_from_path(config_path)
ckpt_path = config.model["ckpt_path"]
optical_model = instantiate_from_config(config.model, ckpt_path, strict=True)
return optical_model
Expand All @@ -29,14 +28,17 @@ def load_trainer(config_path):
Returns:
object: Trainer
"""
config_path = pkg_resources.resource_filename("dflat", config_path)
print(f"Loading trainer from config path: {config_path}")

config = OmegaConf.load(config_path)
config = load_config_from_path(config_path)
config_model = config.model

config_trainer = config.trainer
ckpt_path = pkg_resources.resource_filename("dflat", config_trainer["ckpt_path"])

rel_ckpt = config_trainer["relative_ckpt"]
ckpt_path = config_trainer["ckpt_path"]
ckpt_path = (
ckpt_path
if not rel_ckpt
else pkg_resources.resource_filename("dflat", ckpt_path)
)
dataset = get_obj_from_str(config_trainer["data"])()

trainer = get_obj_from_str(config_trainer["target"])(
Expand All @@ -48,20 +50,41 @@ def load_trainer(config_path):
return trainer


def load_config_from_path(config_path):
try:
use_path = pkg_resources.resource_filename("dflat", config_path)
config = OmegaConf.load(use_path)
except Exception as e1:
try:
use_path = config_path
config = OmegaConf.load(use_path)
except Exception as e2:
print(f"Failed absolute path identification. Errors \n {e1} \n {e2}.")

return config


def instantiate_from_config(config_model, ckpt_path=None, strict=False):
assert "target" in config_model, "Expected key `target` to instantiate."
target_str = config_model["target"]
print(f"Target Module: {target_str}")
loaded_module = get_obj_from_str(target_str)(**config_model.get("params", dict()))

# Get model checkpoint
if ckpt_path is not None and ckpt_path != "None":
ckpt_path = pkg_resources.resource_filename("dflat", ckpt_path)
## Try and Except to handle relative pathing vs absolute pathing
try:
use_path = pkg_resources.resource_filename("dflat", ckpt_path)
sd = torch.load(use_path, map_location="cpu")["state_dict"]
except Exception as e1:
try:
use_path = ckpt_path
sd = torch.load(use_path, map_location="cpu")["state_dict"]
except Exception as e2:
print(f"Failed absolute path identification. Errors \n {e1} \n {e2}.")

print(
f"Target: {config_model['target']} Loading from checkpoint {ckpt_path} as strict={strict}"
f"Target: {config_model['target']} Loading from checkpoint {use_path} as strict={strict}"
)
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]

missing, unexpected = loaded_module.load_state_dict(sd, strict=strict)
print(
f"Restored {target_str} with {len(missing)} missing and {len(unexpected)} unexpected keys"
Expand Down
Loading

0 comments on commit 88649b7

Please sign in to comment.