Skip to content

Commit 8791f96

Browse files
committed
feat(optimizer): Add WindowFilterPushdown rule changes for rank queries
1 parent 32fc45c commit 8791f96

File tree

8 files changed

+250
-32
lines changed

8 files changed

+250
-32
lines changed

presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ public final class SystemSessionProperties
182182
public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation";
183183
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
184184
public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number";
185+
public static final String OPTIMIZE_TOP_N_RANK = "optimize_top_n_rank";
185186
public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate";
186187
public static final String MAX_GROUPING_SETS = "max_grouping_sets";
187188
public static final String LEGACY_UNNEST = "legacy_unnest";
@@ -980,6 +981,11 @@ public SystemSessionProperties(
980981
"Use top N row number optimization",
981982
featuresConfig.isOptimizeTopNRowNumber(),
982983
false),
984+
booleanProperty(
985+
OPTIMIZE_TOP_N_RANK,
986+
"Use top N rank and dense_rank optimization",
987+
featuresConfig.isOptimizeTopNRank(),
988+
false),
983989
booleanProperty(
984990
OPTIMIZE_CASE_EXPRESSION_PREDICATE,
985991
"Optimize case expression predicates",
@@ -2555,6 +2561,11 @@ public static boolean isOptimizeTopNRowNumber(Session session)
25552561
return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class);
25562562
}
25572563

2564+
public static boolean isOptimizeTopNRank(Session session)
2565+
{
2566+
return session.getSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.class);
2567+
}
2568+
25582569
public static boolean isOptimizeCaseExpressionPredicate(Session session)
25592570
{
25602571
return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class);

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ public class FeaturesConfig
163163
private boolean adaptivePartialAggregationEnabled;
164164
private double adaptivePartialAggregationRowsReductionRatioThreshold = 0.8;
165165
private boolean optimizeTopNRowNumber = true;
166+
167+
private boolean optimizeTopNRank;
166168
private boolean pushLimitThroughOuterJoin = true;
167169
private boolean optimizeConstantGroupingKeys = true;
168170

@@ -1141,13 +1143,25 @@ public boolean isOptimizeTopNRowNumber()
11411143
return optimizeTopNRowNumber;
11421144
}
11431145

1146+
public boolean isOptimizeTopNRank()
1147+
{
1148+
return optimizeTopNRank;
1149+
}
1150+
11441151
@Config("optimizer.optimize-top-n-row-number")
11451152
public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber)
11461153
{
11471154
this.optimizeTopNRowNumber = optimizeTopNRowNumber;
11481155
return this;
11491156
}
11501157

1158+
@Config("optimizer.optimize-top-n-rank")
1159+
public FeaturesConfig setOptimizeTopNRank(boolean optimizeTopNRank)
1160+
{
1161+
this.optimizeTopNRank = optimizeTopNRank;
1162+
return this;
1163+
}
1164+
11511165
public boolean isOptimizeCaseExpressionPredicate()
11521166
{
11531167
return optimizeCaseExpressionPredicate;

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@
4141
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
4242
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
4343
import com.google.common.collect.ImmutableList;
44+
import com.google.common.collect.Iterables;
4445

4546
import java.util.Map;
4647
import java.util.Optional;
4748
import java.util.OptionalInt;
4849

50+
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
51+
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRank;
4952
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber;
5053
import static com.facebook.presto.common.predicate.Marker.Bound.BELOW;
5154
import static com.facebook.presto.common.type.BigintType.BIGINT;
@@ -134,6 +137,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
134137
return replaceChildren(node, ImmutableList.of(rewrittenSource));
135138
}
136139

140+
private boolean canReplaceWithTopNRowNumber(WindowNode node)
141+
{
142+
return (canOptimizeRowNumberFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) ||
143+
(canOptimizeRankFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRank(session) && isNativeExecutionEnabled(session));
144+
}
145+
137146
@Override
138147
public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
139148
{
@@ -152,16 +161,16 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152161
planChanged = true;
153162
source = rowNumberNode;
154163
}
155-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
164+
else if (source instanceof WindowNode) {
156165
WindowNode windowNode = (WindowNode) source;
157-
// verify that unordered row_number window functions are replaced by RowNumberNode
158-
verify(windowNode.getOrderingScheme().isPresent());
159-
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
160-
if (windowNode.getPartitionBy().isEmpty()) {
161-
return topNRowNumberNode;
166+
if (canReplaceWithTopNRowNumber(windowNode)) {
167+
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
168+
if (windowNode.getPartitionBy().isEmpty()) {
169+
return topNRowNumberNode;
170+
}
171+
planChanged = true;
172+
source = topNRowNumberNode;
162173
}
163-
planChanged = true;
164-
source = topNRowNumberNode;
165174
}
166175
return replaceChildren(node, ImmutableList.of(source));
167176
}
@@ -183,15 +192,17 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183192
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
184193
}
185194
}
186-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
195+
else if (source instanceof WindowNode) {
187196
WindowNode windowNode = (WindowNode) source;
188-
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
189-
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
190-
191-
if (upperBound.isPresent()) {
192-
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
193-
planChanged = true;
194-
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
197+
if (canReplaceWithTopNRowNumber(windowNode)) {
198+
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
199+
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
200+
201+
if (upperBound.isPresent()) {
202+
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
203+
planChanged = true;
204+
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
205+
}
195206
}
196207
}
197208
return replaceChildren(node, ImmutableList.of(source));
@@ -287,12 +298,44 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
287298
Optional.empty());
288299
}
289300

301+
private TopNRowNumberNode convertToTopNRank(WindowNode windowNode, int limit)
302+
{
303+
String windowFunction = Iterables.getOnlyElement(windowNode.getWindowFunctions().values()).getFunctionCall().getFunctionHandle().getName();
304+
String[] parts = windowFunction.split("\\.");
305+
String windowFunctionName = parts[parts.length - 1];
306+
TopNRowNumberNode.RankingFunction rankingFunction;
307+
switch (windowFunctionName) {
308+
case "row_number":
309+
rankingFunction = TopNRowNumberNode.RankingFunction.ROW_NUMBER;
310+
break;
311+
case "rank":
312+
rankingFunction = TopNRowNumberNode.RankingFunction.RANK;
313+
break;
314+
case "dense_rank":
315+
rankingFunction = TopNRowNumberNode.RankingFunction.DENSE_RANK;
316+
break;
317+
default:
318+
throw new IllegalArgumentException("Unsupported window function for TopNRowNumberNode: " + windowFunctionName);
319+
}
320+
321+
return new TopNRowNumberNode(
322+
windowNode.getSourceLocation(),
323+
idAllocator.getNextId(),
324+
windowNode.getSource(),
325+
windowNode.getSpecification(),
326+
rankingFunction,
327+
getOnlyElement(windowNode.getCreatedVariable()),
328+
limit,
329+
false,
330+
Optional.empty());
331+
}
332+
290333
private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
291334
{
292-
return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
335+
return canOptimizeRowNumberFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
293336
}
294337

295-
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
338+
private static boolean canOptimizeRowNumberFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
296339
{
297340
if (node.getWindowFunctions().size() != 1) {
298341
return false;
@@ -301,10 +344,27 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTyp
301344
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
302345
}
303346

347+
private static boolean canOptimizeRankFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
348+
{
349+
if (node.getWindowFunctions().size() != 1) {
350+
return false;
351+
}
352+
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
353+
return isRankMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
354+
}
355+
304356
private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
305357
{
306358
FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of());
307359
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction));
308360
}
361+
362+
private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
363+
{
364+
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of());
365+
FunctionHandle denseRankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of());
366+
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)) ||
367+
functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(denseRankFunction));
368+
}
309369
}
310370
}

presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public TopNRowNumberNode(
8585
requireNonNull(rowNumberVariable, "rowNumberVariable is null");
8686
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
8787
requireNonNull(hashVariable, "hashVariable is null");
88+
requireNonNull(rankingFunction, "rankingFunction is null");
8889

8990
this.source = source;
9091
this.specification = specification;

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void testDefaults()
145145
.setAdaptivePartialAggregationEnabled(false)
146146
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.8)
147147
.setOptimizeTopNRowNumber(true)
148+
.setOptimizeTopNRank(false)
148149
.setOptimizeCaseExpressionPredicate(false)
149150
.setDistributedSortEnabled(true)
150151
.setMaxGroupingSets(2048)
@@ -367,6 +368,7 @@ public void testExplicitPropertyMappings()
367368
.put("experimental.adaptive-partial-aggregation", "true")
368369
.put("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold", "0.9")
369370
.put("optimizer.optimize-top-n-row-number", "false")
371+
.put("optimizer.optimize-top-n-rank", "true")
370372
.put("optimizer.optimize-case-expression-predicate", "true")
371373
.put("distributed-sort", "false")
372374
.put("analyzer.max-grouping-sets", "2047")
@@ -584,6 +586,7 @@ public void testExplicitPropertyMappings()
584586
.setAdaptivePartialAggregationEnabled(true)
585587
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.9)
586588
.setOptimizeTopNRowNumber(false)
589+
.setOptimizeTopNRank(true)
587590
.setOptimizeCaseExpressionPredicate(true)
588591
.setDistributedSortEnabled(false)
589592
.setMaxGroupingSets(2047)

0 commit comments

Comments
 (0)