Skip to content

Commit efc25da

Browse files
fix: Handle restricted output columns in Arrow Page Source
1 parent 5205309 commit efc25da

22 files changed

+601
-33
lines changed

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import com.fasterxml.jackson.annotation.JsonCreator;
2020
import com.fasterxml.jackson.annotation.JsonProperty;
2121

22+
import java.util.Objects;
23+
2224
import static java.util.Objects.requireNonNull;
2325

2426
public class ArrowColumnHandle
@@ -61,4 +63,23 @@ public String toString()
6163
{
6264
return columnName + ":" + columnType;
6365
}
66+
67+
@Override
68+
public boolean equals(Object o)
69+
{
70+
if (this == o) {
71+
return true;
72+
}
73+
if (o == null || getClass() != o.getClass()) {
74+
return false;
75+
}
76+
ArrowColumnHandle that = (ArrowColumnHandle) o;
77+
return Objects.equals(columnName, that.columnName) && Objects.equals(columnType, that.columnType);
78+
}
79+
80+
@Override
81+
public int hashCode()
82+
{
83+
return Objects.hash(columnName, columnType);
84+
}
6485
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.util.Map;
4141
import java.util.Optional;
4242
import java.util.Set;
43+
import java.util.stream.Collectors;
4344

4445
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
4546
import static com.google.common.base.Preconditions.checkArgument;
@@ -83,7 +84,7 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable
8384
if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) {
8485
return null;
8586
}
86-
return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName());
87+
return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName(), Optional.empty());
8788
}
8889

8990
public List<Field> getColumnsList(ConnectorSession connectorSession, String schema, String table)
@@ -148,15 +149,29 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa
148149
@Override
149150
public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table)
150151
{
151-
List<ColumnMetadata> meta = new ArrayList<>();
152-
List<Field> columnList = getColumnsList(session, ((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable());
153-
154-
for (Field field : columnList) {
155-
String columnName = field.getName();
156-
Type fieldType = getPrestoTypeFromArrowField(field);
157-
meta.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, columnName)).setType(fieldType).build());
152+
checkArgument(table instanceof ArrowTableHandle, "Table handle should be of type ArrowTableHandle");
153+
ArrowTableHandle arrowTableHandle = (ArrowTableHandle) table;
154+
155+
List<ColumnMetadata> meta;
156+
if (arrowTableHandle.getColumns().isPresent()) {
157+
meta = ImmutableList.copyOf(arrowTableHandle.getColumns()
158+
.get()
159+
.stream()
160+
.map(column -> {
161+
return ColumnMetadata.builder().setName(normalizeIdentifier(session, column.getColumnName())).setType(column.getColumnType()).build();
162+
})
163+
.collect(Collectors.toList()));
164+
}
165+
else {
166+
List<Field> columnList = getColumnsList(session, arrowTableHandle.getSchema(), arrowTableHandle.getTable());
167+
meta = new ArrayList<>();
168+
for (Field field : columnList) {
169+
String columnName = field.getName();
170+
Type fieldType = getPrestoTypeFromArrowField(field);
171+
meta.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, columnName)).setType(fieldType).build());
172+
}
158173
}
159-
return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta);
174+
return new ConnectorTableMetadata(new SchemaTableName(arrowTableHandle.getSchema(), arrowTableHandle.getTable()), meta);
160175
}
161176

162177
@Override

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,16 @@
2222
import org.apache.arrow.vector.FieldVector;
2323

2424
import java.util.ArrayList;
25+
import java.util.HashMap;
2526
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Optional;
29+
import java.util.stream.Collectors;
30+
import java.util.stream.IntStream;
2631

2732
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
33+
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_INTERNAL_ERROR;
34+
import static com.facebook.presto.spi.function.table.Preconditions.checkArgument;
2835
import static java.util.Objects.requireNonNull;
2936

3037
public class ArrowPageSource
@@ -34,6 +41,7 @@ public class ArrowPageSource
3441
private final List<ArrowColumnHandle> columnHandles;
3542
private final ArrowBlockBuilder arrowBlockBuilder;
3643
private final ClientClosingFlightStream flightStreamAndClient;
44+
private final List<Integer> outputColumnIndices;
3745
private boolean completed;
3846
private int currentPosition;
3947

@@ -42,13 +50,16 @@ public ArrowPageSource(
4250
List<ArrowColumnHandle> columnHandles,
4351
BaseArrowFlightClientHandler clientHandler,
4452
ConnectorSession connectorSession,
45-
ArrowBlockBuilder arrowBlockBuilder)
53+
ArrowBlockBuilder arrowBlockBuilder,
54+
ArrowTableLayoutHandle arrowTableLayoutHandle)
4655
{
4756
requireNonNull(split, "split is null");
4857
this.columnHandles = requireNonNull(columnHandles, "columnHandles is null");
4958
requireNonNull(clientHandler, "clientHandler is null");
5059
this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
5160
this.flightStreamAndClient = clientHandler.getFlightStream(connectorSession, split);
61+
requireNonNull(arrowTableLayoutHandle, "arrowTableLayoutHandle is null");
62+
outputColumnIndices = getOutputColumnIndices(arrowTableLayoutHandle.getTable().getColumns());
5263
}
5364

5465
@Override
@@ -98,8 +109,10 @@ public Page getNextPage()
98109
// Create blocks from the loaded Arrow record batch
99110
List<Block> blocks = new ArrayList<>();
100111
List<FieldVector> vectors = flightStreamAndClient.getRoot().getFieldVectors();
101-
for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
102-
FieldVector vector = vectors.get(columnIndex);
112+
113+
for (int columnIndex = 0; columnIndex < outputColumnIndices.size(); columnIndex++) {
114+
checkArgument(outputColumnIndices.get(columnIndex) < vectors.size(), "Column index " + outputColumnIndices.get(columnIndex) + " is out of bounds for list of vectors with " + vectors.size() + " elements");
115+
FieldVector vector = vectors.get(outputColumnIndices.get(columnIndex));
103116
Type type = columnHandles.get(columnIndex).getColumnType();
104117
Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider());
105118
blocks.add(block);
@@ -122,4 +135,28 @@ public void close()
122135
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e);
123136
}
124137
}
138+
139+
private List<Integer> getOutputColumnIndices(Optional<List<ArrowColumnHandle>> tableHandleColumns)
140+
{
141+
List<Integer> outputColumnIndices;
142+
// Compute the indices of the output columns from the columns retrieved from the flight server
143+
if (tableHandleColumns.isPresent()) {
144+
outputColumnIndices = new ArrayList<>();
145+
Map<String, Integer> tableColumnNameToIndexMap = new HashMap<>();
146+
IntStream.range(0, tableHandleColumns.get().size()).forEach(
147+
i -> tableColumnNameToIndexMap.put(tableHandleColumns.get().get(i).getColumnName(), i));
148+
for (ArrowColumnHandle columnHandle : columnHandles) {
149+
if (tableColumnNameToIndexMap.containsKey(columnHandle.getColumnName())) {
150+
outputColumnIndices.add(tableColumnNameToIndexMap.get(columnHandle.getColumnName()));
151+
}
152+
else {
153+
throw new ArrowException(ARROW_INTERNAL_ERROR, "Unable to find column " + columnHandle.getColumnName() + " in the list of columns in table handle: " + String.join(",", tableColumnNameToIndexMap.keySet()));
154+
}
155+
}
156+
}
157+
else {
158+
outputColumnIndices = IntStream.range(0, columnHandles.size()).boxed().collect(Collectors.toList());
159+
}
160+
return outputColumnIndices;
161+
}
125162
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
*/
1414
package com.facebook.plugin.arrow;
1515

16+
import com.facebook.presto.common.RuntimeStats;
1617
import com.facebook.presto.spi.ColumnHandle;
1718
import com.facebook.presto.spi.ConnectorPageSource;
1819
import com.facebook.presto.spi.ConnectorSession;
1920
import com.facebook.presto.spi.ConnectorSplit;
21+
import com.facebook.presto.spi.ConnectorTableLayoutHandle;
2022
import com.facebook.presto.spi.SplitContext;
2123
import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
2224
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
@@ -41,13 +43,14 @@ public ArrowPageSourceProvider(BaseArrowFlightClientHandler clientHandler, Arrow
4143
}
4244

4345
@Override
44-
public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List<ColumnHandle> columns, SplitContext splitContext)
46+
public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, ConnectorTableLayoutHandle layout, List<ColumnHandle> columns, SplitContext splitContext, RuntimeStats runtimeStats)
4547
{
4648
ImmutableList.Builder<ArrowColumnHandle> columnHandles = ImmutableList.builder();
4749
for (ColumnHandle handle : columns) {
4850
columnHandles.add((ArrowColumnHandle) handle);
4951
}
5052
ArrowSplit arrowSplit = (ArrowSplit) split;
51-
return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder);
53+
ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) layout;
54+
return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder, arrowTableLayoutHandle);
5255
}
5356
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,35 @@
1717
import com.fasterxml.jackson.annotation.JsonCreator;
1818
import com.fasterxml.jackson.annotation.JsonProperty;
1919

20+
import java.util.List;
2021
import java.util.Objects;
22+
import java.util.Optional;
2123

2224
public class ArrowTableHandle
2325
implements ConnectorTableHandle
2426
{
2527
private final String schema;
2628
private final String table;
29+
private final Optional<List<ArrowColumnHandle>> columns;
2730

31+
/***
32+
* Create instance of ArrowTableHandle
33+
* @param schema schema name
34+
* @param table table name
35+
* @param columns If present, this list should be the list of columns in the table,
36+
* which can be a larger list than the list of output columns. This list should match the list of vectors
37+
* returned by Arrow Flight stream for this table. This value needs to be set only
38+
* in certain scenarios like when using TVF. In other common scenarios, this can be empty.
39+
*/
2840
@JsonCreator
2941
public ArrowTableHandle(
3042
@JsonProperty("schema") String schema,
31-
@JsonProperty("table") String table)
43+
@JsonProperty("table") String table,
44+
@JsonProperty("columns") Optional<List<ArrowColumnHandle>> columns)
3245
{
3346
this.schema = schema;
3447
this.table = table;
48+
this.columns = columns;
3549
}
3650

3751
@JsonProperty("schema")
@@ -46,10 +60,16 @@ public String getTable()
4660
return table;
4761
}
4862

63+
@JsonProperty("columns")
64+
public Optional<List<ArrowColumnHandle>> getColumns()
65+
{
66+
return columns;
67+
}
68+
4969
@Override
5070
public String toString()
5171
{
52-
return schema + ":" + table;
72+
return schema + ":" + table + ":" + columns;
5373
}
5474

5575
@Override
@@ -62,7 +82,7 @@ public boolean equals(Object o)
6282
return false;
6383
}
6484
ArrowTableHandle that = (ArrowTableHandle) o;
65-
return Objects.equals(schema, that.schema) && Objects.equals(table, that.table);
85+
return Objects.equals(schema, that.schema) && Objects.equals(table, that.table) && Objects.equals(columns, that.columns);
6686
}
6787

6888
@Override

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
*/
1414
package com.facebook.plugin.arrow;
1515

16+
import com.facebook.airlift.testing.EquivalenceTester;
17+
import com.facebook.presto.common.type.BigintType;
1618
import com.facebook.presto.common.type.IntegerType;
1719
import com.facebook.presto.spi.ColumnMetadata;
1820
import org.testng.annotations.Test;
1921

22+
import static com.facebook.plugin.arrow.ArrowMetadataUtil.COLUMN_CODEC;
23+
import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip;
2024
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
2125
import static org.testng.Assert.assertNotNull;
2226

@@ -78,4 +82,18 @@ public void testToString()
7882
String expected = columnName + ":" + IntegerType.INTEGER;
7983
assertEquals(result, expected, "toString() should return the correct string representation");
8084
}
85+
86+
@Test
87+
public void testJsonRoundTrip()
88+
{
89+
assertJsonRoundTrip(COLUMN_CODEC, new ArrowColumnHandle("column1", BigintType.BIGINT));
90+
}
91+
92+
@Test
93+
public void testEquivalence()
94+
{
95+
EquivalenceTester.equivalenceTester()
96+
.addEquivalentGroup(
97+
new ArrowColumnHandle("column1", BigintType.BIGINT)).check();
98+
}
8199
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ protected FeaturesConfig createFeaturesConfig()
118118
return new FeaturesConfig().setNativeExecutionEnabled(true);
119119
}
120120

121+
@Test
122+
public void testQueryFunctionWithRestrictedColumns()
123+
{
124+
assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4");
125+
}
126+
127+
@Test
128+
public void testQueryFunctionWithoutRestrictedColumns() throws InterruptedException
129+
{
130+
assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4");
131+
}
132+
121133
@Test
122134
public void testFiltersAndProjections1()
123135
{

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,18 @@ public void testDescribeUnknownTable()
166166
assertEquals(actualRows, expectedRows);
167167
}
168168

169+
@Test
170+
public void testQueryFunctionWithRestrictedColumns()
171+
{
172+
assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4");
173+
}
174+
175+
@Test
176+
public void testQueryFunctionWithoutRestrictedColumns() throws InterruptedException
177+
{
178+
assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4");
179+
}
180+
169181
private LocalDate getDate(String dateString)
170182
{
171183
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
package com.facebook.plugin.arrow;
1515

1616
import com.facebook.airlift.testing.EquivalenceTester;
17+
import com.facebook.presto.common.type.BigintType;
18+
import com.facebook.presto.common.type.VarcharType;
1719
import org.testng.annotations.Test;
1820

21+
import java.util.Arrays;
22+
import java.util.List;
23+
import java.util.Optional;
24+
1925
import static com.facebook.plugin.arrow.ArrowMetadataUtil.TABLE_CODEC;
2026
import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip;
2127

@@ -24,14 +30,16 @@ public class TestArrowTableHandle
2430
@Test
2531
public void testJsonRoundTrip()
2632
{
27-
assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table"));
33+
List<ArrowColumnHandle> columnHandles = Arrays.asList(new ArrowColumnHandle("column1", BigintType.BIGINT), new ArrowColumnHandle("column2", VarcharType.VARCHAR));
34+
assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table", Optional.of(columnHandles)));
2835
}
2936

3037
@Test
3138
public void testEquivalence()
3239
{
40+
List<ArrowColumnHandle> columnHandles = Arrays.asList(new ArrowColumnHandle("column1", BigintType.BIGINT), new ArrowColumnHandle("column2", VarcharType.VARCHAR));
3341
EquivalenceTester.equivalenceTester()
3442
.addEquivalentGroup(
35-
new ArrowTableHandle("tm_engine", "employees")).check();
43+
new ArrowTableHandle("tm_engine", "employees", Optional.of(columnHandles))).check();
3644
}
3745
}

0 commit comments

Comments
 (0)