Skip to content

Commit 48a35cb

Browse files
Handle restricted output columns in Arrow Page Source
Restricted columns need to be handled when using table valued functions in query.
1 parent 62bf52b commit 48a35cb

17 files changed

+630
-17
lines changed

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,16 @@
2222
import org.apache.arrow.vector.FieldVector;
2323

2424
import java.util.ArrayList;
25+
import java.util.HashMap;
2526
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Optional;
29+
import java.util.OptionalInt;
30+
import java.util.stream.IntStream;
2631

2732
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
33+
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_INTERNAL_ERROR;
34+
import static com.facebook.presto.common.Utils.checkArgument;
2835
import static java.util.Objects.requireNonNull;
2936

3037
public class ArrowPageSource
@@ -34,6 +41,8 @@ public class ArrowPageSource
3441
private final List<ArrowColumnHandle> columnHandles;
3542
private final ArrowBlockBuilder arrowBlockBuilder;
3643
private final ClientClosingFlightStream flightStreamAndClient;
44+
private final Optional<List<ArrowColumnHandle>> columnHandlesInSplit;
45+
private final Map<Integer, Integer> columnIndexToVectorIndexMap;
3746
private boolean completed;
3847
private int currentPosition;
3948

@@ -49,6 +58,8 @@ public ArrowPageSource(
4958
requireNonNull(clientHandler, "clientHandler is null");
5059
this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
5160
this.flightStreamAndClient = clientHandler.getFlightStream(connectorSession, split);
61+
this.columnHandlesInSplit = requireNonNull(split.getColumns(), "split.getColumns is null");
62+
columnIndexToVectorIndexMap = getColumnIndexToVectorIndexMap();
5263
}
5364

5465
@Override
@@ -98,8 +109,11 @@ public Page getNextPage()
98109
// Create blocks from the loaded Arrow record batch
99110
List<Block> blocks = new ArrayList<>();
100111
List<FieldVector> vectors = flightStreamAndClient.getRoot().getFieldVectors();
112+
113+
columnHandlesInSplit.ifPresent(arrowColumnHandles -> checkArgument(vectors.size() == arrowColumnHandles.size(), "Number of field vectors is not the same as number of column handles in split"));
114+
101115
for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
102-
FieldVector vector = vectors.get(columnIndex);
116+
FieldVector vector = vectors.get(columnIndexToVectorIndexMap.get(columnIndex));
103117
Type type = columnHandles.get(columnIndex).getColumnType();
104118
Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider());
105119
blocks.add(block);
@@ -122,4 +136,44 @@ public void close()
122136
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e);
123137
}
124138
}
139+
140+
private int getVectorIndexForColumnHandleIndex(int columnHandleIndex)
141+
{
142+
int vectorIndex;
143+
if (columnHandlesInSplit.isPresent()) {
144+
// If the ArrowSplit defines the list of columns for the data in the split,
145+
// get the vector to read by finding the index of the required column in this list.
146+
OptionalInt index = IntStream
147+
.range(0, columnHandlesInSplit.get().size())
148+
.filter(k -> columnHandlesInSplit.get().get(k).getColumnName().equals(columnHandles.get(columnHandleIndex).getColumnName()))
149+
.findFirst();
150+
if (index.isPresent()) {
151+
vectorIndex = index.getAsInt();
152+
}
153+
else {
154+
throw new ArrowException(ARROW_INTERNAL_ERROR, "Unable to find column " + columnHandles.get(columnHandleIndex).getColumnName() + " in the column handles given in split");
155+
}
156+
}
157+
else {
158+
// If the ArrowSplit does not define the list of columns for the data in the split,
159+
// assume that the vector to read is in the same position as the column handle
160+
vectorIndex = columnHandleIndex;
161+
}
162+
163+
return vectorIndex;
164+
}
165+
166+
private Map<Integer, Integer> getColumnIndexToVectorIndexMap()
167+
{
168+
Map<Integer, Integer> columnIndexToVectorIndexMap = new HashMap<>();
169+
if (columnHandlesInSplit.isPresent()) {
170+
for (int i = 0; i < columnHandles.size(); i++) {
171+
columnIndexToVectorIndexMap.put(i, getVectorIndexForColumnHandleIndex(i));
172+
}
173+
}
174+
else {
175+
IntStream.range(0, columnHandles.size()).forEach(i -> columnIndexToVectorIndexMap.put(i, i));
176+
}
177+
return columnIndexToVectorIndexMap;
178+
}
125179
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,27 @@
2323

2424
import java.util.Collections;
2525
import java.util.List;
26+
import java.util.Optional;
2627

2728
public class ArrowSplit
2829
implements ConnectorSplit
2930
{
3031
private final String schemaName;
3132
private final String tableName;
3233
private final byte[] flightEndpointBytes;
34+
private final Optional<List<ArrowColumnHandle>> columns;
3335

3436
@JsonCreator
3537
public ArrowSplit(
3638
@JsonProperty("schemaName") @Nullable String schemaName,
3739
@JsonProperty("tableName") String tableName,
38-
@JsonProperty("flightEndpointBytes") byte[] flightEndpointBytes)
40+
@JsonProperty("flightEndpointBytes") byte[] flightEndpointBytes,
41+
@JsonProperty("columns") Optional<List<ArrowColumnHandle>> columns)
3942
{
4043
this.schemaName = schemaName;
4144
this.tableName = tableName;
4245
this.flightEndpointBytes = flightEndpointBytes;
46+
this.columns = columns;
4347
}
4448

4549
@Override
@@ -77,4 +81,10 @@ public byte[] getFlightEndpointBytes()
7781
{
7882
return flightEndpointBytes;
7983
}
84+
85+
@JsonProperty
86+
public Optional<List<ArrowColumnHandle>> getColumns()
87+
{
88+
return columns;
89+
}
8090
}

presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.arrow.flight.FlightInfo;
2424

2525
import java.util.List;
26+
import java.util.Optional;
2627

2728
import static com.google.common.collect.ImmutableList.toImmutableList;
2829
import static java.util.Objects.requireNonNull;
@@ -49,8 +50,14 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand
4950
.map(info -> new ArrowSplit(
5051
tableHandle.getSchema(),
5152
tableHandle.getTable(),
52-
info.serialize().array()))
53+
info.serialize().array(),
54+
Optional.empty()))
5355
.collect(toImmutableList());
5456
return new FixedSplitSource(splits);
5557
}
58+
59+
public BaseArrowFlightClientHandler getClientHandler()
60+
{
61+
return clientHandler;
62+
}
5663
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,30 @@ public static DistributedQueryRunner createQueryRunner(int flightServerPort) thr
6262
return createQueryRunner(flightServerPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty());
6363
}
6464

65+
public static DistributedQueryRunner createQueryRunner(int flightServerPort, boolean testingWithTVF) throws Exception
66+
{
67+
return createQueryRunner(flightServerPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty(), testingWithTVF);
68+
}
69+
6570
public static DistributedQueryRunner createQueryRunner(
6671
int flightServerPort,
6772
Map<String, String> extraProperties,
6873
Map<String, String> coordinatorProperties,
6974
Optional<BiFunction<Integer, URI, Process>> externalWorkerLauncher,
7075
Optional<Boolean> mTLSEnabled)
7176
throws Exception
77+
{
78+
return createQueryRunner(flightServerPort, extraProperties, coordinatorProperties, externalWorkerLauncher, mTLSEnabled, false);
79+
}
80+
81+
public static DistributedQueryRunner createQueryRunner(
82+
int flightServerPort,
83+
Map<String, String> extraProperties,
84+
Map<String, String> coordinatorProperties,
85+
Optional<BiFunction<Integer, URI, Process>> externalWorkerLauncher,
86+
Optional<Boolean> mTLSEnabled,
87+
boolean testingWithTVF)
88+
throws Exception
7289
{
7390
Session session = testSessionBuilder()
7491
.setCatalog(ARROW_FLIGHT_CATALOG)
@@ -87,7 +104,7 @@ public static DistributedQueryRunner createQueryRunner(
87104

88105
try {
89106
boolean nativeExecution = externalWorkerLauncher.isPresent();
90-
queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution));
107+
queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution, testingWithTVF));
91108
Map<String, String> catalogProperties = ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort));
92109

93110
ImmutableMap.Builder<String, String> properties = ImmutableMap.<String, String>builder()

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class TestArrowFlightQueries
5353
extends AbstractTestQueries
5454
{
5555
private static final Logger logger = Logger.get(TestArrowFlightQueries.class);
56-
private int serverPort;
56+
protected int serverPort;
5757
private RootAllocator allocator;
5858
private FlightServer server;
5959
private DistributedQueryRunner arrowFlightQueryRunner;
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow;
15+
16+
import com.facebook.presto.testing.MaterializedResult;
17+
import com.facebook.presto.testing.QueryRunner;
18+
import org.testng.annotations.Test;
19+
20+
import static com.facebook.presto.common.type.IntegerType.INTEGER;
21+
import static com.facebook.presto.testing.MaterializedResult.resultBuilder;
22+
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
23+
24+
public class TestArrowFlightQueriesWithTVF
25+
extends TestArrowFlightQueries
26+
{
27+
@Override
28+
protected QueryRunner createQueryRunner()
29+
throws Exception
30+
{
31+
serverPort = ArrowFlightQueryRunner.findUnusedPort();
32+
return ArrowFlightQueryRunner.createQueryRunner(serverPort, true);
33+
}
34+
35+
@Test
36+
public void testQueryFunction()
37+
{
38+
MaterializedResult actualRow = computeActual("SELECT id from TABLE(system.query_function('SELECT name, id FROM tpch.member WHERE id = 1', 'name VARCHAR, id INTEGER'))");
39+
MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER)
40+
.row(1)
41+
.build();
42+
assertEquals(actualRow, expectedRow);
43+
}
44+
}

presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.net.URISyntaxException;
2525
import java.nio.ByteBuffer;
2626
import java.util.List;
27+
import java.util.Optional;
2728

2829
import static org.testng.Assert.assertEquals;
2930
import static org.testng.Assert.assertNotNull;
@@ -48,7 +49,7 @@ public void setUp()
4849
Location location = new Location("http://localhost:8080");
4950
flightEndpoint = new FlightEndpoint(ticket, location);
5051
// Instantiate ArrowSplit with mock data
51-
arrowSplit = new ArrowSplit(schemaName, tableName, flightEndpoint.serialize().array());
52+
arrowSplit = new ArrowSplit(schemaName, tableName, flightEndpoint.serialize().array(), Optional.empty());
5253
}
5354

5455
@Test
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.presto.common.type.IntegerType;
17+
import com.facebook.presto.common.type.Type;
18+
import com.facebook.presto.spi.PrestoException;
19+
20+
import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType;
21+
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
22+
23+
public final class PrimitiveToPrestoTypeMappings
24+
{
25+
private PrimitiveToPrestoTypeMappings()
26+
{
27+
throw new UnsupportedOperationException();
28+
}
29+
30+
public static Type fromPrimitiveToPrestoType(String dataType)
31+
{
32+
switch (dataType) {
33+
case "INTEGER":
34+
return IntegerType.INTEGER;
35+
36+
case "VARCHAR":
37+
return createUnboundedVarcharType();
38+
}
39+
throw new PrestoException(NOT_SUPPORTED, "Unsupported datatype '" + dataType + "' in the selected table.");
40+
}
41+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.plugin.arrow.testingConnector;
15+
16+
import com.facebook.plugin.arrow.ArrowColumnHandle;
17+
import com.facebook.plugin.arrow.ArrowTableHandle;
18+
import com.fasterxml.jackson.annotation.JsonProperty;
19+
20+
import java.util.Collections;
21+
import java.util.List;
22+
import java.util.Objects;
23+
import java.util.UUID;
24+
25+
public class QueryArrowTableHandle
26+
extends ArrowTableHandle
27+
{
28+
private final String query;
29+
private final List<ArrowColumnHandle> columns;
30+
31+
public QueryArrowTableHandle(
32+
@JsonProperty String query,
33+
@JsonProperty List<ArrowColumnHandle> columns)
34+
{
35+
super("schema-" + UUID.randomUUID(), "table-" + UUID.randomUUID());
36+
this.query = query;
37+
this.columns = Collections.unmodifiableList(columns);
38+
}
39+
40+
@JsonProperty("query")
41+
public String getQuery()
42+
{
43+
return query;
44+
}
45+
46+
@JsonProperty("columns")
47+
public List<ArrowColumnHandle> getColumns()
48+
{
49+
return columns;
50+
}
51+
52+
@Override
53+
public String toString()
54+
{
55+
return query + ":" + columns;
56+
}
57+
58+
@Override
59+
public boolean equals(Object o)
60+
{
61+
if (this == o) {
62+
return true;
63+
}
64+
if (o == null || getClass() != o.getClass()) {
65+
return false;
66+
}
67+
QueryArrowTableHandle that = (QueryArrowTableHandle) o;
68+
return Objects.equals(getSchema(), that.getSchema()) && Objects.equals(getTable(), that.getTable()) &&
69+
Objects.equals(getQuery(), that.getQuery()) && Objects.equals(getColumns(), that.getColumns());
70+
}
71+
72+
@Override
73+
public int hashCode()
74+
{
75+
return Objects.hash(getSchema(), getTable(), query, columns);
76+
}
77+
}

0 commit comments

Comments
 (0)