diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index b67fc74d..fe970587 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Any, List, Optional, TYPE_CHECKING import logging @@ -82,6 +83,41 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + def _convert_complex_types_to_string( + self, rows: "pyarrow.Table" + ) -> "pyarrow.Table": + """ + Convert complex types (array, struct, map) to string representation. + Args: + rows: Input PyArrow table + Returns: + PyArrow table with complex types converted to strings + """ + + if not pyarrow: + return rows + + def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array": + python_values = col.to_pylist() + json_strings = [ + (None if val is None else json.dumps(val)) for val in python_values + ] + return pyarrow.array(json_strings, type=pyarrow.string()) + + converted_columns = [] + for col in rows.columns: + converted_col = col + if ( + pyarrow.types.is_list(col.type) + or pyarrow.types.is_large_list(col.type) + or pyarrow.types.is_struct(col.type) + or pyarrow.types.is_map(col.type) + ): + converted_col = convert_complex_column_to_string(col) + converted_columns.append(converted_col) + + return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names) + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -200,6 +236,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results @@ -213,6 +252,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index c8a3a078..7c79600f 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -54,10 +54,19 @@ def table_fixture(self, connection_details): ("map_array_col", list), ], ) - def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_arrow( + self, field, expected_type, table_fixture, extra_params + ): """Confirms the return types of a complex type field when reading as arrow""" - with self.cursor() as cursor: + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() @@ -75,10 +84,18 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): ("map_array_col"), ], ) - def test_read_complex_types_as_string(self, field, table_fixture): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_string(self, field, table_fixture, extra_params): """Confirms the return type of a complex type that is returned as a string""" + extra_params = {**extra_params, "_use_arrow_native_complex_types": False} with self.cursor( - extra_params={"_use_arrow_native_complex_types": False} + extra_params=extra_params, ) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1"