1
1
import matplotlib
2
+
2
3
# 非交互式 GUI 使用 Agg
3
- matplotlib .use (' Agg' )
4
+ matplotlib .use (" Agg" )
4
5
import matplotlib .pyplot as plt
5
6
import os
6
7
from typing import List , Optional , Union , Tuple
7
8
from matplotlib .font_manager import FontProperties
8
9
import matplotlib .font_manager as fm
9
10
11
+
10
12
class Graph :
11
13
"""
12
14
图表
13
15
"""
14
16
15
- def __init__ (self , style_id : int = 1 ) -> None :
17
+ def __init__ (self , style_id : int = 1 , subplots : Tuple [ int , int ] = None ) -> None :
16
18
self .style_id = style_id
17
- self .fig = plt .figure (figsize = (8 , 4 ))
18
- self .ax = self .fig .add_subplot (111 )
19
-
19
+ self .subplots = subplots
20
+
21
+ if subplots is not None :
22
+ self .fig , self .ax = plt .subplots (subplots [0 ], subplots [1 ])
23
+ else :
24
+ self .fig = plt .figure (figsize = (8 , 4 ))
25
+ self .ax = self .fig .add_subplot (111 )
26
+
20
27
# -- configuation --
21
- self .x_label : Optional [str ] = None # x轴标签
22
- self .y_label : Optional [str ] = None # y轴标签
23
- self .width_picture = False # 是否是宽图
28
+ self .x_label : Optional [str ] = None # x轴标签
29
+ self .y_label : Optional [str ] = None # y轴标签
30
+ self .width_picture = False # 是否是宽图
24
31
self .grid = "y" # 网格线 x | y | xy | None
25
32
self .y_lim : Optional [Tuple [float , float ]] = None
26
33
@@ -33,25 +40,29 @@ def __init__(self, style_id: int = 1) -> None:
33
40
# 保存图片
34
41
self .dpi = 300
35
42
self .bbox_inches = "tight" # 适当上下左右留白
36
-
37
- self .title : Optional [str ] = None # 图表标题
38
-
39
- font_path = f' { os .path .dirname (__file__ )} /consola-1.ttf'
43
+
44
+ self .title : Optional [str ] = None # 图表标题
45
+
46
+ font_path = f" { os .path .dirname (__file__ )} /consola-1.ttf"
40
47
fm .fontManager .addfont (font_path )
41
- plt .rcParams ['font.family' ] = 'Consolas'
42
- # self.font_family = "Consolas" # 字体
43
- # 如果是 linux
44
- # if os.name == 'posix':
45
- # self.font_family = consolas_font.get_name()
46
- # self.colors: Optional[List[str]] = None
47
-
48
- def plot (self , x_data : List [float ], y_data : List [float ]): # pragma: no cover
48
+ plt .rcParams ["font.family" ] = "Consolas"
49
+
50
+ # legend
51
+ self .legend_labels = None
52
+ self .legend_loc = None
53
+ self .legend_bbox_to_anchor = None
54
+ self .legend_ncols = None
55
+ self .legend_font_size = 'medium'
56
+
57
+ def plot (self , x_data : List [float ], y_data : List [float ]): # pragma: no cover
49
58
"""
50
59
填入数据
51
60
"""
52
61
raise NotImplementedError ("请在子类中实现此方法" )
53
62
54
- def plot_2d (self , y_data : List [List [float ]], group_names : List [str ], column_names : List [str ], emphasize_index : int = - 1 ): # pragma: no cover
63
+ def plot_2d (
64
+ self , y_data : List [List [float ]], group_names : List [str ], column_names : List [str ], emphasize_index : int = - 1
65
+ ): # pragma: no cover
55
66
"""
56
67
绘制二维柱状图
57
68
@@ -62,9 +73,9 @@ def plot_2d(self, y_data: List[List[float]], group_names: List[str], column_name
62
73
"""
63
74
raise NotImplementedError ("请在子类中实现此方法" )
64
75
65
- def _create_graph (self ): # pragma: no cover
76
+ def _create_graph (self ): # pragma: no cover
66
77
self ._check_config ()
67
-
78
+
68
79
if self .width_picture :
69
80
self .fig .set_size_inches (16 , 4 )
70
81
self .ax .set_xlabel (self .x_label )
@@ -88,30 +99,102 @@ def _create_graph(self): # pragma: no cover
88
99
alpha = self .grid_alpha ,
89
100
)
90
101
self .ax .set_axisbelow (True )
91
-
102
+
92
103
if self .y_lim is not None :
93
104
self .ax .set_ylim (self .y_lim )
94
105
95
106
if self .title is not None :
96
- self .fig .text (0.5 , - 0.02 , self .title , ha = 'center' , fontsize = 14 , weight = 'bold' )
107
+ self .fig .text (0.5 , - 0.02 , self .title , ha = "center" , fontsize = 14 , weight = "bold" )
108
+
109
+ if self .legend_labels is not None :
110
+ self .legend = self .ax .legend (
111
+ self .legend_labels ,
112
+ loc = self .legend_loc , # 居中置顶
113
+ ncols = self .legend_ncols , # 横向排布
114
+ bbox_to_anchor = self .legend_bbox_to_anchor , # 置于图外侧
115
+ handlelength = 1 , # 图例长宽, 修改为正方形
116
+ handleheight = 1 , # 图例长宽, 修改为正方形
117
+ handletextpad = 0.4 , # 缩短文字和图例的间距
118
+ fontsize = self .legend_font_size , # 图例文字大小
119
+ )
97
120
98
- def _adjust_graph (self ):
99
- '''
121
+ def adjust_graph (self ):
122
+ """
100
123
子类中可以重写该函数来调整图表
101
- '''
124
+ """
102
125
103
126
def save (self , path : str = "result.png" ):
104
127
"""
105
128
保存图片
106
129
"""
107
130
self ._create_graph ()
108
- self ._adjust_graph ()
131
+ self .adjust_graph ()
132
+ plt .tight_layout ()
109
133
plt .savefig (path , dpi = self .dpi , bbox_inches = self .bbox_inches )
110
- print (f"保存成功: { path } " )
134
+ print (f"save picture in { path } " )
111
135
112
136
def _check_config (self ):
113
137
"""
114
138
检查配置的属性是否设置的合理
115
139
"""
116
140
assert self .grid in ["x" , "y" , "xy" , None ], "grid 参数值只能是 x | y | xy | None"
117
- assert self .width_picture in [True , False ], "width_picture 参数值只能是 True | False"
141
+ assert self .width_picture in [True , False ], "width_picture 参数值只能是 True | False"
142
+
143
+ def set_label_legend (self , column_names , position : str = "w" , alignment : str = "-" ):
144
+ """
145
+ position should be 1 or 2 of 'wasd'
146
+
147
+ w/a/s/d means up/left/down/right in keyboard
148
+ """
149
+ self .legend_labels = column_names
150
+
151
+ # https://matplotlib.org/stable/api/legend_api.html#module-matplotlib.legend
152
+ self .legend_loc = "upper center"
153
+ self .legend_bbox_to_anchor = (0.5 , 1.15 )
154
+ self .legend_ncols = len (column_names )
155
+
156
+ # legend position
157
+
158
+ # bbox_to_anchor
159
+ # x:相对于图形的水平位置(通常 0 到 1 的值,1 表示图的最右边).
160
+ # y:相对于图形的垂直位置(通常 0 到 1 的值,1 表示图的顶部)
161
+ if position == "w" :
162
+ self .legend_loc = "upper center"
163
+ self .legend_bbox_to_anchor = (0.5 , 1.15 )
164
+
165
+ elif position == "d" :
166
+ self .legend_loc = "upper left"
167
+ self .legend_bbox_to_anchor = (1.05 , 1 )
168
+
169
+ elif position == "wd" :
170
+ self .legend_loc = "upper right"
171
+ self .legend_bbox_to_anchor = None
172
+
173
+ # legend alignment
174
+ if alignment == "-" :
175
+ self .legend_ncols = len (column_names )
176
+ elif alignment == "|" :
177
+ self .legend_ncols = 1
178
+ elif type (alignment ) == int :
179
+ self .legend_ncols = alignment
180
+
181
+ def adjust_legend (self , position : str = None , alignment : str = None , bbox_to_anchor : Tuple [float , float ] = None , font_size : int = None ):
182
+
183
+ if position :
184
+ self .legend_loc (position )
185
+
186
+ if alignment :
187
+ if alignment == "-" :
188
+ self .legend_ncols = len (self .legend_labels )
189
+ elif alignment == "|" :
190
+ self .legend_ncols = 1
191
+ elif type (alignment ) == int :
192
+ self .legend_ncols = alignment
193
+ else :
194
+ raise ValueError ("alignment should be int or '-' or '|'" )
195
+
196
+ if bbox_to_anchor :
197
+ self .legend_bbox_to_anchor = bbox_to_anchor
198
+
199
+ if font_size :
200
+ self .legend_font_size = font_size
0 commit comments