diff --git a/mesa/experimental/__init__.py b/mesa/experimental/__init__.py index 964dc5d19a3..71d6a5dd53f 100644 --- a/mesa/experimental/__init__.py +++ b/mesa/experimental/__init__.py @@ -1 +1 @@ -from .jupyter_viz import JupyterViz, make_text # noqa +from .jupyter_viz import JupyterViz, make_text, prepare_matplotlib_space # noqa diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index a7eefb51624..8fc31039f15 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -219,7 +219,19 @@ def change_handler(value, name=name): raise ValueError(f"{input_type} is not a supported input type") -def make_space(model, agent_portrayal): +def prepare_matplotlib_space(drawer): + def wrapped_drawer(model, agent_portrayal): + space_fig = Figure() + space_ax = space_fig.subplots() + drawer(model, agent_portrayal, space_fig, space_ax) + space_ax.set_axis_off() + solara.FigureMatplotlib(space_fig) + + return wrapped_drawer + + +@prepare_matplotlib_space +def make_space(model, agent_portrayal, space_fig, space_ax): def portray(g): x = [] y = [] @@ -248,14 +260,10 @@ def portray(g): out["c"] = c return out - space_fig = Figure() - space_ax = space_fig.subplots() if isinstance(model.grid, mesa.space.NetworkGrid): _draw_network_grid(model, space_ax, agent_portrayal) else: space_ax.scatter(**portray(model.grid)) - space_ax.set_axis_off() - solara.FigureMatplotlib(space_fig) def _draw_network_grid(model, space_ax, agent_portrayal):