|
17 | 17 | #include <csp/python/PyCppNode.h>
|
18 | 18 | #include <csp/python/PyNodeWrapper.h>
|
19 | 19 | #include <csp/python/NumpyConversions.h>
|
| 20 | +#include <arrow/c/abi.h> |
| 21 | +#include <arrow/c/bridge.h> |
20 | 22 | #include <arrow/io/memory.h>
|
21 |
| -#include <arrow/ipc/reader.h> |
| 23 | +#include <arrow/table.h> |
22 | 24 | #include <locale>
|
23 | 25 | #include <codecvt>
|
24 | 26 |
|
@@ -156,34 +158,30 @@ class ArrowTableGenerator : public csp::Generator<std::shared_ptr<arrow::Table>,
|
156 | 158 | {
|
157 | 159 | CSP_THROW( csp::python::PythonPassthrough, "" );
|
158 | 160 | }
|
159 |
| - if( nextVal == nullptr ) |
| 161 | + if( nextValPtr.get() == nullptr ) |
160 | 162 | {
|
161 | 163 | return false;
|
162 | 164 | }
|
163 | 165 |
|
164 |
| - if(!PyBytes_Check( nextVal )) |
| 166 | + if( !PyCapsule_IsValid( nextValPtr.get(), "arrow_array_stream" ) ) |
165 | 167 | {
|
166 |
| - CSP_THROW( csp::TypeError, "Invalid arrow buffer type, expected bytes got " << Py_TYPE( nextVal ) -> tp_name ); |
| 168 | + CSP_THROW( csp::TypeError, "Invalid arrow data, expected PyCapsule got " << Py_TYPE( nextValPtr.get() ) -> tp_name ); |
167 | 169 | }
|
168 |
| - const char * data = PyBytes_AsString( nextVal ); |
169 |
| - if( !data ) |
170 |
| - CSP_THROW( csp::python::PythonPassthrough, "" ); |
171 |
| - auto size = PyBytes_Size(nextVal); |
172 |
| - m_data = csp::python::PyObjectPtr::incref(nextVal); |
173 |
| - std::shared_ptr<arrow::io::BufferReader> bufferReader = std::make_shared<arrow::io::BufferReader>( |
174 |
| - reinterpret_cast<const uint8_t *>(data), size ); |
175 |
| - std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader = arrow::ipc::RecordBatchStreamReader::Open(bufferReader.get()).ValueOrDie(); |
176 |
| - auto result = reader->ToTable(); |
177 |
| - if (!(result.ok())) |
178 |
| - CSP_THROW(csp::RuntimeException, "Failed read arrow table from buffer"); |
179 |
| - value = std::move(result.ValueUnsafe()); |
| 170 | + // Extract the record batch |
| 171 | + struct ArrowArrayStream * c_stream = reinterpret_cast<struct ArrowArrayStream*>( PyCapsule_GetPointer( nextValPtr.get(), "arrow_array_stream" ) ); |
| 172 | + auto reader_result = arrow::ImportRecordBatchReader( c_stream ); |
| 173 | + if( !reader_result.ok() ) |
| 174 | + CSP_THROW( csp::ValueError, "Failed to load record batches through PyCapsule C Data interface: " << reader_result.status().ToString() ); |
| 175 | + auto reader = std::move( reader_result.ValueUnsafe() ); |
| 176 | + auto table_result = arrow::Table::FromRecordBatchReader( reader.get() ); |
| 177 | + if( !table_result.ok() ) |
| 178 | + CSP_THROW( csp::ValueError, "Failed to load table from record batches " << table_result.status().ToString() ); |
| 179 | + value = std::move( table_result.ValueUnsafe() ); |
180 | 180 | return true;
|
181 | 181 | }
|
182 | 182 | private:
|
183 | 183 | csp::python::PyObjectPtr m_wrappedGenerator;
|
184 | 184 | csp::python::PyObjectPtr m_iter;
|
185 |
| - // We need to keep the last buffer in memory since arrow doesn't copy it but can refer to strings in it |
186 |
| - csp::python::PyObjectPtr m_data; |
187 | 185 | };
|
188 | 186 |
|
189 | 187 | template< typename CspCType>
|
|
0 commit comments