Skip to content

Commit fca2788

Browse files
[dev] : support stackbar; change default color
1 parent 1a44659 commit fca2788

File tree

8 files changed

+256
-80
lines changed

8 files changed

+256
-80
lines changed

examples/bar1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# 随机生成一个 5 x 7 的数据
99
a = 5
10-
b = 7
10+
b = 3
1111
y = np.random.randint(10, 100, size=(a, b))
1212

1313
# 初始化一个对象

examples/stackbar_graph.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import sys
2+
import os
3+
4+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5+
6+
7+
import paperplotlib as ppl
8+
import unittest
9+
import numpy as np
10+
11+
12+
import paperplotlib as ppl
13+
14+
# 创建一个堆叠条形图
15+
stackbar_graph = ppl.StackBarGraph()
16+
17+
# 设置数据
18+
labels = [
19+
"alloc_migration_target",
20+
"try_to_migrate",
21+
"move_to_new_folio",
22+
"folio_add_lru" "remove_migration_ptes",
23+
"migrate_folio_done",
24+
]
25+
26+
# move_to_new_folio(30.311% 545/1798)
27+
# try_to_migrate(19.188% 345/1798)
28+
# migrate_folio_done(8.732% 157/1798)
29+
# alloc_migration_target(8.676% 156/1798)
30+
# remove_migration_ptes(8.287% 149/1798)
31+
# folio_add_lru(6.897% 124/1798)
32+
33+
# 百分比数据(按图片中的顺序)
34+
percentages = [8.676, 19.188, 30.311, 6.897, 8.287, 8.732]
35+
36+
# 绘制堆叠条形图
37+
stackbar_graph.direction = "horizontal"
38+
stackbar_graph.thinkness = 0.2
39+
stackbar_graph.plot(percentages, labels, name="migrate_page_batch")
40+
stackbar_graph.adjust_legend(alignment=3, font_size=20)
41+
# 保存图像
42+
stackbar_graph.save("stackbar_graph.png")

paperplotlib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

22
from .line_graph import LineGraph
33
from .bar_graph import BarGraph
4+
from .stackbar_graph import StackBarGraph
45
from .color import *

paperplotlib/bar_graph.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,7 @@ def plot_2d(
6262
self.ax.set_xticks(range(group_len), group_names)
6363
self.ax.tick_params(bottom=False)
6464

65-
# https://matplotlib.org/stable/api/legend_api.html#module-matplotlib.legend
66-
self.legend = self.ax.legend(
67-
column_names,
68-
loc="upper center", # 居中置顶
69-
ncols=column_len, # 横向排布
70-
bbox_to_anchor=(0.5, 1.15), # 置于图外侧
71-
handlelength=1, # 图例长宽, 修改为正方形
72-
handleheight=1, # 图例长宽, 修改为正方形
73-
handletextpad=0.4, # 缩短文字和图例的间距
74-
fontsize="x-small" if column_len >= 7 else "medium", # 图例文字大小
75-
)
65+
self.set_label_legend(column_names)
7666

7767
def add_line(self, y: int, line_style="-"):
7868
self.ax.axhline(y, linestyle=line_style, linewidth=0.5, color="black")

paperplotlib/color.css

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
color: #acbcdb #012790;
33
}
44

5-
/* Demystifying CXL Memory with Genuine CXL-Ready Systems and Devices */
5+
/* Harnessing Integrated CPU-GPU System Memory for HPC: a first look into Grace Hopper */
66
.style-1 {
7-
color: #0070c0;
8-
color: #ffc000 #0070c0;
9-
color: #ffc000 #4472c4 #63c0cf;
10-
color: #ffc000 #ed7d31 #0070c0 #63c0cf;
11-
color: #ffc000 #ed7d31 #5b9bd5 #4bbabd #358d8f;
12-
color: #ffc000 #ed7d31 #5b9bd5 #0070c0 #4bbabd #358d8f;
13-
color: #ffc000 #ed7d31 #2f5597 #5b9bd5 #0070c0 #4bbabd #358d8f;
7+
color: #4184f3;
8+
color: #4184f3 #e94234;
9+
color: #4184f3 #e94234 #fabb03;
10+
color: #4184f3 #e94234 #fabb03 #33a852;
11+
color: #4184f3 #e94234 #fabb03 #33a852 #46bdc5;
1412
}
1513

1614
.style-2 {

paperplotlib/graph.py

Lines changed: 114 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
11
import matplotlib
2+
23
# 非交互式 GUI 使用 Agg
3-
matplotlib.use('Agg')
4+
matplotlib.use("Agg")
45
import matplotlib.pyplot as plt
56
import os
67
from typing import List, Optional, Union, Tuple
78
from matplotlib.font_manager import FontProperties
89
import matplotlib.font_manager as fm
910

11+
1012
class Graph:
1113
"""
1214
图表
1315
"""
1416

15-
def __init__(self, style_id: int = 1) -> None:
17+
def __init__(self, style_id: int = 1, subplots: Tuple[int, int] = None) -> None:
1618
self.style_id = style_id
17-
self.fig = plt.figure(figsize=(8, 4))
18-
self.ax = self.fig.add_subplot(111)
19-
19+
self.subplots = subplots
20+
21+
if subplots is not None:
22+
self.fig, self.ax = plt.subplots(subplots[0], subplots[1])
23+
else:
24+
self.fig = plt.figure(figsize=(8, 4))
25+
self.ax = self.fig.add_subplot(111)
26+
2027
# -- configuation --
21-
self.x_label: Optional[str] = None # x轴标签
22-
self.y_label: Optional[str] = None # y轴标签
23-
self.width_picture = False # 是否是宽图
28+
self.x_label: Optional[str] = None # x轴标签
29+
self.y_label: Optional[str] = None # y轴标签
30+
self.width_picture = False # 是否是宽图
2431
self.grid = "y" # 网格线 x | y | xy | None
2532
self.y_lim: Optional[Tuple[float, float]] = None
2633

@@ -33,25 +40,29 @@ def __init__(self, style_id: int = 1) -> None:
3340
# 保存图片
3441
self.dpi = 300
3542
self.bbox_inches = "tight" # 适当上下左右留白
36-
37-
self.title: Optional[str] = None # 图表标题
38-
39-
font_path = f'{os.path.dirname(__file__)}/consola-1.ttf'
43+
44+
self.title: Optional[str] = None # 图表标题
45+
46+
font_path = f"{os.path.dirname(__file__)}/consola-1.ttf"
4047
fm.fontManager.addfont(font_path)
41-
plt.rcParams['font.family'] = 'Consolas'
42-
# self.font_family = "Consolas" # 字体
43-
# 如果是 linux
44-
# if os.name == 'posix':
45-
# self.font_family = consolas_font.get_name()
46-
# self.colors: Optional[List[str]] = None
47-
48-
def plot(self, x_data: List[float], y_data: List[float]): # pragma: no cover
48+
plt.rcParams["font.family"] = "Consolas"
49+
50+
# legend
51+
self.legend_labels = None
52+
self.legend_loc = None
53+
self.legend_bbox_to_anchor = None
54+
self.legend_ncols = None
55+
self.legend_font_size = 'medium'
56+
57+
def plot(self, x_data: List[float], y_data: List[float]): # pragma: no cover
4958
"""
5059
填入数据
5160
"""
5261
raise NotImplementedError("请在子类中实现此方法")
5362

54-
def plot_2d(self, y_data: List[List[float]], group_names: List[str], column_names: List[str], emphasize_index: int = -1): # pragma: no cover
63+
def plot_2d(
64+
self, y_data: List[List[float]], group_names: List[str], column_names: List[str], emphasize_index: int = -1
65+
): # pragma: no cover
5566
"""
5667
绘制二维柱状图
5768
@@ -62,9 +73,9 @@ def plot_2d(self, y_data: List[List[float]], group_names: List[str], column_name
6273
"""
6374
raise NotImplementedError("请在子类中实现此方法")
6475

65-
def _create_graph(self): # pragma: no cover
76+
def _create_graph(self): # pragma: no cover
6677
self._check_config()
67-
78+
6879
if self.width_picture:
6980
self.fig.set_size_inches(16, 4)
7081
self.ax.set_xlabel(self.x_label)
@@ -88,30 +99,102 @@ def _create_graph(self): # pragma: no cover
8899
alpha=self.grid_alpha,
89100
)
90101
self.ax.set_axisbelow(True)
91-
102+
92103
if self.y_lim is not None:
93104
self.ax.set_ylim(self.y_lim)
94105

95106
if self.title is not None:
96-
self.fig.text(0.5, -0.02, self.title, ha='center', fontsize=14, weight='bold')
107+
self.fig.text(0.5, -0.02, self.title, ha="center", fontsize=14, weight="bold")
108+
109+
if self.legend_labels is not None:
110+
self.legend = self.ax.legend(
111+
self.legend_labels,
112+
loc=self.legend_loc, # 居中置顶
113+
ncols=self.legend_ncols, # 横向排布
114+
bbox_to_anchor=self.legend_bbox_to_anchor, # 置于图外侧
115+
handlelength=1, # 图例长宽, 修改为正方形
116+
handleheight=1, # 图例长宽, 修改为正方形
117+
handletextpad=0.4, # 缩短文字和图例的间距
118+
fontsize=self.legend_font_size, # 图例文字大小
119+
)
97120

98-
def _adjust_graph(self):
99-
'''
121+
def adjust_graph(self):
122+
"""
100123
子类中可以重写该函数来调整图表
101-
'''
124+
"""
102125

103126
def save(self, path: str = "result.png"):
104127
"""
105128
保存图片
106129
"""
107130
self._create_graph()
108-
self._adjust_graph()
131+
self.adjust_graph()
132+
plt.tight_layout()
109133
plt.savefig(path, dpi=self.dpi, bbox_inches=self.bbox_inches)
110-
print(f"保存成功:{path}")
134+
print(f"save picture in {path}")
111135

112136
def _check_config(self):
113137
"""
114138
检查配置的属性是否设置的合理
115139
"""
116140
assert self.grid in ["x", "y", "xy", None], "grid 参数值只能是 x | y | xy | None"
117-
assert self.width_picture in [True, False], "width_picture 参数值只能是 True | False"
141+
assert self.width_picture in [True, False], "width_picture 参数值只能是 True | False"
142+
143+
def set_label_legend(self, column_names, position: str = "w", alignment: str = "-"):
144+
"""
145+
position should be 1 or 2 of 'wasd'
146+
147+
w/a/s/d means up/left/down/right in keyboard
148+
"""
149+
self.legend_labels = column_names
150+
151+
# https://matplotlib.org/stable/api/legend_api.html#module-matplotlib.legend
152+
self.legend_loc = "upper center"
153+
self.legend_bbox_to_anchor = (0.5, 1.15)
154+
self.legend_ncols = len(column_names)
155+
156+
# legend position
157+
158+
# bbox_to_anchor
159+
# x:相对于图形的水平位置(通常 0 到 1 的值,1 表示图的最右边).
160+
# y:相对于图形的垂直位置(通常 0 到 1 的值,1 表示图的顶部)
161+
if position == "w":
162+
self.legend_loc = "upper center"
163+
self.legend_bbox_to_anchor = (0.5, 1.15)
164+
165+
elif position == "d":
166+
self.legend_loc = "upper left"
167+
self.legend_bbox_to_anchor = (1.05, 1)
168+
169+
elif position == "wd":
170+
self.legend_loc = "upper right"
171+
self.legend_bbox_to_anchor = None
172+
173+
# legend alignment
174+
if alignment == "-":
175+
self.legend_ncols = len(column_names)
176+
elif alignment == "|":
177+
self.legend_ncols = 1
178+
elif type(alignment) == int:
179+
self.legend_ncols = alignment
180+
181+
def adjust_legend(self, position: str = None, alignment: str = None, bbox_to_anchor: Tuple[float, float] = None, font_size: int = None):
182+
183+
if position:
184+
self.legend_loc(position)
185+
186+
if alignment:
187+
if alignment == "-":
188+
self.legend_ncols = len(self.legend_labels)
189+
elif alignment == "|":
190+
self.legend_ncols = 1
191+
elif type(alignment) == int:
192+
self.legend_ncols = alignment
193+
else:
194+
raise ValueError("alignment should be int or '-' or '|'")
195+
196+
if bbox_to_anchor:
197+
self.legend_bbox_to_anchor = bbox_to_anchor
198+
199+
if font_size:
200+
self.legend_font_size = font_size

paperplotlib/line_graph.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,46 +9,28 @@ def __init__(self) -> None:
99
super().__init__()
1010
self.grid = "xy"
1111
# https://matplotlib.org/stable/api/markers_api.html
12-
self.all_markers = [
13-
"o",
14-
"^",
15-
"x",
16-
"s",
17-
"D",
18-
"*",
19-
"+",
20-
"v",
21-
"p",
22-
"P",
23-
"h",
24-
"H",
25-
"1",
26-
"2",
27-
"3",
28-
"4",
29-
"X"
30-
]
31-
self.disable_x_ticks = False # 是否禁用 x 轴刻度
32-
self.disable_points = False # 是否禁用点
12+
self.all_markers = ["o", "^", "x", "s", "D", "*", "+", "v", "p", "P", "h", "H", "1", "2", "3", "4", "X"]
13+
self.disable_x_ticks = False # 是否禁用 x 轴刻度
14+
self.disable_points = False # 是否禁用点
3315
self.line_width = 1.5
3416

35-
def _adjust_graph(self):
36-
17+
def adjust_graph(self):
18+
3719
# 线条宽度
3820
for line in self.ax.get_lines():
3921
line.set_linewidth(self.line_width)
40-
22+
4123
if self.disable_x_ticks:
4224
self.ax.xaxis.set_major_locator(ticker.NullLocator())
4325
if self.disable_points:
4426
# 修改 marker
4527
for line in self.ax.get_lines():
46-
line.set_marker('')
28+
line.set_marker("")
4729
# 修改图例中的 marker
4830
legend = self.ax.get_legend()
4931
if legend is not None: # 检查图例是否存在
5032
for legend_line in legend.get_lines():
51-
legend_line.set_marker('')
33+
legend_line.set_marker("")
5234

5335
def plot(self, x_data: List[int], y_data: List[float]):
5436
"""
@@ -62,7 +44,9 @@ def plot(self, x_data: List[int], y_data: List[float]):
6244
if x_data is None:
6345
x_data = range(len(y_data))
6446
x_ticks = range(len(x_data))
65-
self.ax.plot(x_ticks, y_data, linewidth=2, marker="o", markersize=5, color=COLOR.get_colors(1, self.style_id)[0])
47+
self.ax.plot(
48+
x_ticks, y_data, linewidth=2, marker="o", markersize=5, color=COLOR.get_colors(1, self.style_id)[0]
49+
)
6650
# x 轴标签和位置的映射
6751
self.ax.set_xticks(x_ticks, x_data)
6852

@@ -72,8 +56,8 @@ def plot_2d(self, x_data: List[int], y_data: List[List[float]], line_names: List
7256
x_data = range(len(y_data[0]))
7357
x_ticks = range(len(x_data))
7458
line_number = len(line_names)
75-
76-
assert line_number <= len(self.all_markers), "markers 数量不足"
59+
60+
assert line_number <= len(self.all_markers), "markers 数量不足"
7761
markers = self.all_markers[:line_number]
7862
colors = COLOR.get_colors(line_number, self.style_id, emphasize_index)
7963
for i, y in enumerate(y_data):

0 commit comments

Comments
 (0)