diff --git a/apis/python/tests/test_dataframe.py b/apis/python/tests/test_dataframe.py index bec5c5fc36..5cb86e8fd8 100644 --- a/apis/python/tests/test_dataframe.py +++ b/apis/python/tests/test_dataframe.py @@ -1979,24 +1979,74 @@ def test_arrow_table_sliced_writer(tmp_path): [ ("myint", pa.int32()), ("mystring", pa.large_string()), + ("mybool", pa.bool_()), + ("myenumint", pa.dictionary(pa.int64(), pa.int32())), + ("myenumstr", pa.dictionary(pa.int64(), pa.large_string())), + ("myenumbool", pa.dictionary(pa.int64(), pa.bool_())), ] ) pydict = { "soma_joinid": list(range(num_rows)), - "myint": [(e + 1) * 10 for e in range(num_rows)], - "mystring": ["s_%08d" % e for e in range(num_rows)], + "myint": np.random.randint(10, 100, size=num_rows), + "mystring": [f"s_{np.random.randint(1, 100000):08d}" for _ in range(num_rows)], + "mybool": np.random.choice([False, True], size=num_rows), + "myenumint": pd.Categorical( + np.random.choice([1, 2, 3], size=num_rows, replace=True) + ), + "myenumstr": pd.Categorical( + np.random.choice(["a", "bb", "ccc"], size=num_rows, replace=True) + ), + "myenumbool": pd.Categorical( + np.random.choice([False, True], size=num_rows, replace=True) + ), } + + pydict["myenumint"] = pa.DictionaryArray.from_arrays( + pa.array(pydict["myenumint"].codes, type=pa.int32()), + pa.array([1, 2, 3], type=pa.int32()), + ) + + pydict["myenumstr"] = pa.DictionaryArray.from_arrays( + pa.array(pydict["myenumstr"].codes, type=pa.int32()), + pa.array(["a", "bb", "ccc"], type=pa.large_string()), + ) + + pydict["myenumbool"] = pa.DictionaryArray.from_arrays( + pa.array(pydict["myenumbool"].codes, type=pa.int32()), + pa.array([False, True], type=pa.bool_()), + ) + table = pa.Table.from_pydict(pydict) domain = [[0, len(table) - 1]] with soma.DataFrame.create(uri, schema=schema, domain=domain) as sdf: + sdf.write(table[:]) + + with soma.DataFrame.open(uri) as sdf: + pdf = sdf.read().concat().to_pandas() + + np.testing.assert_array_equal(pdf["myint"], pydict["myint"]) + np.testing.assert_array_equal(pdf["mystring"], pydict["mystring"]) + np.testing.assert_array_equal(pdf["mybool"], pydict["mybool"]) + + np.testing.assert_array_equal(pdf["myenumint"], pydict["myenumint"]) + np.testing.assert_array_equal(pdf["myenumstr"], pydict["myenumstr"]) + np.testing.assert_array_equal(pdf["myenumbool"], pydict["myenumbool"]) + + with soma.DataFrame.open(uri, mode="w") as sdf: mid = num_rows // 2 sdf.write(table[:mid]) sdf.write(table[mid:]) with soma.DataFrame.open(uri) as sdf: pdf = sdf.read().concat().to_pandas() - assert list(pdf["myint"]) == pydict["myint"] - assert list(pdf["mystring"]) == pydict["mystring"] + + np.testing.assert_array_equal(pdf["myint"], pydict["myint"]) + np.testing.assert_array_equal(pdf["mystring"], pydict["mystring"]) + np.testing.assert_array_equal(pdf["mybool"], pydict["mybool"]) + + np.testing.assert_array_equal(pdf["myenumint"], pydict["myenumint"]) + np.testing.assert_array_equal(pdf["myenumstr"], pydict["myenumstr"]) + np.testing.assert_array_equal(pdf["myenumbool"], pydict["myenumbool"]) diff --git a/libtiledbsoma/src/soma/managed_query.cc b/libtiledbsoma/src/soma/managed_query.cc index 07ce3207ff..bf3975170b 100644 --- a/libtiledbsoma/src/soma/managed_query.cc +++ b/libtiledbsoma/src/soma/managed_query.cc @@ -1014,12 +1014,17 @@ bool ManagedQuery::_cast_column_aux( (void)se; // se is unused in bool specialization auto casted = util::cast_bit_to_uint8(schema, array); + uint8_t* validity = (uint8_t*)array->buffers[0]; + if (validity != nullptr) { + validity += array->offset; + } + setup_write_column( schema->name, array->length, (const void*)casted.data(), (uint64_t*)nullptr, - (uint8_t*)array->buffers[0]); + (uint8_t*)validity); return false; } @@ -1097,15 +1102,14 @@ bool ManagedQuery::_extend_and_evolve_schema( // Specially handle Boolean types as their representation in Arrow (bit) // is different from what is in TileDB (uint8_t) auto casted = util::cast_bit_to_uint8(value_schema, value_array); - enums_in_write.assign( - (ValueType*)casted.data(), (ValueType*)casted.data() + num_elems); + enums_in_write.assign(casted.data(), casted.data() + num_elems); } else { // General case - const void* data; + ValueType* data; if (value_array->n_buffers == 3) { - data = value_array->buffers[2]; + data = (ValueType*)value_array->buffers[2] + value_array->offset; } else { - data = value_array->buffers[1]; + data = (ValueType*)value_array->buffers[1] + value_array->offset; } enums_in_write.assign((ValueType*)data, (ValueType*)data + num_elems); } @@ -1252,4 +1256,4 @@ bool ManagedQuery::_extend_and_evolve_schema( } return false; } -}; // namespace tiledbsoma \ No newline at end of file +}; // namespace tiledbsoma diff --git a/libtiledbsoma/src/soma/managed_query.h b/libtiledbsoma/src/soma/managed_query.h index b3e6ad9b30..97821f9d50 100644 --- a/libtiledbsoma/src/soma/managed_query.h +++ b/libtiledbsoma/src/soma/managed_query.h @@ -725,9 +725,9 @@ class ManagedQuery { // Get the user passed-in dictionary indexes IndexType* idxbuf; if (index_array->n_buffers == 3) { - idxbuf = (IndexType*)index_array->buffers[2]; + idxbuf = (IndexType*)index_array->buffers[2] + index_array->offset; } else { - idxbuf = (IndexType*)index_array->buffers[1]; + idxbuf = (IndexType*)index_array->buffers[1] + index_array->offset; } std::vector original_indexes( idxbuf, idxbuf + index_array->length); @@ -794,12 +794,17 @@ class ManagedQuery { std::vector casted_indexes( shifted_indexes.begin(), shifted_indexes.end()); + uint8_t* validity = (uint8_t*)index_array->buffers[0]; + if (validity != nullptr) { + validity += index_array->offset; + } + setup_write_column( column_name, casted_indexes.size(), (const void*)casted_indexes.data(), (uint64_t*)nullptr, - (uint8_t*)index_array->buffers[0]); + (uint8_t*)validity); } bool _extend_enumeration( diff --git a/libtiledbsoma/src/soma/soma_array.cc b/libtiledbsoma/src/soma/soma_array.cc index e7d7e25fd9..5b25a26376 100644 --- a/libtiledbsoma/src/soma/soma_array.cc +++ b/libtiledbsoma/src/soma/soma_array.cc @@ -327,7 +327,10 @@ void SOMAArray::write(bool sort_coords) { } mq_->submit_write(sort_coords); - mq_->reset(); + // When we evolve the schema, the ArraySchema needs to be updated to the + // latest version so re-open the Array + arr_ = std::make_shared(*ctx_->tiledb_ctx(), uri_, TILEDB_WRITE); + mq_ = std::make_unique(arr_, ctx_->tiledb_ctx(), name_); } void SOMAArray::consolidate_and_vacuum(std::vector modes) { diff --git a/libtiledbsoma/src/utils/util.cc b/libtiledbsoma/src/utils/util.cc index 550aa593d6..23adfd274d 100644 --- a/libtiledbsoma/src/utils/util.cc +++ b/libtiledbsoma/src/utils/util.cc @@ -81,20 +81,19 @@ std::vector cast_bit_to_uint8(ArrowSchema* schema, ArrowArray* array) { schema->format)); } - const void* data; + uint8_t* data; if (array->n_buffers == 3) { - data = array->buffers[2]; + data = (uint8_t*)array->buffers[2]; } else { - data = array->buffers[1]; + data = (uint8_t*)array->buffers[1]; } - std::vector casted; - for (int64_t i = 0; i * 8 < array->length; ++i) { - uint8_t byte = ((uint8_t*)data)[i]; - for (int64_t j = 0; j < 8; ++j) { - casted.push_back((uint8_t)((byte >> j) & 0x01)); - } - } + std::vector casted(array->length); + ArrowBitsUnpackInt8( + data, + array->offset, + array->length, + reinterpret_cast(casted.data())); return casted; }