Skip to content

Commit b5cbe40

Browse files
committed
experiment: GIL-aware iterator reaching straddling rust and python.
1 parent 3e5ea89 commit b5cbe40

File tree

1 file changed

+98
-51
lines changed

1 file changed

+98
-51
lines changed

src/lib.rs

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
//! - **Reusable tables**: The factory pattern allows the same table to be queried
1818
//! multiple times, with fresh data streams created for each query.
1919
//!
20-
//! ## Memory Considerations
20+
//! ## Streaming Behavior
2121
//!
22-
//! Currently, all batches are loaded into memory when a query executes. This is
23-
//! due to Python GIL constraints that prevent true streaming with background threads.
24-
//! For very large datasets, consider using smaller chunks or processing in stages.
22+
//! Batches are read lazily one at a time during query execution. The GIL is acquired
23+
//! for each batch read, allowing DataFusion to process and potentially filter batches
24+
//! incrementally. This enables processing of larger-than-memory datasets when combined
25+
//! with DataFusion's streaming execution.
2526
2627
use std::ffi::CString;
2728
use std::fmt::Debug;
2829
use std::sync::Arc;
2930

3031
use arrow::array::RecordBatch;
3132
use arrow::datatypes::SchemaRef;
32-
use arrow::ffi_stream::ArrowArrayStreamReader;
3333
use arrow::pyarrow::FromPyArrow;
3434
use datafusion::catalog::streaming::StreamingTable;
3535
use datafusion::common::DataFusionError;
@@ -39,7 +39,7 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
3939
use datafusion::physical_plan::streaming::PartitionStream;
4040
use datafusion::physical_plan::SendableRecordBatchStream;
4141
use datafusion_ffi::table_provider::FFI_TableProvider;
42-
use futures::stream;
42+
use futures::stream::unfold;
4343
use pyo3::prelude::*;
4444
use pyo3::types::PyCapsule;
4545
use tokio::runtime::Handle;
@@ -72,6 +72,17 @@ impl Debug for PyArrowStreamPartition {
7272
}
7373
}
7474

75+
/// State for the lazy stream - holds a Python iterator that yields RecordBatches.
76+
/// Using Py<PyAny> which is Send, allowing this to be used across async boundaries.
77+
enum StreamState {
78+
/// Initial state: factory not yet called
79+
NotStarted(Py<PyAny>),
80+
/// Active state: iterator ready to yield batches
81+
Active(Py<PyAny>),
82+
/// Terminal state: stream exhausted or errored
83+
Done,
84+
}
85+
7586
impl PartitionStream for PyArrowStreamPartition {
7687
fn schema(&self) -> &SchemaRef {
7788
&self.schema
@@ -80,58 +91,94 @@ impl PartitionStream for PyArrowStreamPartition {
8091
fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
8192
let schema = Arc::clone(&self.schema);
8293

83-
// NOTE: We collect all batches synchronously here rather than using a background
84-
// thread with streaming. This is a deliberate design choice to avoid a GIL deadlock:
85-
//
86-
// When Python calls DataFusion's collect(), the Python thread holds the GIL.
87-
// If we spawned a background thread to read from Python, that thread would need
88-
// to acquire the GIL, but the main thread holds it while waiting for data from
89-
// the channel, causing a deadlock.
90-
//
91-
// The tradeoff: All batches are loaded into memory at query execution time.
92-
// For very large datasets, consider using smaller chunks or processing in stages.
93-
//
94-
// Future improvement: If datafusion-python releases the gil during async operations,
95-
// we could revisit the background thread approach for true streaming.
96-
let results: Vec<Result<RecordBatch, DataFusionError>> = Python::attach(|py| {
97-
// Call the factory to get a fresh stream
98-
let stream_result = self.stream_factory.call0(py);
94+
// Clone the factory with the GIL held
95+
let factory = Python::attach(|py| self.stream_factory.clone_ref(py));
9996

100-
match stream_result {
101-
Ok(stream_obj) => {
102-
let bound = stream_obj.bind(py);
103-
104-
match ArrowArrayStreamReader::from_pyarrow_bound(bound) {
105-
Ok(reader) => {
106-
// Collect all batches, converting errors to DataFusionError
107-
reader
108-
.map(|result| {
109-
result.map_err(|e| {
110-
DataFusionError::Execution(format!(
111-
"Failed to read batch from xarray stream: {e}"
112-
))
113-
})
114-
})
115-
.collect()
116-
}
117-
Err(e) => {
118-
vec![Err(DataFusionError::Execution(format!(
119-
"Failed to create Arrow stream reader: {e}"
120-
)))]
97+
// Create a lazy stream using unfold. Each poll acquires the GIL and reads one batch.
98+
// This allows DataFusion to process batches incrementally rather than loading all into memory.
99+
//
100+
// The stream state transitions: NotStarted -> Active -> Done
101+
// - NotStarted: Factory hasn't been called yet
102+
// - Active: Iterator is yielding batches
103+
// - Done: Iterator exhausted or error occurred
104+
let batch_stream = unfold(StreamState::NotStarted(factory), |state| async move {
105+
match state {
106+
StreamState::Done => None,
107+
StreamState::NotStarted(factory) => {
108+
// First poll: call factory to get iterator
109+
Python::attach(|py| {
110+
match factory.call0(py) {
111+
Ok(stream_obj) => {
112+
// Get Python iterator from the stream object
113+
let bound = stream_obj.bind(py);
114+
match bound.call_method0("__iter__") {
115+
Ok(iter) => {
116+
// Get first batch
117+
let iter_py: Py<PyAny> = iter.unbind();
118+
read_next_batch(py, iter_py)
119+
}
120+
Err(e) => Some((
121+
Err(DataFusionError::Execution(format!(
122+
"Failed to get iterator from stream: {e}"
123+
))),
124+
StreamState::Done,
125+
)),
126+
}
127+
}
128+
Err(e) => Some((
129+
Err(DataFusionError::Execution(format!(
130+
"Failed to call stream factory: {e}"
131+
))),
132+
StreamState::Done,
133+
)),
121134
}
122-
}
135+
})
123136
}
124-
Err(e) => {
125-
vec![Err(DataFusionError::Execution(format!(
126-
"Failed to call xarray stream factory: {e}"
127-
)))]
137+
StreamState::Active(iterator) => {
138+
// Subsequent polls: read next batch from iterator
139+
Python::attach(|py| read_next_batch(py, iterator))
128140
}
129141
}
130142
});
131143

132-
// Create a stream from the collected results
133-
let result_stream = stream::iter(results);
134-
Box::pin(RecordBatchStreamAdapter::new(schema, result_stream))
144+
Box::pin(RecordBatchStreamAdapter::new(schema, batch_stream))
145+
}
146+
}
147+
148+
/// Read the next batch from a Python iterator, returning the stream state transition.
149+
fn read_next_batch(
150+
py: Python<'_>,
151+
iterator: Py<PyAny>,
152+
) -> Option<(Result<RecordBatch, DataFusionError>, StreamState)> {
153+
let bound_iter = iterator.bind(py);
154+
155+
match bound_iter.call_method0("__next__") {
156+
Ok(batch_obj) => {
157+
// Convert PyArrow RecordBatch to Arrow RecordBatch
158+
match RecordBatch::from_pyarrow_bound(&batch_obj) {
159+
Ok(batch) => Some((Ok(batch), StreamState::Active(iterator))),
160+
Err(e) => Some((
161+
Err(DataFusionError::Execution(format!(
162+
"Failed to convert batch from PyArrow: {e}"
163+
))),
164+
StreamState::Done,
165+
)),
166+
}
167+
}
168+
Err(e) => {
169+
// Check if this is StopIteration (normal end of iterator)
170+
if e.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
171+
None // Stream exhausted normally
172+
} else {
173+
// Actual error
174+
Some((
175+
Err(DataFusionError::Execution(format!(
176+
"Error reading batch from stream: {e}"
177+
))),
178+
StreamState::Done,
179+
))
180+
}
181+
}
135182
}
136183
}
137184

0 commit comments

Comments
 (0)