Skip to content

Commit

Permalink
Implement defining a single "color" aside from a colormap
Browse files Browse the repository at this point in the history
  • Loading branch information
EwoutH committed Oct 3, 2024
1 parent 8b34f9e commit cb4a238
Showing 1 changed file with 79 additions and 115 deletions.
194 changes: 79 additions & 115 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Matplotlib based solara components for visualization MESA spaces and plots."""

import warnings

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import solara
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import LinearSegmentedColormap, to_rgba
from matplotlib.figure import Figure

import mesa
Expand Down Expand Up @@ -44,7 +48,7 @@ def SpaceMatplotlib(
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
space = model.space
space = getattr(model, "space", None)

if isinstance(space, mesa.space._Grid):
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
Expand All @@ -56,6 +60,8 @@ def SpaceMatplotlib(
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, VoronoiGrid):
_draw_voronoi(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
elif space is None and propertylayer_portrayal:
draw_property_layers(space_ax, space, propertylayer_portrayal, model)

solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

Expand All @@ -74,112 +80,55 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
if layer is None or not isinstance(layer, PropertyLayer):
continue

cmap = portrayal.get("colormap", "viridis")
alpha = portrayal.get("alpha", 0.5)
vmin = portrayal.get("vmin", np.min(layer.data))
vmax = portrayal.get("vmax", np.max(layer.data))

if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)

im = ax.imshow(
layer.data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
extent=(0, space.width, 0, space.height),
origin="lower",
)
plt.colorbar(im, ax=ax, label=layer_name)


import matplotlib.pyplot as plt
import numpy as np
import solara


def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
"""Create a Matplotlib-based space visualization component.
Args:
agent_portrayal (function): Function to portray agents
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications
Returns:
function: A function that creates a SpaceMatplotlib component
"""
if agent_portrayal is None:

def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceMatplotlib(model):
return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal)

return MakeSpaceMatplotlib


@solara.component
def SpaceMatplotlib(
model,
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
):
update_counter.get()
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
space = model.space

if isinstance(space, mesa.space._Grid):
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(
space, space_ax, agent_portrayal, propertylayer_portrayal, model
)
elif isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, VoronoiGrid):
_draw_voronoi(space, space_ax, agent_portrayal, propertylayer_portrayal, model)

solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


def draw_property_layers(ax, space, propertylayer_portrayal, model):
"""Draw PropertyLayers on the given axes.
Args:
ax (matplotlib.axes.Axes): The axes to draw on.
space (mesa.space._Grid): The space containing the PropertyLayers.
propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
model (mesa.Model): The model instance.
"""
for layer_name, portrayal in propertylayer_portrayal.items():
layer = getattr(model, layer_name, None)
if layer is None or not isinstance(layer, PropertyLayer):
continue

cmap = portrayal.get("colormap", "viridis")
alpha = portrayal.get("alpha", 0.5)
vmin = portrayal.get("vmin", np.min(layer.data))
vmax = portrayal.get("vmax", np.max(layer.data))

if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)

im = ax.imshow(
layer.data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
extent=(0, space.width, 0, space.height),
origin="lower",
)
plt.colorbar(im, ax=ax, label=layer_name)
# Convert boolean arrays to float
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))

if space is not None:
# Check if the dimensions align
if data.shape != (space.width, space.height):
warnings.warn(
f"Layer {layer_name} does not have the same dimensions as the space. "
"Skipping visualization.",
UserWarning,
)
continue
width, height = space.width, space.height
else:
width, height = data.shape

if "color" in portrayal:
# Use color with alpha scaling
color = portrayal["color"]
rgba_color = to_rgba(color) # Convert any color format to RGBA
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.zeros((*data.shape, 4))
rgba_data[..., :3] = rgba_color[:3] # RGB channels
rgba_data[..., 3] = normalized_data * alpha * rgba_color[3] # Alpha channel
im = ax.imshow(
rgba_data.transpose(1, 0, 2), # Transpose to (height, width, 4)
extent=(0, width, 0, height),
origin="lower",
)
else:
# Use colormap
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
im = ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
extent=(0, width, 0, height),
origin="lower",
)
plt.colorbar(im, ax=ax, label=layer_name)


def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
Expand Down Expand Up @@ -257,15 +206,13 @@ def portray(space):

# This is matplotlib's default marker size
default_size = 20
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out
return {"x": x, "y": y, "s": s, "c": c, "m": m}

# Determine border style based on space.torus
border_style = "solid" if not space.torus else (0, (5, 10))
Expand Down Expand Up @@ -304,8 +251,6 @@ def portray(g):
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
# This is the default value for the marker size, which auto-scales
# according to the grid area.
out["s"] = s
if len(c) > 0:
out["c"] = c
Expand Down Expand Up @@ -334,18 +279,37 @@ def portray(g):
alpha=min(1, cell.properties[space.cell_coloring_property]),
c="red",
) # Plot filled polygon
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black


def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
"""Create a plotting function for a specified measure.
Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
Returns:
function: A function that creates a PlotMatplotlib component.
"""

def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103
def MakePlotMeasure(model):
return PlotMatplotlib(model, measure)

return MakePlotMeasure


@solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
"""Create a Matplotlib-based plot for a measure or measures.
Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
dependencies (list[any] | None): Optional dependencies for the plot.
Returns:
solara.FigureMatplotlib: A component for rendering the plot.
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
Expand All @@ -362,5 +326,5 @@ def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # no
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
solara.FigureMatplotlib(fig, dependencies=dependencies)

0 comments on commit cb4a238

Please sign in to comment.