Skip to content

Commit 5ff7e3f

Browse files
committed
cache report
1 parent 8c5e127 commit 5ff7e3f

10 files changed

+33
-12
lines changed

src/column.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl std::fmt::Display for ColumnError {
5656
impl std::error::Error for ColumnError {}
5757

5858
/// A ColumnReporter tells bedder how to report a column in the output.
59-
pub trait ColumnReporter {
59+
pub trait ColumnReporter: std::fmt::Debug {
6060
/// report the name, e.g. `count` for the INFO field of the VCF
6161
fn name(&self) -> &str;
6262
/// report the type, for the INFO field of the VCF

src/intersection.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use crate::chrom_ordering::Chromosome;
2+
use crate::report::Report;
3+
use crate::report_options::ReportOptions;
24
use crate::string::String;
35
use hashbrown::HashMap;
46
use parking_lot::Mutex;
@@ -57,6 +59,9 @@ impl Clone for Intersection {
5759
pub struct Intersections {
5860
pub base_interval: Arc<Mutex<Position>>,
5961
pub overlapping: Vec<Intersection>,
62+
63+
// report cache, keyed by report_options. Use Arc Mutex for interior mutability.
64+
pub(crate) cached_report: Arc<Mutex<Option<(ReportOptions, Arc<Report>)>>>,
6065
}
6166

6267
struct ReverseOrderPosition {
@@ -200,6 +205,7 @@ impl<P: PositionedIterator> Iterator for IntersectionIterator<'_, P> {
200205
Some(Ok(Intersections {
201206
base_interval: Arc::clone(&base_interval),
202207
overlapping: overlapping_positions,
208+
cached_report: Arc::new(Mutex::new(None)),
203209
}))
204210
}
205211
}

src/intersections.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@ fn inverse(base_interval: &Position, overlaps: &[Intersection]) -> Vec<Arc<Mutex
3232
impl Intersections {
3333
// TODO: brenth cache report with report options as this is expensive.
3434
// add a new field to Intersections to cache the report.
35-
pub fn report(&self, report_options: &ReportOptions) -> Report {
35+
pub fn report(&self, report_options: &ReportOptions) -> Arc<Report> {
36+
let mut cached_report = self
37+
.cached_report
38+
.try_lock()
39+
.expect("failed to lock cached report");
40+
if let Some((ro, report)) = &*cached_report {
41+
if ro == report_options {
42+
return report.clone();
43+
}
44+
}
3645
// usually the result is [query, [[b1-part, b1-part2, ...], [b2-part, ...]]]],
3746
// in fact, usually, there's only a single b and a single interval from b, so it's:
3847
// [query, [[b1-part]]]
@@ -131,7 +140,9 @@ impl Intersections {
131140
);
132141
}
133142

134-
Report::new(result)
143+
let report = Arc::new(Report::new(result));
144+
*cached_report = Some((report_options.clone(), report.clone()));
145+
report
135146
}
136147

137148
fn satisfies_requirements(

src/main.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ struct Args {
3030
#[arg(
3131
help = "columns to output (format: name:type:description:number:value_parser)",
3232
short = 'c',
33-
long = "columns",
34-
required = true
33+
long = "columns"
3534
)]
3635
columns: Vec<String>,
3736

src/py.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ impl From<crate::report::ReportFragment> for PyReportFragment {
172172
#[pyclass]
173173
#[derive(Clone, Debug)]
174174
pub struct PyReport {
175-
inner: crate::report::Report,
175+
inner: Arc<crate::report::Report>,
176176
}
177177

178178
#[pymethods]
@@ -181,14 +181,14 @@ impl PyReport {
181181
/// Create a new empty Report
182182
fn new() -> Self {
183183
PyReport {
184-
inner: crate::report::Report::new(Vec::new()),
184+
inner: Arc::new(crate::report::Report::new(Vec::new())),
185185
}
186186
}
187187

188188
/// Add a report fragment to the collection
189189
fn add_fragment(&mut self, frag: PyReportFragment) -> PyResult<()> {
190190
let inner_frags = vec![frag.inner];
191-
self.inner = crate::report::Report::new(inner_frags);
191+
self.inner = Arc::new(crate::report::Report::new(inner_frags));
192192
Ok(())
193193
}
194194

@@ -483,7 +483,7 @@ impl PyIntersections {
483483
}
484484

485485
/// Report intersections based on specified modes and requirements
486-
fn report(&self) -> PyResult<PyReport> {
486+
fn report(&mut self) -> PyResult<PyReport> {
487487
Ok(PyReport {
488488
inner: self.inner.report(&self.report_options),
489489
})

src/py_test.rs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ mod tests {
2525
interval: Arc::new(Mutex::new(overlap_pos)),
2626
id: 0,
2727
}],
28+
cached_report: Arc::new(Mutex::new(None)),
2829
};
2930

3031
PyIntersections::new(intersections, Arc::new(ReportOptions::default()))

src/report.rs

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::sync::Arc;
55
pub struct ReportFragment {
66
pub a: Option<Arc<Mutex<Position>>>,
77
pub b: Vec<Arc<Mutex<Position>>>,
8+
// id is the file index of the source
89
pub id: usize,
910
}
1011

src/report_options.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ pub enum OverlapAmount {
122122
Fraction(f32),
123123
}
124124

125+
impl Eq for OverlapAmount {}
126+
125127
impl FromStr for OverlapAmount {
126128
type Err = ParseFloatError;
127129

@@ -179,7 +181,7 @@ impl Default for &OverlapAmount {
179181
/// .b_requirements(OverlapAmount::Fraction(0.5))
180182
/// .build();
181183
/// ```
182-
#[derive(Debug, Clone, Default)]
184+
#[derive(Debug, Clone, Default, PartialEq, Eq)]
183185
pub struct ReportOptions {
184186
pub a_mode: IntersectionMode,
185187
pub b_mode: IntersectionMode,

src/tests/parse_intersections.rs

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub(crate) fn parse_intersections(input: &str) -> Intersections {
7272
Intersections {
7373
base_interval: Arc::new(Mutex::new(Position::Interval(base_interval))),
7474
overlapping: intersections,
75+
cached_report: Arc::new(Mutex::new(None)),
7576
}
7677
}
7778

src/writer.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ impl Writer {
210210
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
211211
}
212212

213-
fn update<T: ColumnReporter>(
213+
fn apply_report<T: ColumnReporter>(
214214
format: Format,
215215
intersections: &mut Intersections,
216216
report_options: Arc<ReportOptions>,
@@ -319,7 +319,7 @@ impl Writer {
319319
}
320320
};
321321

322-
Self::update(self.format, intersections, report_options, crs)?;
322+
Self::apply_report(self.format, intersections, report_options, crs)?;
323323

324324
// Use the current value of the Arc without modifying it
325325
if let Position::Bed(ref bed_record) = *intersections.base_interval.lock() {

0 commit comments

Comments
 (0)