Skip to content

Commit e58b6c2

Browse files
committed
feat(optimizer): Add WindowFilterPushdown rule changes for rank queries
1 parent e6a3d72 commit e58b6c2

File tree

10 files changed

+254
-36
lines changed

10 files changed

+254
-36
lines changed

presto-docs/src/main/sphinx/presto_cpp/properties-session.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,17 @@ output for each input batch.
557557
If this is true, then the protocol::SpatialJoinNode is converted to a
558558
velox::core::SpatialJoinNode. Otherwise, it is converted to a
559559
velox::core::NestedLoopJoinNode.
560+
561+
``optimizer.optimize_top_n_rank``
562+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
563+
564+
* **Type:** ``boolean``
565+
* **Default value:** ``false``
566+
567+
If this is true, then filter and limit queries for ``n`` rows of
568+
``rank()`` and ``dense_rank()`` window function values are executed
569+
with a special TopNRowNumber operator instead of the
570+
WindowFunction operator.
571+
572+
The TopNRowNumber operator is more efficient than window as
573+
it has a streaming behavior and does not need to buffer all input rows.

presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ public final class SystemSessionProperties
182182
public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation";
183183
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
184184
public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number";
185+
public static final String OPTIMIZE_TOP_N_RANK = "optimize_top_n_rank";
185186
public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate";
186187
public static final String MAX_GROUPING_SETS = "max_grouping_sets";
187188
public static final String LEGACY_UNNEST = "legacy_unnest";
@@ -982,6 +983,11 @@ public SystemSessionProperties(
982983
"Use top N row number optimization",
983984
featuresConfig.isOptimizeTopNRowNumber(),
984985
false),
986+
booleanProperty(
987+
OPTIMIZE_TOP_N_RANK,
988+
"Use top N rank and dense_rank optimization",
989+
featuresConfig.isOptimizeTopNRank(),
990+
false),
985991
booleanProperty(
986992
OPTIMIZE_CASE_EXPRESSION_PREDICATE,
987993
"Optimize case expression predicates",
@@ -2567,6 +2573,11 @@ public static boolean isOptimizeTopNRowNumber(Session session)
25672573
return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class);
25682574
}
25692575

2576+
public static boolean isOptimizeTopNRank(Session session)
2577+
{
2578+
return session.getSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.class);
2579+
}
2580+
25702581
public static boolean isOptimizeCaseExpressionPredicate(Session session)
25712582
{
25722583
return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class);

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ public class FeaturesConfig
163163
private boolean adaptivePartialAggregationEnabled;
164164
private double adaptivePartialAggregationRowsReductionRatioThreshold = 0.8;
165165
private boolean optimizeTopNRowNumber = true;
166+
167+
private boolean optimizeTopNRank;
166168
private boolean pushLimitThroughOuterJoin = true;
167169
private boolean optimizeConstantGroupingKeys = true;
168170

@@ -1142,13 +1144,25 @@ public boolean isOptimizeTopNRowNumber()
11421144
return optimizeTopNRowNumber;
11431145
}
11441146

1147+
public boolean isOptimizeTopNRank()
1148+
{
1149+
return optimizeTopNRank;
1150+
}
1151+
11451152
@Config("optimizer.optimize-top-n-row-number")
11461153
public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber)
11471154
{
11481155
this.optimizeTopNRowNumber = optimizeTopNRowNumber;
11491156
return this;
11501157
}
11511158

1159+
@Config("optimizer.optimize-top-n-rank")
1160+
public FeaturesConfig setOptimizeTopNRank(boolean optimizeTopNRank)
1161+
{
1162+
this.optimizeTopNRank = optimizeTopNRank;
1163+
return this;
1164+
}
1165+
11521166
public boolean isOptimizeCaseExpressionPredicate()
11531167
{
11541168
return optimizeCaseExpressionPredicate;

presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ public Optional<PlanNode> visitTopNRowNumber(TopNRowNumberNode node, Context con
699699
new DataOrganizationSpecification(
700700
partitionBy,
701701
node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))),
702-
TopNRowNumberNode.RankingFunction.ROW_NUMBER,
702+
node.getRankingFunction(),
703703
rowNumberVariable,
704704
node.getMaxRowCountPerPartition(),
705705
node.isPartial(),

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@
4141
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
4242
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
4343
import com.google.common.collect.ImmutableList;
44+
import com.google.common.collect.Iterables;
4445

4546
import java.util.Map;
4647
import java.util.Optional;
4748
import java.util.OptionalInt;
4849

50+
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
51+
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRank;
4952
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber;
5053
import static com.facebook.presto.common.predicate.Marker.Bound.BELOW;
5154
import static com.facebook.presto.common.type.BigintType.BIGINT;
@@ -134,6 +137,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
134137
return replaceChildren(node, ImmutableList.of(rewrittenSource));
135138
}
136139

140+
private boolean canReplaceWithTopNRowNumber(WindowNode node)
141+
{
142+
return (canOptimizeRowNumberFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) ||
143+
(canOptimizeRankFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRank(session) && isNativeExecutionEnabled(session));
144+
}
145+
137146
@Override
138147
public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
139148
{
@@ -152,16 +161,18 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152161
planChanged = true;
153162
source = rowNumberNode;
154163
}
155-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
164+
else if (source instanceof WindowNode) {
156165
WindowNode windowNode = (WindowNode) source;
157-
// verify that unordered row_number window functions are replaced by RowNumberNode
158-
verify(windowNode.getOrderingScheme().isPresent());
159-
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
160-
if (windowNode.getPartitionBy().isEmpty()) {
161-
return topNRowNumberNode;
166+
if (canReplaceWithTopNRowNumber(windowNode)) {
167+
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
168+
// Limit can be entirely skipped for row_number without partitioning.
169+
if (windowNode.getPartitionBy().isEmpty() &&
170+
canOptimizeRowNumberFunction(windowNode, metadata.getFunctionAndTypeManager())) {
171+
return topNRowNumberNode;
172+
}
173+
planChanged = true;
174+
source = topNRowNumberNode;
162175
}
163-
planChanged = true;
164-
source = topNRowNumberNode;
165176
}
166177
return replaceChildren(node, ImmutableList.of(source));
167178
}
@@ -183,15 +194,17 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183194
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
184195
}
185196
}
186-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
197+
else if (source instanceof WindowNode) {
187198
WindowNode windowNode = (WindowNode) source;
188-
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
189-
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
190-
191-
if (upperBound.isPresent()) {
192-
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
193-
planChanged = true;
194-
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
199+
if (canReplaceWithTopNRowNumber(windowNode)) {
200+
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
201+
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
202+
203+
if (upperBound.isPresent()) {
204+
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
205+
planChanged = true;
206+
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
207+
}
195208
}
196209
}
197210
return replaceChildren(node, ImmutableList.of(source));
@@ -275,12 +288,30 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
275288

276289
private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit)
277290
{
291+
String windowFunction = Iterables.getOnlyElement(windowNode.getWindowFunctions().values()).getFunctionCall().getFunctionHandle().getName();
292+
String[] parts = windowFunction.split("\\.");
293+
String windowFunctionName = parts[parts.length - 1];
294+
TopNRowNumberNode.RankingFunction rankingFunction;
295+
switch (windowFunctionName) {
296+
case "row_number":
297+
rankingFunction = TopNRowNumberNode.RankingFunction.ROW_NUMBER;
298+
break;
299+
case "rank":
300+
rankingFunction = TopNRowNumberNode.RankingFunction.RANK;
301+
break;
302+
case "dense_rank":
303+
rankingFunction = TopNRowNumberNode.RankingFunction.DENSE_RANK;
304+
break;
305+
default:
306+
throw new IllegalArgumentException("Unsupported window function for TopNRowNumberNode: " + windowFunctionName);
307+
}
308+
278309
return new TopNRowNumberNode(
279310
windowNode.getSourceLocation(),
280311
idAllocator.getNextId(),
281312
windowNode.getSource(),
282313
windowNode.getSpecification(),
283-
TopNRowNumberNode.RankingFunction.ROW_NUMBER,
314+
rankingFunction,
284315
getOnlyElement(windowNode.getCreatedVariable()),
285316
limit,
286317
false,
@@ -289,22 +320,37 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
289320

290321
private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
291322
{
292-
return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
323+
return canOptimizeRowNumberFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
324+
}
325+
326+
private static boolean canOptimizeRowNumberFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
327+
{
328+
if (node.getWindowFunctions().size() != 1) {
329+
return false;
330+
}
331+
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(getOnlyElement(node.getWindowFunctions().values()).getFunctionHandle()));
293332
}
294333

295-
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
334+
private static boolean canOptimizeRankFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
296335
{
297336
if (node.getWindowFunctions().size() != 1) {
298337
return false;
299338
}
300-
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
301-
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
339+
return isRankMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(getOnlyElement(node.getWindowFunctions().values()).getFunctionHandle()));
302340
}
303341

304342
private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
305343
{
306344
FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of());
307345
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction));
308346
}
347+
348+
private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
349+
{
350+
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of());
351+
FunctionHandle denseRankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of());
352+
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)) ||
353+
functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(denseRankFunction));
354+
}
309355
}
310356
}

presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public TopNRowNumberNode(
8585
requireNonNull(rowNumberVariable, "rowNumberVariable is null");
8686
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
8787
requireNonNull(hashVariable, "hashVariable is null");
88+
requireNonNull(rankingFunction, "rankingFunction is null");
8889

8990
this.source = source;
9091
this.specification = specification;

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void testDefaults()
145145
.setAdaptivePartialAggregationEnabled(false)
146146
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.8)
147147
.setOptimizeTopNRowNumber(true)
148+
.setOptimizeTopNRank(false)
148149
.setOptimizeCaseExpressionPredicate(false)
149150
.setDistributedSortEnabled(true)
150151
.setMaxGroupingSets(2048)
@@ -368,6 +369,7 @@ public void testExplicitPropertyMappings()
368369
.put("experimental.adaptive-partial-aggregation", "true")
369370
.put("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold", "0.9")
370371
.put("optimizer.optimize-top-n-row-number", "false")
372+
.put("optimizer.optimize-top-n-rank", "true")
371373
.put("optimizer.optimize-case-expression-predicate", "true")
372374
.put("distributed-sort", "false")
373375
.put("analyzer.max-grouping-sets", "2047")
@@ -586,6 +588,7 @@ public void testExplicitPropertyMappings()
586588
.setAdaptivePartialAggregationEnabled(true)
587589
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.9)
588590
.setOptimizeTopNRowNumber(false)
591+
.setOptimizeTopNRank(true)
589592
.setOptimizeCaseExpressionPredicate(true)
590593
.setDistributedSortEnabled(false)
591594
.setMaxGroupingSets(2047)

0 commit comments

Comments
 (0)