Skip to content

Commit aac1159

Browse files
committed
[Native] TopNRank optimization
1 parent fb92cb7 commit aac1159

File tree

15 files changed

+226
-23
lines changed

15 files changed

+226
-23
lines changed

presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ public Optional<PlanNode> visitTopNRowNumber(TopNRowNumberNode node, Context con
696696
new WindowNode.Specification(
697697
partitionBy,
698698
node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))),
699+
node.getRankingFunction(),
699700
rowNumberVariable,
700701
node.getMaxRowCountPerPartition(),
701702
node.isPartial(),

presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ public PlanOptimizers(
642642
estimatedExchangesCostCalculator,
643643
ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))),
644644
new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown
645-
new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits
645+
new WindowFilterPushDown(metadata, featuresConfig.isNativeExecutionEnabled()), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits
646646
prefilterForLimitingAggregation,
647647
new IterativeOptimizer(
648648
metadata,

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr
504504
idAllocator.getNextId(),
505505
child.getNode(),
506506
node.getSpecification(),
507+
node.getRankingFunction(),
507508
node.getRowNumberVariable(),
508509
node.getMaxRowCountPerPartition(),
509510
true,

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa
329329
node.getId(),
330330
child.getNode(),
331331
node.getSpecification(),
332+
node.getRankingFunction(),
332333
node.getRowNumberVariable(),
333334
node.getMaxRowCountPerPartition(),
334335
node.isPartial(),

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ public Optional<DecorrelationResult> visitTopN(TopNNode node, Void context)
310310
new Specification(
311311
ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate),
312312
Optional.of(orderingScheme)),
313+
TopNRowNumberNode.RankingFunction.ROW_NUMBER,
313314
variableAllocator.newVariable("row_number", BIGINT),
314315
toIntExact(node.getCount()),
315316
false,

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Set<Va
702702
node.getStatsEquivalentPlanNode(),
703703
source,
704704
node.getSpecification(),
705+
node.getRankingFunction(),
705706
node.getRowNumberVariable(),
706707
node.getMaxRowCountPerPartition(),
707708
node.isPartial(),

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Void>
482482
node.getId(),
483483
context.rewrite(node.getSource()),
484484
canonicalizeAndDistinct(node.getSpecification()),
485+
node.getRankingFunction(),
485486
canonicalize(node.getRowNumberVariable()),
486487
node.getMaxRowCountPerPartition(),
487488
node.isPartial(),

presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@ public class WindowFilterPushDown
6565
private final RowExpressionDomainTranslator domainTranslator;
6666
private final LogicalRowExpressions logicalRowExpressions;
6767

68-
public WindowFilterPushDown(Metadata metadata)
68+
private boolean isNativeExecution;
69+
70+
public WindowFilterPushDown(Metadata metadata, boolean isNativeExecution)
6971
{
7072
this.metadata = requireNonNull(metadata, "metadata is null");
7173
this.domainTranslator = new RowExpressionDomainTranslator(metadata);
74+
this.isNativeExecution = isNativeExecution;
7275
this.logicalRowExpressions = new LogicalRowExpressions(
7376
new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()),
7477
new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()),
@@ -84,7 +87,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider
8487
requireNonNull(variableAllocator, "variableAllocator is null");
8588
requireNonNull(idAllocator, "idAllocator is null");
8689

87-
Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session);
90+
Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session, isNativeExecution);
8891
PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
8992
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
9093
}
@@ -97,15 +100,17 @@ private static class Rewriter
97100
private final RowExpressionDomainTranslator domainTranslator;
98101
private final LogicalRowExpressions logicalRowExpressions;
99102
private final Session session;
103+
private final boolean isNativeExecution;
100104
private boolean planChanged;
101105

102-
private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session)
106+
private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session, boolean isNativeExecution)
103107
{
104108
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
105109
this.metadata = requireNonNull(metadata, "metadata is null");
106110
this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
107111
this.logicalRowExpressions = logicalRowExpressions;
108112
this.session = requireNonNull(session, "session is null");
113+
this.isNativeExecution = isNativeExecution;
109114
}
110115

111116
public boolean isPlanChanged()
@@ -138,6 +143,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
138143
public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
139144
{
140145
// Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value
146+
// TODO (Aditi) : Don't think this check is needed for Native engine.
141147
if (node.getCount() > Integer.MAX_VALUE) {
142148
return context.defaultRewrite(node);
143149
}
@@ -152,11 +158,11 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152158
planChanged = true;
153159
source = rowNumberNode;
154160
}
155-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
161+
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) {
156162
WindowNode windowNode = (WindowNode) source;
157163
// verify that unordered row_number window functions are replaced by RowNumberNode
158164
verify(windowNode.getOrderingScheme().isPresent());
159-
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
165+
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit, metadata.getFunctionAndTypeManager());
160166
if (windowNode.getPartitionBy().isEmpty()) {
161167
return topNRowNumberNode;
162168
}
@@ -183,13 +189,13 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183189
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
184190
}
185191
}
186-
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
192+
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) {
187193
WindowNode windowNode = (WindowNode) source;
188194
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
189195
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);
190196

191197
if (upperBound.isPresent()) {
192-
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
198+
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt(), metadata.getFunctionAndTypeManager());
193199
planChanged = true;
194200
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
195201
}
@@ -273,13 +279,23 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
273279
return new RowNumberNode(node.getSourceLocation(), node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberVariable(), Optional.of(newRowCountPerPartition), false, node.getHashVariable());
274280
}
275281

276-
private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit)
282+
private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit, FunctionAndTypeManager functionAndTypeManager)
277283
{
284+
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getWindowFunctions().keySet());
285+
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(windowNode.getWindowFunctions().get(rowNumberVariable).getFunctionHandle());
286+
287+
TopNRowNumberNode.RankingFunction rankingFunction =
288+
isRowNumberMetadata(functionAndTypeManager, functionMetadata) ?
289+
TopNRowNumberNode.RankingFunction.ROW_NUMBER :
290+
isRankMetadata(functionAndTypeManager, functionMetadata) ?
291+
TopNRowNumberNode.RankingFunction.RANK :
292+
TopNRowNumberNode.RankingFunction.DENSE_RANK;
278293
return new TopNRowNumberNode(
279294
windowNode.getSourceLocation(),
280295
idAllocator.getNextId(),
281296
windowNode.getSource(),
282297
windowNode.getSpecification(),
298+
rankingFunction,
283299
getOnlyElement(windowNode.getCreatedVariable()),
284300
limit,
285301
false,
@@ -288,22 +304,48 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
288304

289305
private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
290306
{
291-
return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
307+
if (node.getWindowFunctions().size() != 1) {
308+
return false;
309+
}
310+
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
311+
312+
return isRowNumberMetadata(functionAndTypeManager,
313+
functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()))
314+
&& !node.getOrderingScheme().isPresent();
292315
}
293316

294-
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
317+
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager, boolean isNativeExecution)
295318
{
296319
if (node.getWindowFunctions().size() != 1) {
297320
return false;
298321
}
299322
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
300-
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
323+
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle());
324+
if (isNativeExecution) {
325+
return isRowNumberMetadata(functionAndTypeManager, functionMetadata)
326+
|| node.getOrderingScheme().isPresent() && (isRankMetadata(functionAndTypeManager, functionMetadata)
327+
|| isDenseRankMetadata(functionAndTypeManager, functionMetadata));
328+
}
329+
330+
return isRowNumberMetadata(functionAndTypeManager, functionMetadata);
301331
}
302332

303333
private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
304334
{
305335
FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of());
306336
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction));
307337
}
338+
339+
private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
340+
{
341+
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of());
342+
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction));
343+
}
344+
345+
private static boolean isDenseRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
346+
{
347+
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of());
348+
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction));
349+
}
308350
}
309351
}

presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,18 @@
3636
public final class TopNRowNumberNode
3737
extends InternalPlanNode
3838
{
39+
public enum RankingFunction
40+
{
41+
ROW_NUMBER,
42+
RANK,
43+
DENSE_RANK
44+
}
45+
3946
private final PlanNode source;
4047
private final Specification specification;
48+
private final RankingFunction rankingFunction;
4149
private final VariableReferenceExpression rowNumberVariable;
42-
private final int maxRowCountPerPartition;
50+
private final int maxRankPerPartition;
4351
private final boolean partial;
4452
private final Optional<VariableReferenceExpression> hashVariable;
4553

@@ -49,12 +57,13 @@ public TopNRowNumberNode(
4957
@JsonProperty("id") PlanNodeId id,
5058
@JsonProperty("source") PlanNode source,
5159
@JsonProperty("specification") Specification specification,
60+
@JsonProperty("rankingType") RankingFunction rankingFunction,
5261
@JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable,
53-
@JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition,
62+
@JsonProperty("maxRowCountPerPartition") int maxRankPerPartition,
5463
@JsonProperty("partial") boolean partial,
5564
@JsonProperty("hashVariable") Optional<VariableReferenceExpression> hashVariable)
5665
{
57-
this(sourceLocation, id, Optional.empty(), source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
66+
this(sourceLocation, id, Optional.empty(), source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
5867
}
5968

6069
public TopNRowNumberNode(
@@ -63,8 +72,9 @@ public TopNRowNumberNode(
6372
Optional<PlanNode> statsEquivalentPlanNode,
6473
PlanNode source,
6574
Specification specification,
75+
RankingFunction rankingFunction,
6676
VariableReferenceExpression rowNumberVariable,
67-
int maxRowCountPerPartition,
77+
int maxRankPerPartition,
6878
boolean partial,
6979
Optional<VariableReferenceExpression> hashVariable)
7080
{
@@ -74,13 +84,14 @@ public TopNRowNumberNode(
7484
requireNonNull(specification, "specification is null");
7585
checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent");
7686
requireNonNull(rowNumberVariable, "rowNumberVariable is null");
77-
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
87+
checkArgument(maxRankPerPartition > 0, "maxRowCountPerPartition must be > 0");
7888
requireNonNull(hashVariable, "hashVariable is null");
7989

8090
this.source = source;
8191
this.specification = specification;
92+
this.rankingFunction = rankingFunction;
8293
this.rowNumberVariable = rowNumberVariable;
83-
this.maxRowCountPerPartition = maxRowCountPerPartition;
94+
this.maxRankPerPartition = maxRankPerPartition;
8495
this.partial = partial;
8596
this.hashVariable = hashVariable;
8697
}
@@ -108,6 +119,12 @@ public PlanNode getSource()
108119
return source;
109120
}
110121

122+
@JsonProperty
123+
public RankingFunction getRankingFunction()
124+
{
125+
return rankingFunction;
126+
}
127+
111128
@JsonProperty
112129
public Specification getSpecification()
113130
{
@@ -133,7 +150,7 @@ public VariableReferenceExpression getRowNumberVariable()
133150
@JsonProperty
134151
public int getMaxRowCountPerPartition()
135152
{
136-
return maxRowCountPerPartition;
153+
return maxRankPerPartition;
137154
}
138155

139156
@JsonProperty
@@ -157,12 +174,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
157174
@Override
158175
public PlanNode replaceChildren(List<PlanNode> newChildren)
159176
{
160-
return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
177+
return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
161178
}
162179

163180
@Override
164181
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
165182
{
166-
return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
183+
return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
167184
}
168185
}

presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context)
387387
{
388388
printNode(node,
389389
"TopNRowNumber",
390-
format("partition by = %s|order by = %s|n = %s",
390+
format("function = %s; partition by = %s|order by = %s|n = %s",
391+
node.getRankingFunction(),
391392
Joiner.on(", ").join(node.getPartitionBy()),
392393
Joiner.on(", ").join(node.getOrderingScheme().getOrderByVariables()), node.getMaxRowCountPerPartition()),
393394
NODE_COLORS.get(NodeType.WINDOW));

0 commit comments

Comments
 (0)