4141import com .facebook .presto .sql .relational .RowExpressionDeterminismEvaluator ;
4242import com .facebook .presto .sql .relational .RowExpressionDomainTranslator ;
4343import com .google .common .collect .ImmutableList ;
44+ import com .google .common .collect .Iterables ;
4445
4546import java .util .Map ;
4647import java .util .Optional ;
4748import java .util .OptionalInt ;
4849
4950import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRowNumber ;
51+ import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRank ;
52+ import static com .facebook .presto .SystemSessionProperties .isNativeExecutionEnabled ;
5053import static com .facebook .presto .common .predicate .Marker .Bound .BELOW ;
5154import static com .facebook .presto .common .type .BigintType .BIGINT ;
5255import static com .facebook .presto .expressions .LogicalRowExpressions .TRUE_CONSTANT ;
@@ -152,11 +155,12 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152155 planChanged = true ;
153156 source = rowNumberNode ;
154157 }
155- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
158+ else if ((source instanceof WindowNode && canOptimizeRowNumberFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
159+ (source instanceof WindowNode && canOptimizeRankFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ))) {
156160 WindowNode windowNode = (WindowNode ) source ;
157161 // verify that unordered row_number window functions are replaced by RowNumberNode
158162 verify (windowNode .getOrderingScheme ().isPresent ());
159- TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
163+ TopNRowNumberNode topNRowNumberNode = convertToTopNRank (windowNode , limit );
160164 if (windowNode .getPartitionBy ().isEmpty ()) {
161165 return topNRowNumberNode ;
162166 }
@@ -183,7 +187,8 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183187 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
184188 }
185189 }
186- else if (source instanceof WindowNode && canOptimizeWindowFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) {
190+ else if ((source instanceof WindowNode && canOptimizeRowNumberFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
191+ (source instanceof WindowNode && canOptimizeRankFunction ((WindowNode ) source , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ))) {
187192 WindowNode windowNode = (WindowNode ) source ;
188193 VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
189194 OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
@@ -280,19 +285,51 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
280285 idAllocator .getNextId (),
281286 windowNode .getSource (),
282287 windowNode .getSpecification (),
288+
283289 TopNRowNumberNode .RankingFunction .ROW_NUMBER ,
284290 getOnlyElement (windowNode .getCreatedVariable ()),
285291 limit ,
286292 false ,
287293 Optional .empty ());
288294 }
289295
296+ private TopNRowNumberNode convertToTopNRank (WindowNode windowNode , int limit )
297+ {
298+
299+ String windowFunctionName = Iterables .getOnlyElement (windowNode .getWindowFunctions ().values ()).getFunctionCall ().getFunctionHandle ().getName ();
300+ TopNRowNumberNode .RankingFunction rankingFunction ;
301+ switch (windowFunctionName ) {
302+ case "row_number" :
303+ rankingFunction = TopNRowNumberNode .RankingFunction .ROW_NUMBER ;
304+ break ;
305+ case "rank" :
306+ rankingFunction = TopNRowNumberNode .RankingFunction .RANK ;
307+ break ;
308+ case "dense_rank" :
309+ rankingFunction = TopNRowNumberNode .RankingFunction .DENSE_RANK ;
310+ break ;
311+ default :
312+ throw new IllegalArgumentException ("Unsupported window function for TopNRowNumberNode: " + windowFunctionName );
313+ }
314+
315+ return new TopNRowNumberNode (
316+ windowNode .getSourceLocation (),
317+ idAllocator .getNextId (),
318+ windowNode .getSource (),
319+ windowNode .getSpecification (),
320+ rankingFunction ,
321+ getOnlyElement (windowNode .getCreatedVariable ()),
322+ limit ,
323+ false ,
324+ Optional .empty ());
325+ }
326+
290327 private static boolean canReplaceWithRowNumber (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
291328 {
292- return canOptimizeWindowFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
329+ return canOptimizeRowNumberFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
293330 }
294331
295- private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
332+ private static boolean canOptimizeRowNumberFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
296333 {
297334 if (node .getWindowFunctions ().size () != 1 ) {
298335 return false ;
@@ -301,10 +338,27 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTyp
301338 return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
302339 }
303340
341+ private static boolean canOptimizeRankFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
342+ {
343+ if (node .getWindowFunctions ().size () != 1 ) {
344+ return false ;
345+ }
346+ VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
347+ return isRankMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
348+ }
349+
304350 private static boolean isRowNumberMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
305351 {
306352 FunctionHandle rowNumberFunction = functionAndTypeManager .lookupFunction ("row_number" , ImmutableList .of ());
307353 return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rowNumberFunction ));
308354 }
355+
356+ private static boolean isRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
357+ {
358+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("rank" , ImmutableList .of ());
359+ FunctionHandle denseRankFunction = functionAndTypeManager .lookupFunction ("dense_rank" , ImmutableList .of ());
360+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction )) ||
361+ functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (denseRankFunction ));
362+ }
309363 }
310364}
0 commit comments