From 1b1dedd40ae3eb0a207581bf341825e40f0c91fd Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 3 Oct 2024 15:09:46 +0200 Subject: [PATCH] Cleanup --- mesa/visualization/components/matplotlib.py | 46 +++++++++------------ 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 79386d815dc..497183e9507 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -77,45 +77,35 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model): """ for layer_name, portrayal in propertylayer_portrayal.items(): layer = getattr(model, layer_name, None) - if layer is None or not isinstance(layer, PropertyLayer): + if not isinstance(layer, PropertyLayer): continue - alpha = portrayal.get("alpha", 0.5) - - # Convert boolean arrays to float data = layer.data.astype(float) if layer.data.dtype == bool else layer.data + width, height = data.shape if space is None else (space.width, space.height) + + if space and data.shape != (width, height): + warnings.warn( + f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).", + UserWarning, + ) + # Get portrayal properties, or use defaults + alpha = portrayal.get("alpha", 1) vmin = portrayal.get("vmin", np.min(data)) vmax = portrayal.get("vmax", np.max(data)) - if space is not None: - # Check if the dimensions align - if data.shape != (space.width, space.height): - warnings.warn( - f"Layer {layer_name} does not have the same dimensions as the space. " - "Skipping visualization.", - UserWarning, - ) - continue - width, height = space.width, space.height - else: - width, height = data.shape - + # Draw the layer if "color" in portrayal: - # Use color with alpha scaling - color = portrayal["color"] - rgba_color = to_rgba(color) # Convert any color format to RGBA + rgba_color = to_rgba(portrayal["color"]) normalized_data = (data - vmin) / (vmax - vmin) - rgba_data = np.zeros((*data.shape, 4)) - rgba_data[..., :3] = rgba_color[:3] # RGB channels - rgba_data[..., 3] = normalized_data * alpha * rgba_color[3] # Alpha channel + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha im = ax.imshow( - rgba_data.transpose(1, 0, 2), # Transpose to (height, width, 4) + rgba_data.transpose(1, 0, 2), extent=(0, width, 0, height), origin="lower", ) - else: - # Use colormap + elif "colormap" in portrayal: cmap = portrayal.get("colormap", "viridis") if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) @@ -129,6 +119,10 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model): origin="lower", ) plt.colorbar(im, ax=ax, label=layer_name) + else: + raise ValueError( + f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." + ) def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):