Skip to content

Commit

Permalink
move more from shim to jinja
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Oct 28, 2024
1 parent 0e09804 commit 530e100
Show file tree
Hide file tree
Showing 20 changed files with 898 additions and 1,015 deletions.
52 changes: 29 additions & 23 deletions autotest/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

from autotest.conftest import get_project_root_path
from flopy.mf6.utils.codegen import make_all, make_targets
from flopy.mf6.utils.codegen.context import Context
from flopy.mf6.utils.codegen.dfn import Dfn
from flopy.mf6.utils.codegen.make import make_all, make_targets

PROJ_ROOT = get_project_root_path()
MF6_PATH = PROJ_ROOT / "flopy" / "mf6"
Expand All @@ -17,34 +17,40 @@

@pytest.mark.parametrize("dfn_name", DFN_NAMES)
def test_dfn_load(dfn_name):
dfn_path = DFN_PATH / f"{dfn_name}.dfn"

common_path = DFN_PATH / "common.dfn"
with open(common_path, "r") as f:
common, _ = Dfn._load(f)

with open(dfn_path, "r") as f:
dfn = Dfn.load(f, name=Dfn.Name(*dfn_name.split("-")), common=common)
if dfn_name in ["sln-ems", "exg-gwfprt", "exg-gwfgwe", "exg-gwfgwt"]:
assert not any(dfn)
else:
assert any(dfn)
with (
open(DFN_PATH / "common.dfn", "r") as common_file,
open(DFN_PATH / f"{dfn_name}.dfn", "r") as dfn_file,
):
name = Dfn.Name.parse(dfn_name)
common, _ = Dfn._load(common_file)
dfn = Dfn.load(dfn_file, name=name, common=common)

if name in [
("sln", "ems"),
("exg", "gwfprt"),
("exg", "gwfgwe"),
("exg", "gwfgwt"),
]:
assert not any(dfn)
else:
assert any(dfn)


@pytest.mark.parametrize("dfn_name", DFN_NAMES)
def test_make_targets(dfn_name, function_tmpdir):
common_path = DFN_PATH / "common.dfn"
with open(common_path, "r") as f:
common, _ = Dfn._load(f)

with open(DFN_PATH / f"{dfn_name}.dfn", "r") as f:
dfn = Dfn.load(f, name=Dfn.Name(*dfn_name.split("-")), common=common)
with (
open(DFN_PATH / "common.dfn", "r") as common_file,
open(DFN_PATH / f"{dfn_name}.dfn", "r") as dfn_file,
):
name = Dfn.Name.parse(dfn_name)
common, _ = Dfn._load(common_file)
dfn = Dfn.load(dfn_file, name=name, common=common)

make_targets(dfn, function_tmpdir, verbose=True)

for name in Context.Name.from_dfn(dfn):
source_path = function_tmpdir / name.target
assert source_path.is_file()
assert all(
(function_tmpdir / name.target).is_file()
for name in Context.Name.from_dfn(dfn)
)


def test_make_all(function_tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
from jinja2 import Environment, PackageLoader

from flopy.mf6.utils.codegen.context import Context
from flopy.mf6.utils.codegen.dfn import Dfn, Dfns
from flopy.mf6.utils.codegen.ref import Ref, Refs
from flopy.mf6.utils.codegen.dfn import Dfn, Dfns, Ref, Refs

__all__ = ["make_targets", "make_all"]

_TEMPLATE_LOADER = PackageLoader("flopy", "mf6/utils/codegen/templates/")
_TEMPLATE_ENV = Environment(loader=_TEMPLATE_LOADER)
_TEMPLATE_NAME = "context.py.jinja"
_TEMPLATE = _TEMPLATE_ENV.get_template(_TEMPLATE_NAME)


def make_targets(dfn: Dfn, outdir: Path, verbose: bool = False):
"""Generate Python source file(s) from the given input definition."""

for context in Context.from_dfn(dfn):
target = outdir / context.name.target
name = context.name
target = outdir / name.target
template = _TEMPLATE_ENV.get_template(name.template)
with open(target, "w") as f:
f.write(_TEMPLATE.render(**context.render()))
f.write(template.render(**context.render()))
if verbose:
print(f"Wrote {target}")

Expand Down
59 changes: 24 additions & 35 deletions flopy/mf6/utils/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
Optional,
)

from flopy.mf6.utils.codegen.dfn import Dfn, Vars
from flopy.mf6.utils.codegen.ref import Ref
from flopy.mf6.utils.codegen.render import renderable
from flopy.mf6.utils.codegen.dfn import Dfn, Ref, Vars
from flopy.mf6.utils.codegen.renderable import renderable
from flopy.mf6.utils.codegen.shim import SHIM


Expand Down Expand Up @@ -39,7 +38,7 @@ class Context:

class Name(NamedTuple):
"""
Uniquely identifies an input context. A context
Uniquely identifies an input context. The name
consists of a left term and optional right term.
Notes
Expand All @@ -50,10 +49,11 @@ class Name(NamedTuple):
From the context name several other things are derived:
- a description of the context
- the input context class' name
- a description of the context class
- the name of the source file to write
- the template the context will populate
- the base class the context inherits from
- the name of the source file the context is in
- the name of the parent parameter in the context
class' `__init__` method, if it can have a parent
Expand All @@ -70,7 +70,6 @@ def title(self) -> str:
remains unique. The title is substituted into
the file name and class name.
"""

l, r = self
if self == ("sim", "nam"):
return "simulation"
Expand All @@ -82,7 +81,7 @@ def title(self) -> str:
return r
if l in ["sln", "exg"]:
return r
return f"{l}{r}"
return l + r

@property
def base(self) -> str:
Expand All @@ -99,6 +98,18 @@ def target(self) -> str:
"""The source file name to generate."""
return f"mf{self.title}.py"

@property
def template(self) -> str:
"""The template file to use."""
if self.base == "MFSimulationBase":
return "simulation.py.jinja"
elif self.base == "MFModel":
return "model.py.jinja"
elif self.base == "MFPackage":
if self.l == "exg":
return "exchange.py.jinja"
return "package.py.jinja"

@property
def description(self) -> str:
"""A description of the input context."""
Expand All @@ -109,29 +120,11 @@ def description(self) -> str:
elif self.base == "MFModel":
return f"Modflow{title} defines a {l.upper()} model."
elif self.base == "MFSimulationBase":
return """
MFSimulation is used to load, build, and/or save a MODFLOW 6 simulation.
A MFSimulation object must be created before creating any of the MODFLOW 6
model objects."""

def parent(self, ref: Optional[Ref] = None) -> Optional[str]:
"""
Return the name of the parent `__init__` method parameter,
or `None` if the context cannot have parents. Contexts can
have more than one possible parent, in which case the name
of the parameter is of the pattern `name1_or_..._or_nameN`.
"""
if ref:
return ref.parent
if self == ("sim", "nam"):
return None
elif (
self.l is None
or self.r is None
or self.l in ["sim", "exg", "sln"]
):
return "simulation"
return "model"
return (
"MFSimulation is used to load, build, and/or save a MODFLOW 6 simulation."
" A MFSimulation object must be created before creating any of the MODFLOW"
" 6 model objects."
)

@staticmethod
def from_dfn(dfn: Dfn) -> List["Context.Name"]:
Expand Down Expand Up @@ -172,7 +165,6 @@ def from_dfn(dfn: Dfn) -> List["Context.Name"]:
name: Name
vars: Vars
base: Optional[type] = None
parent: Optional[str] = None
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None

Expand All @@ -183,18 +175,15 @@ def from_dfn(cls, dfn: Dfn) -> Iterator["Context"]:
These are structured representations of input context classes.
Each input definition yields one or more input contexts.
"""

meta = dfn.meta.copy()
ref = Ref.from_dfn(dfn)
if ref:
meta["ref"] = ref

for name in Context.Name.from_dfn(dfn):
yield Context(
name=name,
vars=dfn.data,
base=name.base,
parent=name.parent(ref),
description=name.description,
meta=meta,
)
Loading

0 comments on commit 530e100

Please sign in to comment.