Skip to content

Commit

Permalink
Improved interface for GAME-Net_UQ. Adds UQ as selection criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverlovros committed Sep 13, 2024
1 parent 3ca7fd3 commit fa60c8f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
77 changes: 59 additions & 18 deletions src/care/evaluators/gamenet_uq/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
surface: Surface,
dft_db_path: Optional[str] = None,
num_configs: int = 3,
use_uq: bool = False,
**kwargs
):
"""Interface for GAME-Net-UQ for intermediates.
Expand All @@ -51,6 +52,7 @@ def __init__(
self.model.to(self.device)
self.surface = surface
self.num_configs = num_configs
self.use_uq = use_uq

if dft_db_path is not None and os.path.exists(dft_db_path):
self.db = connect(dft_db_path)
Expand Down Expand Up @@ -178,7 +180,7 @@ def eval(
if self.db and self.retrieve_from_db(intermediate):
return
else:
adsorptions = place_adsorbate(intermediate, self.surface)[:self.num_configs]
adsorptions = place_adsorbate(intermediate, self.surface)
ads_config_dict = {}
for i, adsorption in enumerate(adsorptions):
with no_grad():
Expand All @@ -196,8 +198,10 @@ def eval(
y.scale * self.model.y_scale_params["std"]
).item() # eV

# Select best configurations based on the mean (mu) or the uncertainty (s)
criterion = 's' if self.use_uq else 'mu'
ads_config_dict = dict(
sorted(ads_config_dict.items(), key=lambda item: item[1]["mu"])
sorted(ads_config_dict.items(), key=lambda item: item[1][criterion])[:self.num_configs]
)
intermediate.ads_configs = ads_config_dict
else:
Expand All @@ -217,12 +221,14 @@ def __init__(
T: float = None,
pH: float = None,
U: float = None,
use_uq: bool = False,
**kwargs
):
self.model = load_model(MODEL_PATH)
self.device = "cuda" if cuda.is_available() else "cpu"
self.model.to(self.device)
self.num_params = sum(p.numel() for p in self.model.parameters())
self.use_uq = use_uq
self.intermediates = intermediates
self.pH = pH
self.U = U
Expand All @@ -247,6 +253,7 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
reaction (ElementaryReaction): Elementary reaction.
"""
mu_is, var_is, mu_fs, var_fs = 0.0, 0.0, 0.0, 0.0
criterion = 's' if self.use_uq else 'mu'
if reaction.r_type == "PCET":
for reactant in reaction.reactants:
if reactant.is_surface or isinstance(reactant, Electron):
Expand All @@ -266,8 +273,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
H2O_gas.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_is += abs(reaction.stoic[reactant.code]) * e_min_config
var_is += abs(reaction.stoic[reactant.code]) * s_min_config**2
Expand All @@ -286,8 +297,13 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
H2_gas.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]

if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_is += abs(reaction.stoic[reactant.code]) * e_min_config
var_is += abs(reaction.stoic[reactant.code]) * s_min_config**2
Expand All @@ -304,8 +320,13 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
reactant.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]

if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_is += abs(reaction.stoic[reactant.code]) * e_min_config
var_is += abs(reaction.stoic[reactant.code]) * s_min_config**2
Expand All @@ -327,8 +348,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
H2O_gas.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_fs += abs(reaction.stoic[product.code]) * e_min_config
var_fs += abs(reaction.stoic[product.code]) * s_min_config**2
Expand All @@ -347,8 +372,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
H2_gas.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_fs += abs(reaction.stoic[product.code]) * e_min_config
var_fs += abs(reaction.stoic[product.code]) * s_min_config**2
Expand All @@ -365,8 +394,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
product.code
].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_fs += abs(reaction.stoic[product.code]) * e_min_config
var_fs += abs(reaction.stoic[product.code]) * s_min_config**2
Expand Down Expand Up @@ -395,8 +428,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
config["s"]
for config in self.intermediates[reactant.code].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_is += abs(reaction.stoic[reactant.code]) * e_min_config
var_is += abs(reaction.stoic[reactant.code]) * s_min_config**2
Expand All @@ -411,8 +448,12 @@ def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
config["s"]
for config in self.intermediates[product.code].ads_configs.values()
]
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
if criterion == 'mu':
e_min_config = min(energy_list)
s_min_config = s_list[energy_list.index(e_min_config)]
else:
s_min_config = min(s_list)
e_min_config = energy_list[s_list.index(s_min_config)]

mu_fs += abs(reaction.stoic[product.code]) * e_min_config
var_fs += abs(reaction.stoic[product.code]) * s_min_config**2
Expand Down
1 change: 1 addition & 0 deletions src/care/scripts/example_eval.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ max_steps = 2 # max iterations allowed during relaxation
num_configs = 1 # number of screened configurations for each adsorbate/surface pair
size = 'small' # MACE-MP-0 size
dtype = 'float32' # MACE-MP-0
use_uq = false # Criterion for GAME-Net-UQ: if True, selects best configurations based on uncertainty (s), else based on lowest energy (mu)

[reaction_args]

0 comments on commit fa60c8f

Please sign in to comment.