4141import com .facebook .presto .sql .relational .RowExpressionDeterminismEvaluator ;
4242import com .facebook .presto .sql .relational .RowExpressionDomainTranslator ;
4343import com .google .common .collect .ImmutableList ;
44+ import com .google .common .collect .Iterables ;
4445
4546import java .util .Map ;
4647import java .util .Optional ;
4748import java .util .OptionalInt ;
4849
50+ import static com .facebook .presto .SystemSessionProperties .isNativeExecutionEnabled ;
51+ import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRank ;
4952import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRowNumber ;
5053import static com .facebook .presto .common .predicate .Marker .Bound .BELOW ;
5154import static com .facebook .presto .common .type .BigintType .BIGINT ;
@@ -152,11 +155,12 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152155 planChanged = true ;
153156 source = rowNumberNode ;
154157 }
155- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
158+ else if ((source instanceof WindowNode && canOptimizeRowNumberFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
159+ (source instanceof WindowNode && canOptimizeRankFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ))) {
156160 WindowNode windowNode = (WindowNode ) source ;
157161 // verify that unordered row_number window functions are replaced by RowNumberNode
158162 verify (windowNode .getOrderingScheme ().isPresent ());
159- TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
163+ TopNRowNumberNode topNRowNumberNode = convertToTopNRank (windowNode , limit );
160164 if (windowNode .getPartitionBy ().isEmpty ()) {
161165 return topNRowNumberNode ;
162166 }
@@ -183,7 +187,8 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183187 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
184188 }
185189 }
186- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
190+ else if ((source instanceof WindowNode && canOptimizeRowNumberFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
191+ (source instanceof WindowNode && canOptimizeRankFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ))) {
187192 WindowNode windowNode = (WindowNode ) source ;
188193 VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
189194 OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
@@ -280,19 +285,52 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
280285 idAllocator .getNextId (),
281286 windowNode .getSource (),
282287 windowNode .getSpecification (),
288+
283289 TopNRowNumberNode .RankingFunction .ROW_NUMBER ,
284290 getOnlyElement (windowNode .getCreatedVariable ()),
285291 limit ,
286292 false ,
287293 Optional .empty ());
288294 }
289295
296+ private TopNRowNumberNode convertToTopNRank (WindowNode windowNode , int limit )
297+ {
298+ String windowFunction = Iterables .getOnlyElement (windowNode .getWindowFunctions ().values ()).getFunctionCall ().getFunctionHandle ().getName ();
299+ String [] parts = windowFunction .split ("\\ ." );
300+ String windowFunctionName = parts [parts .length - 1 ];
301+ TopNRowNumberNode .RankingFunction rankingFunction ;
302+ switch (windowFunctionName ) {
303+ case "row_number" :
304+ rankingFunction = TopNRowNumberNode .RankingFunction .ROW_NUMBER ;
305+ break ;
306+ case "rank" :
307+ rankingFunction = TopNRowNumberNode .RankingFunction .RANK ;
308+ break ;
309+ case "dense_rank" :
310+ rankingFunction = TopNRowNumberNode .RankingFunction .DENSE_RANK ;
311+ break ;
312+ default :
313+ throw new IllegalArgumentException ("Unsupported window function for TopNRowNumberNode: " + windowFunctionName );
314+ }
315+
316+ return new TopNRowNumberNode (
317+ windowNode .getSourceLocation (),
318+ idAllocator .getNextId (),
319+ windowNode .getSource (),
320+ windowNode .getSpecification (),
321+ rankingFunction ,
322+ getOnlyElement (windowNode .getCreatedVariable ()),
323+ limit ,
324+ false ,
325+ Optional .empty ());
326+ }
327+
290328 private static boolean canReplaceWithRowNumber (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
291329 {
292- return canOptimizeWindowFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
330+ return canOptimizeRowNumberFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
293331 }
294332
295- private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
333+ private static boolean canOptimizeRowNumberFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
296334 {
297335 if (node .getWindowFunctions ().size () != 1 ) {
298336 return false ;
@@ -301,10 +339,27 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTyp
301339 return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
302340 }
303341
342+ private static boolean canOptimizeRankFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
343+ {
344+ if (node .getWindowFunctions ().size () != 1 ) {
345+ return false ;
346+ }
347+ VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
348+ return isRankMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
349+ }
350+
304351 private static boolean isRowNumberMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
305352 {
306353 FunctionHandle rowNumberFunction = functionAndTypeManager .lookupFunction ("row_number" , ImmutableList .of ());
307354 return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rowNumberFunction ));
308355 }
356+
357+ private static boolean isRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
358+ {
359+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("rank" , ImmutableList .of ());
360+ FunctionHandle denseRankFunction = functionAndTypeManager .lookupFunction ("dense_rank" , ImmutableList .of ());
361+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction )) ||
362+ functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (denseRankFunction ));
363+ }
309364 }
310365}
0 commit comments