Skip to content

Commit 75d0257

Browse files
committed
feat(optimizer): Add WindowFilterPushdown rule changes for rank queries
1 parent 216dadc commit 75d0257

File tree

7 files changed

+225
-21
lines changed

7 files changed

+225
-21
lines changed

presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ public final class SystemSessionProperties
182182
public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation";
183183
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
184184
public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number";
185+
public static final String OPTIMIZE_TOP_N_RANK = "optimize_top_n_rank";
185186
public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate";
186187
public static final String MAX_GROUPING_SETS = "max_grouping_sets";
187188
public static final String LEGACY_UNNEST = "legacy_unnest";
@@ -980,6 +981,11 @@ public SystemSessionProperties(
980981
"Use top N row number optimization",
981982
featuresConfig.isOptimizeTopNRowNumber(),
982983
false),
984+
booleanProperty(
985+
OPTIMIZE_TOP_N_RANK,
986+
"Use top N rank and dense_rank optimization",
987+
featuresConfig.isOptimizeTopNRank(),
988+
false),
983989
booleanProperty(
984990
OPTIMIZE_CASE_EXPRESSION_PREDICATE,
985991
"Optimize case expression predicates",
@@ -2555,6 +2561,11 @@ public static boolean isOptimizeTopNRowNumber(Session session)
25552561
return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class);
25562562
}
25572563

2564+
public static boolean isOptimizeTopNRank(Session session)
2565+
{
2566+
return session.getSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.class);
2567+
}
2568+
25582569
public static boolean isOptimizeCaseExpressionPredicate(Session session)
25592570
{
25602571
return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class);

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ public class FeaturesConfig
163163
private boolean adaptivePartialAggregationEnabled;
164164
private double adaptivePartialAggregationRowsReductionRatioThreshold = 0.8;
165165
private boolean optimizeTopNRowNumber = true;
166+
167+
private boolean optimizeTopNRank = false;
166168
private boolean pushLimitThroughOuterJoin = true;
167169
private boolean optimizeConstantGroupingKeys = true;
168170

@@ -1135,19 +1137,30 @@ public FeaturesConfig setAdaptivePartialAggregationRowsReductionRatioThreshold(d
11351137
this.adaptivePartialAggregationRowsReductionRatioThreshold = adaptivePartialAggregationRowsReductionRatioThreshold;
11361138
return this;
11371139
}
1138-
11391140
public boolean isOptimizeTopNRowNumber()
11401141
{
11411142
return optimizeTopNRowNumber;
11421143
}
11431144

1145+
public boolean isOptimizeTopNRank()
1146+
{
1147+
return optimizeTopNRank;
1148+
}
1149+
11441150
@Config("optimizer.optimize-top-n-row-number")
11451151
public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber)
11461152
{
11471153
this.optimizeTopNRowNumber = optimizeTopNRowNumber;
11481154
return this;
11491155
}
11501156

1157+
@Config("optimizer.optimize-top-n-rank")
1158+
public FeaturesConfig setOptimizeTopNRank(boolean optimizeTopNRank)
1159+
{
1160+
this.optimizeTopNRank = optimizeTopNRank;
1161+
return this;
1162+
}
1163+
11511164
public boolean isOptimizeCaseExpressionPredicate()
11521165
{
11531166
return optimizeCaseExpressionPredicate;

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@
4141
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
4242
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
4343
import com.google.common.collect.ImmutableList;
44+
import com.google.common.collect.Iterables;
4445

4546
import java.util.Map;
4647
import java.util.Optional;
4748
import java.util.OptionalInt;
4849

4950
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber;
51+
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRank;
52+
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
5053
import static com.facebook.presto.common.predicate.Marker.Bound.BELOW;
5154
import static com.facebook.presto.common.type.BigintType.BIGINT;
5255
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
@@ -152,11 +155,12 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152155
planChanged = true;
153156
source = rowNumberNode;
154157
}
155-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
158+
else if ((source instanceof WindowNode && canOptimizeRowNumberFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) ||
159+
(source instanceof WindowNode && canOptimizeRankFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRank(session) && isNativeExecutionEnabled(session))) {
156160
WindowNode windowNode = (WindowNode) source;
157161
// verify that unordered row_number window functions are replaced by RowNumberNode
158162
verify(windowNode.getOrderingScheme().isPresent());
159-
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
163+
TopNRowNumberNode topNRowNumberNode = convertToTopNRank(windowNode, limit);
160164
if (windowNode.getPartitionBy().isEmpty()) {
161165
return topNRowNumberNode;
162166
}
@@ -183,7 +187,8 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183187
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
184188
}
185189
}
186-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
190+
else if ((source instanceof WindowNode && canOptimizeRowNumberFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) ||
191+
(source instanceof WindowNode && canOptimizeRankFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRank(session) && isNativeExecutionEnabled(session))) {
187192
WindowNode windowNode = (WindowNode) source;
188193
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
189194
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
@@ -280,19 +285,51 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
280285
idAllocator.getNextId(),
281286
windowNode.getSource(),
282287
windowNode.getSpecification(),
288+
283289
TopNRowNumberNode.RankingFunction.ROW_NUMBER,
284290
getOnlyElement(windowNode.getCreatedVariable()),
285291
limit,
286292
false,
287293
Optional.empty());
288294
}
289295

296+
private TopNRowNumberNode convertToTopNRank(WindowNode windowNode, int limit)
297+
{
298+
299+
String windowFunctionName = Iterables.getOnlyElement(windowNode.getWindowFunctions().values()).getFunctionCall().getFunctionHandle().getName();
300+
TopNRowNumberNode.RankingFunction rankingFunction;
301+
switch (windowFunctionName) {
302+
case "row_number":
303+
rankingFunction = TopNRowNumberNode.RankingFunction.ROW_NUMBER;
304+
break;
305+
case "rank":
306+
rankingFunction = TopNRowNumberNode.RankingFunction.RANK;
307+
break;
308+
case "dense_rank":
309+
rankingFunction = TopNRowNumberNode.RankingFunction.DENSE_RANK;
310+
break;
311+
default:
312+
throw new IllegalArgumentException("Unsupported window function for TopNRowNumberNode: " + windowFunctionName);
313+
}
314+
315+
return new TopNRowNumberNode(
316+
windowNode.getSourceLocation(),
317+
idAllocator.getNextId(),
318+
windowNode.getSource(),
319+
windowNode.getSpecification(),
320+
rankingFunction,
321+
getOnlyElement(windowNode.getCreatedVariable()),
322+
limit,
323+
false,
324+
Optional.empty());
325+
}
326+
290327
private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
291328
{
292-
return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
329+
return canOptimizeRowNumberFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
293330
}
294331

295-
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
332+
private static boolean canOptimizeRowNumberFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
296333
{
297334
if (node.getWindowFunctions().size() != 1) {
298335
return false;
@@ -301,10 +338,27 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTyp
301338
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
302339
}
303340

341+
private static boolean canOptimizeRankFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
342+
{
343+
if (node.getWindowFunctions().size() != 1) {
344+
return false;
345+
}
346+
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
347+
return isRankMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
348+
}
349+
304350
private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
305351
{
306352
FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of());
307353
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction));
308354
}
355+
356+
private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
357+
{
358+
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of());
359+
FunctionHandle denseRankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of());
360+
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)) ||
361+
functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(denseRankFunction));
362+
}
309363
}
310364
}

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void testDefaults()
145145
.setAdaptivePartialAggregationEnabled(false)
146146
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.8)
147147
.setOptimizeTopNRowNumber(true)
148+
.setOptimizeTopNRank(false)
148149
.setOptimizeCaseExpressionPredicate(false)
149150
.setDistributedSortEnabled(true)
150151
.setMaxGroupingSets(2048)
@@ -367,6 +368,7 @@ public void testExplicitPropertyMappings()
367368
.put("experimental.adaptive-partial-aggregation", "true")
368369
.put("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold", "0.9")
369370
.put("optimizer.optimize-top-n-row-number", "false")
371+
.put("optimizer.optimize-top-n-rank", "true")
370372
.put("optimizer.optimize-case-expression-predicate", "true")
371373
.put("distributed-sort", "false")
372374
.put("analyzer.max-grouping-sets", "2047")
@@ -584,6 +586,7 @@ public void testExplicitPropertyMappings()
584586
.setAdaptivePartialAggregationEnabled(true)
585587
.setAdaptivePartialAggregationRowsReductionRatioThreshold(0.9)
586588
.setOptimizeTopNRowNumber(false)
589+
.setOptimizeTopNRank(true)
587590
.setOptimizeCaseExpressionPredicate(true)
588591
.setDistributedSortEnabled(false)
589592
.setMaxGroupingSets(2047)

presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestWindowFilterPushDown.java

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import org.intellij.lang.annotations.Language;
2222
import org.testng.annotations.Test;
2323

24-
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_TOP_N_ROW_NUMBER;
24+
import static com.facebook.presto.SystemSessionProperties.*;
2525
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot;
2626
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
2727
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit;
@@ -31,15 +31,11 @@
3131
public class TestWindowFilterPushDown
3232
extends BasePlanTest
3333
{
34-
@Test
35-
public void testLimitAboveWindow()
34+
private void testLimitSql(String sql, boolean rowOrRank)
3635
{
37-
@Language("SQL") String sql = "SELECT " +
38-
"row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10";
39-
4036
assertPlanWithSession(
4137
sql,
42-
optimizeTopNRowNumber(true),
38+
rowOrRank ? optimizeTopNRowNumber(true) : optimizeTopNRank(true),
4339
true,
4440
anyTree(
4541
limit(10, anyTree(
@@ -49,25 +45,47 @@ public void testLimitAboveWindow()
4945

5046
assertPlanWithSession(
5147
sql,
52-
optimizeTopNRowNumber(false),
48+
rowOrRank ? optimizeTopNRowNumber(false) : optimizeTopNRank(false),
5349
true,
5450
anyTree(
5551
limit(10, anyTree(
5652
node(WindowNode.class,
5753
anyTree(
5854
tableScan("lineitem")))))));
59-
}
6055

56+
if (!rowOrRank) {
57+
assertPlanWithSession(
58+
sql,
59+
optimizeTopNRankWithoutNative(true),
60+
true,
61+
anyTree(
62+
limit(10, anyTree(
63+
node(WindowNode.class,
64+
anyTree(
65+
tableScan("lineitem")))))));
66+
}
67+
68+
}
6169
@Test
62-
public void testFilterAboveWindow()
70+
public void testLimitAboveWindow()
6371
{
64-
@Language("SQL") String sql = "SELECT * FROM " +
65-
"(SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) " +
66-
"WHERE partition_row_number < 10";
72+
@Language("SQL") String sql = "SELECT " +
73+
"row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10";
74+
testLimitSql(sql, true);
6775

76+
sql = "SELECT " +
77+
"rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10";
78+
testLimitSql(sql, false);
79+
80+
sql = "SELECT " +
81+
"dense_rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10";
82+
testLimitSql(sql, false);
83+
}
84+
85+
private void testFilterSql(String sql, boolean rowOrRank) {
6886
assertPlanWithSession(
6987
sql,
70-
optimizeTopNRowNumber(true),
88+
rowOrRank ? optimizeTopNRowNumber(true) : optimizeTopNRank(true),
7189
true,
7290
anyTree(
7391
anyNot(FilterNode.class,
@@ -77,14 +95,46 @@ public void testFilterAboveWindow()
7795

7896
assertPlanWithSession(
7997
sql,
80-
optimizeTopNRowNumber(false),
98+
rowOrRank ? optimizeTopNRowNumber(false) : optimizeTopNRank(false),
8199
true,
82100
anyTree(
83101
node(FilterNode.class,
84102
anyTree(
85103
node(WindowNode.class,
86104
anyTree(
87105
tableScan("lineitem")))))));
106+
107+
if (!rowOrRank) {
108+
assertPlanWithSession(
109+
sql,
110+
optimizeTopNRankWithoutNative(true),
111+
true,
112+
anyTree(
113+
node(FilterNode.class,
114+
anyTree(
115+
node(WindowNode.class,
116+
anyTree(
117+
tableScan("lineitem")))))));
118+
}
119+
}
120+
@Test
121+
public void testFilterAboveWindow()
122+
{
123+
@Language("SQL") String sql = "SELECT * FROM " +
124+
"(SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) " +
125+
"WHERE partition_row_number < 10";
126+
127+
testFilterSql(sql, true);
128+
129+
sql = "SELECT * FROM " +
130+
"(SELECT rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_rank FROM lineitem) " +
131+
"WHERE partition_rank < 10";
132+
testFilterSql(sql, false);
133+
134+
sql = "SELECT * FROM " +
135+
"(SELECT dense_rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_dense_rank FROM lineitem) " +
136+
"WHERE partition_dense_rank < 10";
137+
testFilterSql(sql, false);
88138
}
89139

90140
private Session optimizeTopNRowNumber(boolean enabled)
@@ -93,4 +143,20 @@ private Session optimizeTopNRowNumber(boolean enabled)
93143
.setSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.toString(enabled))
94144
.build();
95145
}
146+
147+
private Session optimizeTopNRank(boolean enabled)
148+
{
149+
return Session.builder(this.getQueryRunner().getDefaultSession())
150+
.setSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.toString(enabled))
151+
.setSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.toString(enabled))
152+
.build();
153+
}
154+
155+
private Session optimizeTopNRankWithoutNative(boolean enabled)
156+
{
157+
return Session.builder(this.getQueryRunner().getDefaultSession())
158+
.setSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.toString(false))
159+
.setSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.toString(enabled))
160+
.build();
161+
}
96162
}

0 commit comments

Comments
 (0)