Skip to content

Commit 57da64c

Browse files
committed
Add script to collect and plot benchmarks
1 parent a7196c8 commit 57da64c

File tree

3 files changed

+341
-53
lines changed

3 files changed

+341
-53
lines changed

tools/collect-benchmarks.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Combine snakemake per-sample benchmark files and create scatter plots
5+
relating sample sizes to the benchmark columns.
6+
7+
This helps to fine-tune the runtime and memory requirements for each rule,
8+
in order to scale up to very large dataset analyses.
9+
10+
Usage:
11+
python collect-benchmarks.py <analysis_dir>
12+
13+
for a given grenpipe analysis run.
14+
"""
15+
16+
import os
17+
import sys
18+
import matplotlib.pyplot as plt
19+
import matplotlib.ticker as ticker
20+
import yaml
21+
22+
23+
def get_samples_table_path(config_file_path):
24+
"""
25+
Load the config yaml file and retrieve the value of the 'samples-table' key
26+
under the 'data' entry.
27+
"""
28+
with open(config_file_path, 'r', encoding='utf-8') as file:
29+
try:
30+
content = yaml.safe_load(file)
31+
except yaml.YAMLError as exc:
32+
raise yaml.YAMLError(f"Error parsing config yaml file: {exc}")
33+
34+
if 'data' not in content or 'samples-table' not in content['data']:
35+
raise KeyError(
36+
"config yaml file does not contain the expected structure: 'data' -> 'samples-table'"
37+
)
38+
39+
return content['data']['samples-table']
40+
41+
42+
def parse_samples_table(samples_path):
43+
"""
44+
Read the samples table and compute the total size of fq1 + fq2 (if present).
45+
Return two dictionaries:
46+
1) sample_unit_size: keyed by 'sample-unit'
47+
2) sample_size: keyed by 'sample'
48+
"""
49+
sample_unit_size = {}
50+
sample_size = {}
51+
52+
with open(samples_path, "r", encoding="utf-8") as f:
53+
header = f.readline().rstrip("\n").split("\t")
54+
# We expect at least:
55+
# sample, unit, platform, fq1, fq2
56+
# Let's just find their indices dynamically
57+
col_idx = {colname: i for i, colname in enumerate(header)}
58+
59+
# We'll require at least 'sample' and 'fq1' columns
60+
for required_col in ["sample", "unit", "fq1", "fq2"]:
61+
if required_col not in col_idx:
62+
raise ValueError(f"Missing required column '{required_col}' in samples table!")
63+
64+
# Process each row
65+
for line in f:
66+
line = line.strip()
67+
if not line:
68+
continue
69+
row = line.split("\t")
70+
sample_name = row[col_idx["sample"]]
71+
unit_name = row[col_idx["unit"]]
72+
fq1_path = row[col_idx["fq1"]]
73+
fq2_path = row[col_idx["fq2"]] if col_idx["fq2"] < len(row) else None
74+
75+
# Compute file sizes (some might be empty)
76+
size_total = 0
77+
for fq_path in [fq1_path, fq2_path]:
78+
if fq_path: # non-empty
79+
if os.path.isfile(fq_path):
80+
size_total += os.path.getsize(fq_path)
81+
else:
82+
# It's possible the path doesn't exist or is invalid.
83+
# We'll just skip or treat it as size 0.
84+
pass
85+
86+
# Build "sample-unit" key if unit_name is present
87+
sample_unit_key = f"{sample_name}-{unit_name}"
88+
89+
# Store in sample_unit_size
90+
if sample_unit_key not in sample_unit_size:
91+
sample_unit_size[sample_unit_key] = 0
92+
sample_unit_size[sample_unit_key] += size_total
93+
94+
# Also store in sample_size (sum across all units)
95+
if sample_name not in sample_size:
96+
sample_size[sample_name] = 0
97+
sample_size[sample_name] += size_total
98+
99+
return sample_unit_size, sample_size
100+
101+
102+
def combine_benchmarks(top_level_dir, output_dir):
103+
"""
104+
Traverse top_level_dir, and for each subdirectory (including nested),
105+
combine all files into a single TSV in output_dir. The name of the output
106+
file is derived from the subdirectory path with slashes replaced by underscores.
107+
The first column is 'filename', added for each data row.
108+
Return a list of paths to the newly created combined TSV files.
109+
"""
110+
# Ensure the output directory exists
111+
os.makedirs(output_dir, exist_ok=True)
112+
113+
combined_files = [] # to store the paths to the combined TSVs
114+
115+
for root, dirs, files in os.walk(top_level_dir):
116+
print("Processing:", root)
117+
118+
# If there are no files in this directory, skip it
119+
if not files:
120+
continue
121+
122+
# Create a name for this subdirectory, relative to top_level_dir
123+
rel_dir = os.path.relpath(root, top_level_dir)
124+
125+
# Convert slashes to underscores
126+
if rel_dir == ".":
127+
subdir_name = "top_level"
128+
else:
129+
subdir_name = rel_dir.replace(os.sep, "_")
130+
131+
# Full path of the output file for this subdirectory
132+
output_file_name = subdir_name + ".tsv"
133+
output_file_path = os.path.join(output_dir, output_file_name)
134+
135+
# We'll open the output file once and append lines as we go
136+
did_write_header = False
137+
with open(output_file_path, "w", encoding="utf-8") as outfile:
138+
# Process each regular file in the current directory
139+
for f in files:
140+
file_path = os.path.join(root, f)
141+
if not os.path.isfile(file_path):
142+
continue # skip directories, symlinks, etc.
143+
144+
# Derive the base name by removing the extension
145+
base_name = os.path.splitext(f)[0]
146+
147+
with open(file_path, "r", encoding="utf-8") as infile:
148+
for i, line in enumerate(infile):
149+
line = line.rstrip("\n")
150+
if i == 0:
151+
# Header line
152+
if not did_write_header:
153+
# Prepend "filename" to the header
154+
final_header = "filename\t" + line
155+
outfile.write(final_header + "\n")
156+
did_write_header = True
157+
# If we already wrote a header, ignore subsequent
158+
else:
159+
# Data line => prepend the filename
160+
final_data = base_name + "\t" + line
161+
outfile.write(final_data + "\n")
162+
163+
# Check if we ended up writing anything
164+
if os.path.exists(output_file_path) and os.path.getsize(output_file_path) == 0:
165+
os.remove(output_file_path)
166+
else:
167+
print(f"Created: {output_file_path}")
168+
combined_files.append(output_file_path)
169+
170+
return combined_files
171+
172+
173+
def seconds_to_hms(value, pos):
174+
"""Convert 'value' (in seconds) to a short human-readable form."""
175+
hours = int(value // 3600)
176+
minutes = int((value % 3600) // 60)
177+
seconds = int(value % 60)
178+
value = round(value, 1)
179+
if hours > 0:
180+
return f"{hours}h{minutes}m"
181+
elif minutes > 0:
182+
return f"{minutes}m{seconds}s"
183+
else:
184+
return f"{value}s"
185+
186+
187+
def bytes_to_human_readable(num_bytes, pos=None):
188+
"""
189+
Convert a number of bytes into a human-readable string.
190+
E.g., 1048576 -> '1.0 MB'.
191+
"""
192+
# Use 1024-based units
193+
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
194+
if abs(num_bytes) < 1024.0:
195+
return f"{num_bytes:3.1f} {unit}"
196+
num_bytes /= 1024.0
197+
return f"{num_bytes:.1f} PB"
198+
199+
200+
def plot_rule_data(combined_file, sample_unit_size, sample_size):
201+
"""
202+
Given a combined benchmark TSV (one rule's data),
203+
parse the columns for 'filename', 'max_rss' and 's'.
204+
Determine the combined FASTQ size by matching 'filename'
205+
- If 'filename' has a dash, interpret it as 'sample-unit'
206+
- Otherwise interpret it as 'sample'
207+
Then create two scatter subplots:
208+
1) x-axis = combined FASTQ size, y-axis = max_rss
209+
2) x-axis = combined FASTQ size, y-axis = s
210+
Returns a matplotlib figure object (and data arrays if needed).
211+
"""
212+
213+
# We'll parse the file in a straightforward manner:
214+
# Read the header, find the columns of interest, then parse each line.
215+
x_vals_maxrss = []
216+
y_vals_maxrss = []
217+
x_vals_s = []
218+
y_vals_s = []
219+
220+
with open(combined_file, "r", encoding="utf-8") as f:
221+
header = f.readline().rstrip("\n").split("\t")
222+
col_idx = {name: i for i, name in enumerate(header)}
223+
224+
# We expect at least: filename, max_rss, s
225+
for required in ["filename", "max_rss", "s"]:
226+
if required not in col_idx:
227+
# If the file doesn't have these columns, skip plotting
228+
print("Warning: Wrong columns in", combined_file)
229+
return None
230+
231+
for line in f:
232+
row = line.rstrip("\n").split("\t")
233+
filename_val = row[col_idx["filename"]]
234+
maxrss_val_str = row[col_idx["max_rss"]]
235+
s_val_str = row[col_idx["s"]]
236+
237+
# If max_rss or s is blank (or not numeric), skip
238+
try:
239+
maxrss_val = float(maxrss_val_str)
240+
s_val = float(s_val_str)
241+
except ValueError:
242+
continue
243+
244+
# Above, we created the "filename" column of the summary table by using the rule name,
245+
# and the sample (and unit if present). As this script here has no knowledge
246+
# of our internal snakemake rules (and should not have, to keep it independent),
247+
# we hence need to take those appart again.
248+
# Figure out if either "sample" or "sample-unit" appear in the filename.
249+
# We hence only assume that we always use a dash in our rule wildcards.
250+
fastq_size = [ sample_unit_size[key] for key in sample_unit_size if key in filename_val ]
251+
if not fastq_size:
252+
fastq_size = [ sample_size[key] for key in sample_size if key in filename_val ]
253+
if not fastq_size:
254+
continue # no match found
255+
if len(fastq_size) > 1:
256+
print("Warning: Multiple entries for", filename_val, "in", combined_file)
257+
fastq_size = fastq_size[0]
258+
259+
x_vals_maxrss.append(fastq_size)
260+
y_vals_maxrss.append(maxrss_val * 1024 * 1024)
261+
x_vals_s.append(fastq_size)
262+
y_vals_s.append(s_val)
263+
264+
# If no data was collected, return None
265+
if not x_vals_maxrss:
266+
return None
267+
268+
# Make a figure with 2 subplots
269+
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
270+
271+
# Plot 1: s vs combined FASTQ size
272+
axes[0].scatter(x_vals_s, y_vals_s, alpha=0.7)
273+
axes[0].set_xlabel("Sample size")
274+
axes[0].set_ylabel("Runtime")
275+
axes[0].set_title("Sample FASTQ size vs runtime")
276+
axes[0].xaxis.set_major_formatter(ticker.FuncFormatter(bytes_to_human_readable))
277+
axes[0].yaxis.set_major_formatter(ticker.FuncFormatter(seconds_to_hms))
278+
279+
# Plot 2: max_rss vs combined FASTQ size
280+
axes[1].scatter(x_vals_maxrss, y_vals_maxrss, alpha=0.7)
281+
axes[1].set_xlabel("Sample size")
282+
axes[1].set_ylabel("Memory")
283+
axes[1].set_title("Sample FASTQ size vs memory")
284+
axes[1].xaxis.set_major_formatter(ticker.FuncFormatter(bytes_to_human_readable))
285+
axes[1].yaxis.set_major_formatter(ticker.FuncFormatter(bytes_to_human_readable))
286+
287+
fig.tight_layout()
288+
return fig
289+
290+
291+
def main():
292+
if len(sys.argv) != 2:
293+
print("Usage: python collect-benchmarks.py <analysis_dir>")
294+
sys.exit(1)
295+
296+
# Get the paths as needed
297+
analysis_dir = sys.argv[1]
298+
benchmarks_dir = os.path.join(analysis_dir, "benchmarks")
299+
output_dir = os.path.join(analysis_dir, "benchmarks-summary")
300+
301+
# Check that they are in order
302+
if not os.path.isdir(analysis_dir):
303+
print(f"Error: '{analysis_dir}' is not a valid directory.")
304+
sys.exit(1)
305+
if not os.path.isdir(benchmarks_dir):
306+
print(
307+
f"Error: Benchmarks directory '{benchmarks_dir}' is not a valid directory.",
308+
"Did you run grenepipe here before?"
309+
)
310+
sys.exit(1)
311+
312+
# 0) Read the config file
313+
samples_path = get_samples_table_path(os.path.join( analysis_dir, "config.yaml" ))
314+
315+
# 1) Read & parse the samples table to get FASTQ sizes
316+
sample_unit_size, sample_size = parse_samples_table(samples_path)
317+
print("Parsed sample FASTQ sizes:")
318+
print(" sample_unit_size:", dict(list(sample_unit_size.items())[:5]), "...")
319+
print(" sample_size:", dict(list(sample_size.items())[:5]), "...")
320+
321+
# 2) Combine benchmarks
322+
combined_files = combine_benchmarks(benchmarks_dir, output_dir)
323+
324+
# 3) For each combined file, create scatter plots
325+
for cf in combined_files:
326+
fig = plot_rule_data(cf, sample_unit_size, sample_size)
327+
if fig is not None:
328+
# Name the plot based on the TSV filename,
329+
# e.g. "ruleA_subdir.tsv" -> "ruleA_subdir_plots.png"
330+
base_name = os.path.splitext(os.path.basename(cf))[0]
331+
plot_path = os.path.join(output_dir, base_name + "_plots.png")
332+
fig.savefig(plot_path, dpi=150)
333+
plt.close(fig) # free memory
334+
print(f"Created scatter plot: {plot_path}")
335+
else:
336+
print(f"Skipping plot for {cf} (missing columns or no data).")
337+
338+
339+
if __name__ == "__main__":
340+
main()

tools/summarize-benchmarks.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)