4646import java .util .Optional ;
4747import java .util .OptionalInt ;
4848
49+ import static com .facebook .presto .SystemSessionProperties .isNativeExecutionEnabled ;
4950import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRowNumber ;
5051import static com .facebook .presto .common .predicate .Marker .Bound .BELOW ;
5152import static com .facebook .presto .common .type .BigintType .BIGINT ;
@@ -65,10 +66,13 @@ public class WindowFilterPushDown
6566 private final RowExpressionDomainTranslator domainTranslator ;
6667 private final LogicalRowExpressions logicalRowExpressions ;
6768
68- public WindowFilterPushDown (Metadata metadata )
69+ private boolean isNativeExecution = false ;
70+
71+ public WindowFilterPushDown (Metadata metadata , boolean isNativeExecution )
6972 {
7073 this .metadata = requireNonNull (metadata , "metadata is null" );
7174 this .domainTranslator = new RowExpressionDomainTranslator (metadata );
75+ this .isNativeExecution = isNativeExecution ;
7276 this .logicalRowExpressions = new LogicalRowExpressions (
7377 new RowExpressionDeterminismEvaluator (metadata .getFunctionAndTypeManager ()),
7478 new FunctionResolution (metadata .getFunctionAndTypeManager ().getFunctionAndTypeResolver ()),
@@ -84,7 +88,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider
8488 requireNonNull (variableAllocator , "variableAllocator is null" );
8589 requireNonNull (idAllocator , "idAllocator is null" );
8690
87- Rewriter rewriter = new Rewriter (idAllocator , metadata , domainTranslator , logicalRowExpressions , session );
91+ Rewriter rewriter = new Rewriter (idAllocator , metadata , domainTranslator , logicalRowExpressions , session , isNativeExecution );
8892 PlanNode rewrittenPlan = SimplePlanRewriter .rewriteWith (rewriter , plan , null );
8993 return PlanOptimizerResult .optimizerResult (rewrittenPlan , rewriter .isPlanChanged ());
9094 }
@@ -97,15 +101,17 @@ private static class Rewriter
97101 private final RowExpressionDomainTranslator domainTranslator ;
98102 private final LogicalRowExpressions logicalRowExpressions ;
99103 private final Session session ;
104+ private final boolean isNativeExecution ;
100105 private boolean planChanged ;
101106
102- private Rewriter (PlanNodeIdAllocator idAllocator , Metadata metadata , RowExpressionDomainTranslator domainTranslator , LogicalRowExpressions logicalRowExpressions , Session session )
107+ private Rewriter (PlanNodeIdAllocator idAllocator , Metadata metadata , RowExpressionDomainTranslator domainTranslator , LogicalRowExpressions logicalRowExpressions , Session session , boolean isNativeExecution )
103108 {
104109 this .idAllocator = requireNonNull (idAllocator , "idAllocator is null" );
105110 this .metadata = requireNonNull (metadata , "metadata is null" );
106111 this .domainTranslator = requireNonNull (domainTranslator , "domainTranslator is null" );
107112 this .logicalRowExpressions = logicalRowExpressions ;
108113 this .session = requireNonNull (session , "session is null" );
114+ this .isNativeExecution = isNativeExecution ;
109115 }
110116
111117 public boolean isPlanChanged ()
@@ -138,6 +144,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
138144 public PlanNode visitLimit (LimitNode node , RewriteContext <Void > context )
139145 {
140146 // Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value
147+ // TODO (Aditi) : Don't think this check is needed for Native engine.
141148 if (node .getCount () > Integer .MAX_VALUE ) {
142149 return context .defaultRewrite (node );
143150 }
@@ -152,11 +159,11 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152159 planChanged = true ;
153160 source = rowNumberNode ;
154161 }
155- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
162+ else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager (), isNativeExecution ) && isOptimizeTopNRowNumber (session )) {
156163 WindowNode windowNode = (WindowNode ) source ;
157164 // verify that unordered row_number window functions are replaced by RowNumberNode
158165 verify (windowNode .getOrderingScheme ().isPresent ());
159- TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
166+ TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit , metadata . getFunctionAndTypeManager () );
160167 if (windowNode .getPartitionBy ().isEmpty ()) {
161168 return topNRowNumberNode ;
162169 }
@@ -183,13 +190,13 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183190 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
184191 }
185192 }
186- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
193+ else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager (), isNativeExecution ) && isOptimizeTopNRowNumber (session )) {
187194 WindowNode windowNode = (WindowNode ) source ;
188195 VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
189196 OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
190197
191198 if (upperBound .isPresent ()) {
192- source = convertToTopNRowNumber (windowNode , upperBound .getAsInt ());
199+ source = convertToTopNRowNumber (windowNode , upperBound .getAsInt (), metadata . getFunctionAndTypeManager () );
193200 planChanged = true ;
194201 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
195202 }
@@ -273,13 +280,23 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
273280 return new RowNumberNode (node .getSourceLocation (), node .getId (), node .getSource (), node .getPartitionBy (), node .getRowNumberVariable (), Optional .of (newRowCountPerPartition ), false , node .getHashVariable ());
274281 }
275282
276- private TopNRowNumberNode convertToTopNRowNumber (WindowNode windowNode , int limit )
283+ private TopNRowNumberNode convertToTopNRowNumber (WindowNode windowNode , int limit , FunctionAndTypeManager functionAndTypeManager )
277284 {
285+ VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getWindowFunctions ().keySet ());
286+ FunctionMetadata functionMetadata = functionAndTypeManager .getFunctionMetadata (windowNode .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ());
287+
288+ TopNRowNumberNode .RankingFunction rankingFunction =
289+ isRowNumberMetadata (functionAndTypeManager , functionMetadata ) ?
290+ TopNRowNumberNode .RankingFunction .ROW_NUMBER :
291+ isRankMetadata (functionAndTypeManager , functionMetadata ) ?
292+ TopNRowNumberNode .RankingFunction .RANK :
293+ TopNRowNumberNode .RankingFunction .DENSE_RANK ;
278294 return new TopNRowNumberNode (
279295 windowNode .getSourceLocation (),
280296 idAllocator .getNextId (),
281297 windowNode .getSource (),
282298 windowNode .getSpecification (),
299+ rankingFunction ,
283300 getOnlyElement (windowNode .getCreatedVariable ()),
284301 limit ,
285302 false ,
@@ -288,22 +305,49 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
288305
289306 private static boolean canReplaceWithRowNumber (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
290307 {
291- return canOptimizeWindowFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
308+ if (node .getWindowFunctions ().size () != 1 ) {
309+ return false ;
310+ }
311+ VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
312+
313+ return isRowNumberMetadata (functionAndTypeManager ,
314+ functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()))
315+ && !node .getOrderingScheme ().isPresent ();
292316 }
293317
294- private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
318+ private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager , boolean isNativeExecution )
295319 {
296320 if (node .getWindowFunctions ().size () != 1 ) {
297321 return false ;
298322 }
299323 VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
300- return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
324+ FunctionMetadata functionMetadata = functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ());
325+ if (isNativeExecution ) {
326+ return isRowNumberMetadata (functionAndTypeManager , functionMetadata )
327+ || node .getOrderingScheme ().isPresent () && (isRankMetadata (functionAndTypeManager , functionMetadata )
328+ || isDenseRankMetadata (functionAndTypeManager , functionMetadata ));
329+ }
330+
331+ return isRowNumberMetadata (functionAndTypeManager , functionMetadata );
301332 }
302333
334+
303335 private static boolean isRowNumberMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
304336 {
305337 FunctionHandle rowNumberFunction = functionAndTypeManager .lookupFunction ("row_number" , ImmutableList .of ());
306338 return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rowNumberFunction ));
307339 }
340+
341+ private static boolean isRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
342+ {
343+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("rank" , ImmutableList .of ());
344+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction ));
345+ }
346+
347+ private static boolean isDenseRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
348+ {
349+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("dense_rank" , ImmutableList .of ());
350+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction ));
351+ }
308352 }
309353}
0 commit comments