Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
EwoutH committed Oct 3, 2024
1 parent cb4a238 commit 1b1dedd
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,45 +77,35 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
"""
for layer_name, portrayal in propertylayer_portrayal.items():
layer = getattr(model, layer_name, None)
if layer is None or not isinstance(layer, PropertyLayer):
if not isinstance(layer, PropertyLayer):
continue

alpha = portrayal.get("alpha", 0.5)

# Convert boolean arrays to float
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
width, height = data.shape if space is None else (space.width, space.height)

if space and data.shape != (width, height):
warnings.warn(
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).",
UserWarning,
)

# Get portrayal properties, or use defaults
alpha = portrayal.get("alpha", 1)
vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))

if space is not None:
# Check if the dimensions align
if data.shape != (space.width, space.height):
warnings.warn(
f"Layer {layer_name} does not have the same dimensions as the space. "
"Skipping visualization.",
UserWarning,
)
continue
width, height = space.width, space.height
else:
width, height = data.shape

# Draw the layer
if "color" in portrayal:
# Use color with alpha scaling
color = portrayal["color"]
rgba_color = to_rgba(color) # Convert any color format to RGBA
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.zeros((*data.shape, 4))
rgba_data[..., :3] = rgba_color[:3] # RGB channels
rgba_data[..., 3] = normalized_data * alpha * rgba_color[3] # Alpha channel
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
im = ax.imshow(
rgba_data.transpose(1, 0, 2), # Transpose to (height, width, 4)
rgba_data.transpose(1, 0, 2),
extent=(0, width, 0, height),
origin="lower",
)
else:
# Use colormap
elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
Expand All @@ -129,6 +119,10 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
origin="lower",
)
plt.colorbar(im, ax=ax, label=layer_name)
else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)


def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
Expand Down

0 comments on commit 1b1dedd

Please sign in to comment.