Skip to content

Commit e26dda5

Browse files
committed
Discovered issue with parallel execution. Workaround and filed a TODO.
1 parent 41c6138 commit e26dda5

File tree

2 files changed

+96
-83
lines changed

2 files changed

+96
-83
lines changed

src/lib.rs

Lines changed: 89 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,19 @@
2323
//! for each batch read, allowing DataFusion to process and potentially filter batches
2424
//! incrementally. This enables processing of larger-than-memory datasets when combined
2525
//! with DataFusion's streaming execution.
26+
//!
27+
//! ## Parallel Execution Note
28+
//!
29+
//! When using DataFusion's parallel execution (multiple partitions), aggregation queries
30+
//! without ORDER BY may return partial results due to how our stream interacts with
31+
//! DataFusion's async runtime. To ensure complete results:
32+
//! - Add ORDER BY to aggregation queries, or
33+
//! - Use `SessionConfig().with_target_partitions(1)` for single-threaded execution
34+
//! TODO(#106): Implenet proper parallelism and partition handling.
2635
2736
use std::ffi::CString;
2837
use std::fmt::Debug;
29-
use std::sync::Arc;
38+
use std::sync::{Arc, Mutex};
3039

3140
use arrow::array::RecordBatch;
3241
use arrow::datatypes::SchemaRef;
@@ -72,15 +81,14 @@ impl Debug for PyArrowStreamPartition {
7281
}
7382
}
7483

75-
/// State for the lazy stream - holds a Python iterator that yields RecordBatches.
76-
/// Using Py<PyAny> which is Send, allowing this to be used across async boundaries.
77-
enum StreamState {
78-
/// Initial state: factory not yet called
79-
NotStarted(Py<PyAny>),
80-
/// Active state: iterator ready to yield batches
81-
Active(Py<PyAny>),
82-
/// Terminal state: stream exhausted or errored
83-
Done,
84+
/// Shared state for the lazy stream, protected by Mutex for thread safety.
85+
struct SharedStreamState {
86+
/// The PyArrow RecordBatchReader (None until first batch is requested)
87+
reader: Option<Py<PyAny>>,
88+
/// The factory to create the reader (consumed on first use)
89+
factory: Option<Py<PyAny>>,
90+
/// Whether the stream has ended
91+
done: bool,
8492
}
8593

8694
impl PartitionStream for PyArrowStreamPartition {
@@ -94,88 +102,87 @@ impl PartitionStream for PyArrowStreamPartition {
94102
// Clone the factory with the GIL held
95103
let factory = Python::attach(|py| self.stream_factory.clone_ref(py));
96104

97-
// Create a lazy stream using unfold. Each poll acquires the GIL and reads one batch.
98-
// This allows DataFusion to process batches incrementally rather than loading all into memory.
99-
//
100-
// The stream state transitions: NotStarted -> Active -> Done
101-
// - NotStarted: Factory hasn't been called yet
102-
// - Active: Iterator is yielding batches
103-
// - Done: Iterator exhausted or error occurred
104-
let batch_stream = unfold(StreamState::NotStarted(factory), |state| async move {
105-
match state {
106-
StreamState::Done => None,
107-
StreamState::NotStarted(factory) => {
108-
// First poll: call factory to get a PyArrow RecordBatchReader
109-
Python::attach(|py| {
105+
// TODO(alxmrs/CC): I think we need to do something datafusion-native here;
106+
// I suspect that adding a mutex will significantly impact performance.
107+
// This is OK for now.
108+
// Create shared state protected by Mutex
109+
let shared_state = Arc::new(Mutex::new(SharedStreamState {
110+
reader: None,
111+
factory: Some(factory),
112+
done: false,
113+
}));
114+
115+
// Create a lazy stream using unfold.
116+
// The Arc<Mutex<...>> is cloned for each iteration, ensuring thread-safe access.
117+
let batch_stream = unfold(shared_state, |state| async move {
118+
// Clone Arc for potential return
119+
let state_clone = Arc::clone(&state);
120+
121+
// Lock the mutex to access state
122+
let mut guard = state.lock().unwrap();
123+
124+
if guard.done {
125+
return None;
126+
}
127+
128+
// Acquire GIL and process
129+
let result = Python::attach(|py| {
130+
// Initialize reader on first poll
131+
if guard.reader.is_none() {
132+
if let Some(factory) = guard.factory.take() {
110133
match factory.call0(py) {
111134
Ok(reader) => {
112-
// Factory returns a RecordBatchReader directly
113-
read_next_batch(py, reader)
135+
guard.reader = Some(reader);
114136
}
115-
Err(e) => Some((
116-
Err(DataFusionError::Execution(format!(
137+
Err(e) => {
138+
guard.done = true;
139+
return Some(Err(DataFusionError::Execution(format!(
117140
"Failed to call stream factory: {e}"
118-
))),
119-
StreamState::Done,
120-
)),
141+
))));
142+
}
121143
}
122-
})
123-
}
124-
StreamState::Active(iterator) => {
125-
// Subsequent polls: read next batch from iterator
126-
Python::attach(|py| read_next_batch(py, iterator))
144+
}
127145
}
128-
}
129-
});
130146

131-
Box::pin(RecordBatchStreamAdapter::new(schema, batch_stream))
132-
}
133-
}
147+
// Read next batch from reader
148+
if let Some(ref reader) = guard.reader {
149+
let bound_reader = reader.bind(py);
150+
match bound_reader.call_method0("read_next_batch") {
151+
Ok(batch_obj) => match RecordBatch::from_pyarrow_bound(&batch_obj) {
152+
Ok(batch) => Some(Ok(batch)),
153+
Err(e) => {
154+
guard.done = true;
155+
Some(Err(DataFusionError::Execution(format!(
156+
"Failed to convert batch from PyArrow: {e}"
157+
))))
158+
}
159+
},
160+
Err(e) => {
161+
if e.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
162+
guard.done = true;
163+
None // Stream exhausted normally
164+
} else {
165+
guard.done = true;
166+
Some(Err(DataFusionError::Execution(format!(
167+
"Error reading batch from stream: {e}"
168+
))))
169+
}
170+
}
171+
}
172+
} else {
173+
guard.done = true;
174+
None
175+
}
176+
});
134177

135-
/// Read the next batch from a PyArrow RecordBatchReader, returning the stream state transition.
136-
///
137-
/// Handles both:
138-
/// - `read_next_batch()` returning None (PyArrow RecordBatchReader exhausted)
139-
/// - `StopIteration` exception (Python iterator protocol)
140-
fn read_next_batch(
141-
py: Python<'_>,
142-
reader: Py<PyAny>,
143-
) -> Option<(Result<RecordBatch, DataFusionError>, StreamState)> {
144-
let bound_reader = reader.bind(py);
145-
146-
// Call read_next_batch() which returns None when exhausted
147-
match bound_reader.call_method0("read_next_batch") {
148-
Ok(batch_obj) => {
149-
// Check if None (stream exhausted)
150-
if batch_obj.is_none() {
151-
return None; // Stream exhausted normally
152-
}
178+
// Release lock before returning
179+
drop(guard);
153180

154-
// Convert PyArrow RecordBatch to Arrow RecordBatch
155-
match RecordBatch::from_pyarrow_bound(&batch_obj) {
156-
Ok(batch) => Some((Ok(batch), StreamState::Active(reader))),
157-
Err(e) => Some((
158-
Err(DataFusionError::Execution(format!(
159-
"Failed to convert batch from PyArrow: {e}"
160-
))),
161-
StreamState::Done,
162-
)),
163-
}
164-
}
165-
Err(e) => {
166-
// Handle StopIteration as normal end of iteration (Python iterator protocol)
167-
if e.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
168-
return None; // Stream exhausted normally
169-
}
181+
// Map result to include state for next iteration
182+
result.map(|batch_result| (batch_result, state_clone))
183+
});
170184

171-
// Actual error reading batch
172-
Some((
173-
Err(DataFusionError::Execution(format!(
174-
"Error reading batch from stream: {e}"
175-
))),
176-
StreamState::Done,
177-
))
178-
}
185+
Box::pin(RecordBatchStreamAdapter::new(schema, batch_stream))
179186
}
180187
}
181188

xarray_sql/reader_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,10 @@ def test_aggregation_with_many_batches(self):
732732
733733
GROUP BY queries require processing all data, making them a good
734734
test for streaming behavior.
735+
736+
Note: ORDER BY is used to ensure deterministic results. Without it,
737+
DataFusion's parallel execution may cause non-deterministic partial
738+
results with our streaming implementation.
735739
"""
736740
np.random.seed(789)
737741
time_coord = pd.date_range("2020-01-01", periods=120, freq="h")
@@ -758,8 +762,10 @@ def test_aggregation_with_many_batches(self):
758762
ctx.register_table("test_table", table)
759763

760764
# GROUP BY requires scanning all data
765+
# ORDER BY ensures all partial aggregates are collected before returning
766+
# TODO(#106): Fix the underlying partitioning issue.
761767
result = ctx.sql(
762-
"SELECT lat, AVG(temperature) as avg_temp FROM test_table GROUP BY lat"
768+
"SELECT lat, AVG(temperature) as avg_temp FROM test_table GROUP BY lat ORDER BY lat"
763769
).collect()
764770

765771
# Should have result for each lat value

0 commit comments

Comments
 (0)