Skip to content

Commit 60fc1c8

Browse files
Add support for alpha, linewidths and edgecolors to agent_portrayal (#2468)
* add some more support to agent_portrayal * Update mpl_space_drawing.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update mpl_space_drawing.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update mpl_space_drawing.py * some additional docs * Update test_components_matplotlib.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 46ff9fb commit 60fc1c8

File tree

3 files changed

+122
-9
lines changed

3 files changed

+122
-9
lines changed

mesa/visualization/components/matplotlib_components.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def make_mpl_space_component(
3838
the functions for drawing the various spaces for further details.
3939
4040
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
41-
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
42-
41+
"size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
4342
4443
Returns:
4544
function: A function that creates a SpaceMatplotlib component

mesa/visualization/mpl_space_drawing.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
"""
88

9+
import contextlib
910
import itertools
1011
import math
1112
import warnings
@@ -61,10 +62,19 @@ def collect_agent_data(
6162
zorder: default zorder
6263
6364
agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
64-
and marker (marker style)
65+
marker (marker style), alpha, linewidths, and edgecolors
6566
6667
"""
67-
arguments = {"s": [], "c": [], "marker": [], "zorder": [], "loc": []}
68+
arguments = {
69+
"s": [],
70+
"c": [],
71+
"marker": [],
72+
"zorder": [],
73+
"loc": [],
74+
"alpha": [],
75+
"edgecolors": [],
76+
"linewidths": [],
77+
}
6878

6979
for agent in space.agents:
7080
portray = agent_portrayal(agent)
@@ -78,6 +88,10 @@ def collect_agent_data(
7888
arguments["marker"].append(portray.pop("marker", marker))
7989
arguments["zorder"].append(portray.pop("zorder", zorder))
8090

91+
for entry in ["alpha", "edgecolors", "linewidths"]:
92+
with contextlib.suppress(KeyError):
93+
arguments[entry].append(portray.pop(entry))
94+
8195
if len(portray) > 0:
8296
ignored_fields = list(portray.keys())
8397
msg = ", ".join(ignored_fields)
@@ -110,24 +124,32 @@ def draw_space(
110124
Returns the Axes object with the plot drawn onto it.
111125
112126
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
113-
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
127+
"size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
114128
115129
"""
116130
if ax is None:
117131
fig, ax = plt.subplots()
118132

119133
# https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
120134
match space:
121-
case mesa.space._Grid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid():
122-
draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
135+
# order matters here given the class structure of old-style grid spaces
123136
case HexSingleGrid() | HexMultiGrid() | mesa.experimental.cell_space.HexGrid():
124137
draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
138+
case (
139+
mesa.space.SingleGrid()
140+
| OrthogonalMooreGrid()
141+
| OrthogonalVonNeumannGrid()
142+
| mesa.space.MultiGrid()
143+
):
144+
draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
125145
case mesa.space.NetworkGrid() | mesa.experimental.cell_space.Network():
126146
draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
127147
case mesa.space.ContinuousSpace():
128148
draw_continuous_space(space, agent_portrayal, ax=ax)
129149
case VoronoiGrid():
130150
draw_voronoi_grid(space, agent_portrayal, ax=ax)
151+
case _:
152+
raise ValueError(f"Unknown space type: {type(space)}")
131153

132154
if propertylayer_portrayal:
133155
draw_property_layers(space, propertylayer_portrayal, ax=ax)
@@ -543,11 +565,24 @@ def _scatter(ax: Axes, arguments, **kwargs):
543565
marker = arguments.pop("marker")
544566
zorder = arguments.pop("zorder")
545567

568+
# we check if edgecolor, linewidth, and alpha are specified
569+
# at the agent level, if not, we remove them from the arguments dict
570+
# and fallback to the default value in ax.scatter / use what is passed via **kwargs
571+
for entry in ["edgecolors", "linewidths", "alpha"]:
572+
if len(arguments[entry]) == 0:
573+
arguments.pop(entry)
574+
else:
575+
if entry in kwargs:
576+
raise ValueError(
577+
f"{entry} is specified in agent portrayal and via plotting kwargs, you can only use one or the other"
578+
)
579+
546580
for mark in np.unique(marker):
547581
mark_mask = marker == mark
548582
for z_order in np.unique(zorder):
549583
zorder_mask = z_order == zorder
550584
logical = mark_mask & zorder_mask
585+
551586
ax.scatter(
552587
x[logical],
553588
y[logical],

tests/test_components_matplotlib.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
draw_network,
2424
draw_orthogonal_grid,
2525
draw_property_layers,
26+
draw_space,
2627
draw_voronoi_grid,
2728
)
2829

@@ -41,6 +42,84 @@ def agent_portrayal(agent):
4142
}
4243

4344

45+
def test_draw_space():
46+
"""Test draw_space helper method."""
47+
import networkx as nx
48+
49+
def my_portrayal(agent):
50+
"""Simple portrayal of an agent.
51+
52+
Args:
53+
agent (Agent): The agent to portray
54+
55+
"""
56+
return {
57+
"s": 10,
58+
"c": "tab:blue",
59+
"marker": "s" if (agent.unique_id % 2) == 0 else "o",
60+
"alpha": 0.5,
61+
"linewidths": 1,
62+
"linecolors": "tab:orange",
63+
}
64+
65+
# draw space for hexgrid
66+
model = Model(seed=42)
67+
grid = HexSingleGrid(10, 10, torus=True)
68+
for _ in range(10):
69+
agent = Agent(model)
70+
grid.move_to_empty(agent)
71+
72+
fig, ax = plt.subplots()
73+
draw_space(grid, my_portrayal, ax=ax)
74+
75+
# draw space for voroinoi
76+
model = Model(seed=42)
77+
coordinates = model.rng.random((100, 2)) * 10
78+
grid = VoronoiGrid(coordinates.tolist(), random=model.random, capacity=1)
79+
for _ in range(10):
80+
agent = CellAgent(model)
81+
agent.cell = grid.select_random_empty_cell()
82+
83+
fig, ax = plt.subplots()
84+
draw_space(grid, my_portrayal, ax=ax)
85+
86+
# draw orthogonal grid
87+
model = Model(seed=42)
88+
grid = OrthogonalMooreGrid((10, 10), torus=True, random=model.random, capacity=1)
89+
for _ in range(10):
90+
agent = CellAgent(model)
91+
agent.cell = grid.select_random_empty_cell()
92+
fig, ax = plt.subplots()
93+
draw_space(grid, my_portrayal, ax=ax)
94+
95+
# draw network
96+
n = 10
97+
m = 20
98+
seed = 42
99+
graph = nx.gnm_random_graph(n, m, seed=seed)
100+
101+
model = Model(seed=42)
102+
grid = NetworkGrid(graph)
103+
for _ in range(10):
104+
agent = Agent(model)
105+
pos = agent.random.randint(0, len(graph.nodes) - 1)
106+
grid.place_agent(agent, pos)
107+
fig, ax = plt.subplots()
108+
draw_space(grid, my_portrayal, ax=ax)
109+
110+
# draw continuous space
111+
model = Model(seed=42)
112+
space = ContinuousSpace(10, 10, torus=True)
113+
for _ in range(10):
114+
x = model.random.random() * 10
115+
y = model.random.random() * 10
116+
agent = Agent(model)
117+
space.place_agent(agent, (x, y))
118+
119+
fig, ax = plt.subplots()
120+
draw_space(space, my_portrayal, ax=ax)
121+
122+
44123
def test_draw_hex_grid():
45124
"""Test drawing hexgrids."""
46125
model = Model(seed=42)
@@ -62,8 +141,8 @@ def test_draw_hex_grid():
62141
draw_hex_grid(grid, agent_portrayal, ax)
63142

64143

65-
def test_draw_voroinoi_grid():
66-
"""Test drawing voroinoi grids."""
144+
def test_draw_voronoi_grid():
145+
"""Test drawing voronoi grids."""
67146
model = Model(seed=42)
68147

69148
coordinates = model.rng.random((100, 2)) * 10

0 commit comments

Comments
 (0)