2323//! for each batch read, allowing DataFusion to process and potentially filter batches
2424//! incrementally. This enables processing of larger-than-memory datasets when combined
2525//! with DataFusion's streaming execution.
26+ //!
27+ //! ## Parallel Execution Note
28+ //!
29+ //! When using DataFusion's parallel execution (multiple partitions), aggregation queries
30+ //! without ORDER BY may return partial results due to how our stream interacts with
31+ //! DataFusion's async runtime. To ensure complete results:
32+ //! - Add ORDER BY to aggregation queries, or
33+ //! - Use `SessionConfig().with_target_partitions(1)` for single-threaded execution
34+ //! TODO(#106): Implenet proper parallelism and partition handling.
2635
2736use std:: ffi:: CString ;
2837use std:: fmt:: Debug ;
29- use std:: sync:: Arc ;
38+ use std:: sync:: { Arc , Mutex } ;
3039
3140use arrow:: array:: RecordBatch ;
3241use arrow:: datatypes:: SchemaRef ;
@@ -72,15 +81,14 @@ impl Debug for PyArrowStreamPartition {
7281 }
7382}
7483
75- /// State for the lazy stream - holds a Python iterator that yields RecordBatches.
76- /// Using Py<PyAny> which is Send, allowing this to be used across async boundaries.
77- enum StreamState {
78- /// Initial state: factory not yet called
79- NotStarted ( Py < PyAny > ) ,
80- /// Active state: iterator ready to yield batches
81- Active ( Py < PyAny > ) ,
82- /// Terminal state: stream exhausted or errored
83- Done ,
84+ /// Shared state for the lazy stream, protected by Mutex for thread safety.
85+ struct SharedStreamState {
86+ /// The PyArrow RecordBatchReader (None until first batch is requested)
87+ reader : Option < Py < PyAny > > ,
88+ /// The factory to create the reader (consumed on first use)
89+ factory : Option < Py < PyAny > > ,
90+ /// Whether the stream has ended
91+ done : bool ,
8492}
8593
8694impl PartitionStream for PyArrowStreamPartition {
@@ -94,88 +102,87 @@ impl PartitionStream for PyArrowStreamPartition {
94102 // Clone the factory with the GIL held
95103 let factory = Python :: attach ( |py| self . stream_factory . clone_ref ( py) ) ;
96104
97- // Create a lazy stream using unfold. Each poll acquires the GIL and reads one batch.
98- // This allows DataFusion to process batches incrementally rather than loading all into memory.
99- //
100- // The stream state transitions: NotStarted -> Active -> Done
101- // - NotStarted: Factory hasn't been called yet
102- // - Active: Iterator is yielding batches
103- // - Done: Iterator exhausted or error occurred
104- let batch_stream = unfold ( StreamState :: NotStarted ( factory) , |state| async move {
105- match state {
106- StreamState :: Done => None ,
107- StreamState :: NotStarted ( factory) => {
108- // First poll: call factory to get a PyArrow RecordBatchReader
109- Python :: attach ( |py| {
105+ // TODO(alxmrs/CC): I think we need to do something datafusion-native here;
106+ // I suspect that adding a mutex will significantly impact performance.
107+ // This is OK for now.
108+ // Create shared state protected by Mutex
109+ let shared_state = Arc :: new ( Mutex :: new ( SharedStreamState {
110+ reader : None ,
111+ factory : Some ( factory) ,
112+ done : false ,
113+ } ) ) ;
114+
115+ // Create a lazy stream using unfold.
116+ // The Arc<Mutex<...>> is cloned for each iteration, ensuring thread-safe access.
117+ let batch_stream = unfold ( shared_state, |state| async move {
118+ // Clone Arc for potential return
119+ let state_clone = Arc :: clone ( & state) ;
120+
121+ // Lock the mutex to access state
122+ let mut guard = state. lock ( ) . unwrap ( ) ;
123+
124+ if guard. done {
125+ return None ;
126+ }
127+
128+ // Acquire GIL and process
129+ let result = Python :: attach ( |py| {
130+ // Initialize reader on first poll
131+ if guard. reader . is_none ( ) {
132+ if let Some ( factory) = guard. factory . take ( ) {
110133 match factory. call0 ( py) {
111134 Ok ( reader) => {
112- // Factory returns a RecordBatchReader directly
113- read_next_batch ( py, reader)
135+ guard. reader = Some ( reader) ;
114136 }
115- Err ( e) => Some ( (
116- Err ( DataFusionError :: Execution ( format ! (
137+ Err ( e) => {
138+ guard. done = true ;
139+ return Some ( Err ( DataFusionError :: Execution ( format ! (
117140 "Failed to call stream factory: {e}"
118- ) ) ) ,
119- StreamState :: Done ,
120- ) ) ,
141+ ) ) ) ) ;
142+ }
121143 }
122- } )
123- }
124- StreamState :: Active ( iterator) => {
125- // Subsequent polls: read next batch from iterator
126- Python :: attach ( |py| read_next_batch ( py, iterator) )
144+ }
127145 }
128- }
129- } ) ;
130146
131- Box :: pin ( RecordBatchStreamAdapter :: new ( schema, batch_stream) )
132- }
133- }
147+ // Read next batch from reader
148+ if let Some ( ref reader) = guard. reader {
149+ let bound_reader = reader. bind ( py) ;
150+ match bound_reader. call_method0 ( "read_next_batch" ) {
151+ Ok ( batch_obj) => match RecordBatch :: from_pyarrow_bound ( & batch_obj) {
152+ Ok ( batch) => Some ( Ok ( batch) ) ,
153+ Err ( e) => {
154+ guard. done = true ;
155+ Some ( Err ( DataFusionError :: Execution ( format ! (
156+ "Failed to convert batch from PyArrow: {e}"
157+ ) ) ) )
158+ }
159+ } ,
160+ Err ( e) => {
161+ if e. is_instance_of :: < pyo3:: exceptions:: PyStopIteration > ( py) {
162+ guard. done = true ;
163+ None // Stream exhausted normally
164+ } else {
165+ guard. done = true ;
166+ Some ( Err ( DataFusionError :: Execution ( format ! (
167+ "Error reading batch from stream: {e}"
168+ ) ) ) )
169+ }
170+ }
171+ }
172+ } else {
173+ guard. done = true ;
174+ None
175+ }
176+ } ) ;
134177
135- /// Read the next batch from a PyArrow RecordBatchReader, returning the stream state transition.
136- ///
137- /// Handles both:
138- /// - `read_next_batch()` returning None (PyArrow RecordBatchReader exhausted)
139- /// - `StopIteration` exception (Python iterator protocol)
140- fn read_next_batch (
141- py : Python < ' _ > ,
142- reader : Py < PyAny > ,
143- ) -> Option < ( Result < RecordBatch , DataFusionError > , StreamState ) > {
144- let bound_reader = reader. bind ( py) ;
145-
146- // Call read_next_batch() which returns None when exhausted
147- match bound_reader. call_method0 ( "read_next_batch" ) {
148- Ok ( batch_obj) => {
149- // Check if None (stream exhausted)
150- if batch_obj. is_none ( ) {
151- return None ; // Stream exhausted normally
152- }
178+ // Release lock before returning
179+ drop ( guard) ;
153180
154- // Convert PyArrow RecordBatch to Arrow RecordBatch
155- match RecordBatch :: from_pyarrow_bound ( & batch_obj) {
156- Ok ( batch) => Some ( ( Ok ( batch) , StreamState :: Active ( reader) ) ) ,
157- Err ( e) => Some ( (
158- Err ( DataFusionError :: Execution ( format ! (
159- "Failed to convert batch from PyArrow: {e}"
160- ) ) ) ,
161- StreamState :: Done ,
162- ) ) ,
163- }
164- }
165- Err ( e) => {
166- // Handle StopIteration as normal end of iteration (Python iterator protocol)
167- if e. is_instance_of :: < pyo3:: exceptions:: PyStopIteration > ( py) {
168- return None ; // Stream exhausted normally
169- }
181+ // Map result to include state for next iteration
182+ result. map ( |batch_result| ( batch_result, state_clone) )
183+ } ) ;
170184
171- // Actual error reading batch
172- Some ( (
173- Err ( DataFusionError :: Execution ( format ! (
174- "Error reading batch from stream: {e}"
175- ) ) ) ,
176- StreamState :: Done ,
177- ) )
178- }
185+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema, batch_stream) )
179186 }
180187}
181188
0 commit comments