1717//! - **Reusable tables**: The factory pattern allows the same table to be queried
1818//! multiple times, with fresh data streams created for each query.
1919//!
20- //! ## Memory Considerations
20+ //! ## Streaming Behavior
2121//!
22- //! Currently, all batches are loaded into memory when a query executes. This is
23- //! due to Python GIL constraints that prevent true streaming with background threads.
24- //! For very large datasets, consider using smaller chunks or processing in stages.
22+ //! Batches are read lazily one at a time during query execution. The GIL is acquired
23+ //! for each batch read, allowing DataFusion to process and potentially filter batches
24+ //! incrementally. This enables processing of larger-than-memory datasets when combined
25+ //! with DataFusion's streaming execution.
2526
2627use std:: ffi:: CString ;
2728use std:: fmt:: Debug ;
2829use std:: sync:: Arc ;
2930
3031use arrow:: array:: RecordBatch ;
3132use arrow:: datatypes:: SchemaRef ;
32- use arrow:: ffi_stream:: ArrowArrayStreamReader ;
3333use arrow:: pyarrow:: FromPyArrow ;
3434use datafusion:: catalog:: streaming:: StreamingTable ;
3535use datafusion:: common:: DataFusionError ;
@@ -39,7 +39,7 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
3939use datafusion:: physical_plan:: streaming:: PartitionStream ;
4040use datafusion:: physical_plan:: SendableRecordBatchStream ;
4141use datafusion_ffi:: table_provider:: FFI_TableProvider ;
42- use futures:: stream;
42+ use futures:: stream:: unfold ;
4343use pyo3:: prelude:: * ;
4444use pyo3:: types:: PyCapsule ;
4545use tokio:: runtime:: Handle ;
@@ -72,6 +72,17 @@ impl Debug for PyArrowStreamPartition {
7272 }
7373}
7474
75+ /// State for the lazy stream - holds a Python iterator that yields RecordBatches.
76+ /// Using Py<PyAny> which is Send, allowing this to be used across async boundaries.
77+ enum StreamState {
78+ /// Initial state: factory not yet called
79+ NotStarted ( Py < PyAny > ) ,
80+ /// Active state: iterator ready to yield batches
81+ Active ( Py < PyAny > ) ,
82+ /// Terminal state: stream exhausted or errored
83+ Done ,
84+ }
85+
7586impl PartitionStream for PyArrowStreamPartition {
7687 fn schema ( & self ) -> & SchemaRef {
7788 & self . schema
@@ -80,58 +91,94 @@ impl PartitionStream for PyArrowStreamPartition {
8091 fn execute ( & self , _ctx : Arc < TaskContext > ) -> SendableRecordBatchStream {
8192 let schema = Arc :: clone ( & self . schema ) ;
8293
83- // NOTE: We collect all batches synchronously here rather than using a background
84- // thread with streaming. This is a deliberate design choice to avoid a GIL deadlock:
85- //
86- // When Python calls DataFusion's collect(), the Python thread holds the GIL.
87- // If we spawned a background thread to read from Python, that thread would need
88- // to acquire the GIL, but the main thread holds it while waiting for data from
89- // the channel, causing a deadlock.
90- //
91- // The tradeoff: All batches are loaded into memory at query execution time.
92- // For very large datasets, consider using smaller chunks or processing in stages.
93- //
94- // Future improvement: If datafusion-python releases the gil during async operations,
95- // we could revisit the background thread approach for true streaming.
96- let results: Vec < Result < RecordBatch , DataFusionError > > = Python :: attach ( |py| {
97- // Call the factory to get a fresh stream
98- let stream_result = self . stream_factory . call0 ( py) ;
94+ // Clone the factory with the GIL held
95+ let factory = Python :: attach ( |py| self . stream_factory . clone_ref ( py) ) ;
9996
100- match stream_result {
101- Ok ( stream_obj) => {
102- let bound = stream_obj. bind ( py) ;
103-
104- match ArrowArrayStreamReader :: from_pyarrow_bound ( bound) {
105- Ok ( reader) => {
106- // Collect all batches, converting errors to DataFusionError
107- reader
108- . map ( |result| {
109- result. map_err ( |e| {
110- DataFusionError :: Execution ( format ! (
111- "Failed to read batch from xarray stream: {e}"
112- ) )
113- } )
114- } )
115- . collect ( )
116- }
117- Err ( e) => {
118- vec ! [ Err ( DataFusionError :: Execution ( format!(
119- "Failed to create Arrow stream reader: {e}"
120- ) ) ) ]
97+ // Create a lazy stream using unfold. Each poll acquires the GIL and reads one batch.
98+ // This allows DataFusion to process batches incrementally rather than loading all into memory.
99+ //
100+ // The stream state transitions: NotStarted -> Active -> Done
101+ // - NotStarted: Factory hasn't been called yet
102+ // - Active: Iterator is yielding batches
103+ // - Done: Iterator exhausted or error occurred
104+ let batch_stream = unfold ( StreamState :: NotStarted ( factory) , |state| async move {
105+ match state {
106+ StreamState :: Done => None ,
107+ StreamState :: NotStarted ( factory) => {
108+ // First poll: call factory to get iterator
109+ Python :: attach ( |py| {
110+ match factory. call0 ( py) {
111+ Ok ( stream_obj) => {
112+ // Get Python iterator from the stream object
113+ let bound = stream_obj. bind ( py) ;
114+ match bound. call_method0 ( "__iter__" ) {
115+ Ok ( iter) => {
116+ // Get first batch
117+ let iter_py: Py < PyAny > = iter. unbind ( ) ;
118+ read_next_batch ( py, iter_py)
119+ }
120+ Err ( e) => Some ( (
121+ Err ( DataFusionError :: Execution ( format ! (
122+ "Failed to get iterator from stream: {e}"
123+ ) ) ) ,
124+ StreamState :: Done ,
125+ ) ) ,
126+ }
127+ }
128+ Err ( e) => Some ( (
129+ Err ( DataFusionError :: Execution ( format ! (
130+ "Failed to call stream factory: {e}"
131+ ) ) ) ,
132+ StreamState :: Done ,
133+ ) ) ,
121134 }
122- }
135+ } )
123136 }
124- Err ( e) => {
125- vec ! [ Err ( DataFusionError :: Execution ( format!(
126- "Failed to call xarray stream factory: {e}"
127- ) ) ) ]
137+ StreamState :: Active ( iterator) => {
138+ // Subsequent polls: read next batch from iterator
139+ Python :: attach ( |py| read_next_batch ( py, iterator) )
128140 }
129141 }
130142 } ) ;
131143
132- // Create a stream from the collected results
133- let result_stream = stream:: iter ( results) ;
134- Box :: pin ( RecordBatchStreamAdapter :: new ( schema, result_stream) )
144+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema, batch_stream) )
145+ }
146+ }
147+
148+ /// Read the next batch from a Python iterator, returning the stream state transition.
149+ fn read_next_batch (
150+ py : Python < ' _ > ,
151+ iterator : Py < PyAny > ,
152+ ) -> Option < ( Result < RecordBatch , DataFusionError > , StreamState ) > {
153+ let bound_iter = iterator. bind ( py) ;
154+
155+ match bound_iter. call_method0 ( "__next__" ) {
156+ Ok ( batch_obj) => {
157+ // Convert PyArrow RecordBatch to Arrow RecordBatch
158+ match RecordBatch :: from_pyarrow_bound ( & batch_obj) {
159+ Ok ( batch) => Some ( ( Ok ( batch) , StreamState :: Active ( iterator) ) ) ,
160+ Err ( e) => Some ( (
161+ Err ( DataFusionError :: Execution ( format ! (
162+ "Failed to convert batch from PyArrow: {e}"
163+ ) ) ) ,
164+ StreamState :: Done ,
165+ ) ) ,
166+ }
167+ }
168+ Err ( e) => {
169+ // Check if this is StopIteration (normal end of iterator)
170+ if e. is_instance_of :: < pyo3:: exceptions:: PyStopIteration > ( py) {
171+ None // Stream exhausted normally
172+ } else {
173+ // Actual error
174+ Some ( (
175+ Err ( DataFusionError :: Execution ( format ! (
176+ "Error reading batch from stream: {e}"
177+ ) ) ) ,
178+ StreamState :: Done ,
179+ ) )
180+ }
181+ }
135182 }
136183}
137184
0 commit comments