Skip to content

Commit

Permalink
Added utils to load/save crns, implement __call__ dunder for evaluato…
Browse files Browse the repository at this point in the history
…rs and ReactionNetwork
  • Loading branch information
SantiagoMorandi committed Oct 20, 2024
1 parent bd07592 commit 7612791
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/care/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# import juliacall # to avoid segfaults
from pickle import load, dump

from care.constants import *
from care.crn.surface import Surface
Expand All @@ -8,6 +9,14 @@
from care.crn.utils.blueprint import gen_blueprint
from care.crn.templates.dissociation import dissociate

def load_crn(file_path: str) -> ReactionNetwork:
with open(file_path, "rb") as f:
return load(f)

def save_crn(crn: ReactionNetwork, file_path: str):
with open(file_path, "wb") as f:
dump(crn, f)

__all__ = [
"Intermediate",
"ElementaryReaction",
Expand All @@ -16,5 +25,7 @@
"ReactionMechanism",
"gen_blueprint",
"dissociate",
"load_crn",
"save_crn",
]
__version__ = "1.0.0"
3 changes: 3 additions & 0 deletions src/care/crn/reaction_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def __add__(self, other):
return new_net
else:
raise TypeError("Can only add ReactionNetwork to ReactionNetwork")

def __call__(self, *args, **kwargs):
return self.run_microkinetic(*args, **kwargs)

@classmethod
def from_dict(cls, net_dict: dict):
Expand Down
8 changes: 8 additions & 0 deletions src/care/evaluators/energy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class IntermediateEnergyEstimator(ABC):
@abstractmethod
def __init__(self):
pass

@abstractmethod
def __call__(self, inter: Intermediate, surf: Optional[Surface] = None) -> None:
self.eval(inter, surf)

@property
@abstractmethod
Expand Down Expand Up @@ -54,6 +58,10 @@ class ReactionEnergyEstimator(ABC):
@abstractmethod
def __init__(self):
pass

@abstractmethod
def __call__(self, reaction: ElementaryReaction) -> None:
self.eval(reaction)

@property
@abstractmethod
Expand Down
12 changes: 12 additions & 0 deletions src/care/evaluators/gamenet_uq/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def __init__(
self.db = connect(dft_db_path)
else:
self.db = None

def __call__(self,
intermediate: Intermediate,
**kwargs) -> None:
if isinstance(intermediate, Intermediate):
self.eval(intermediate, **kwargs)
else:
return NotImplementedError("Input must be an Intermediate object.")

def adsorbate_domain(self):
return ADSORBATE_ELEMS
Expand Down Expand Up @@ -244,6 +252,10 @@ def __repr__(self) -> str:
return (
f"GAME-Net-UQ ({int(self.num_params/1000)}K params, device={self.device})"
)

def __call__(self,
rxn: ElementaryReaction) -> None:
self.eval(rxn)

def calc_reaction_energy(self, reaction: ElementaryReaction) -> None:
"""
Expand Down

0 comments on commit 7612791

Please sign in to comment.