4141import com .facebook .presto .sql .relational .RowExpressionDeterminismEvaluator ;
4242import com .facebook .presto .sql .relational .RowExpressionDomainTranslator ;
4343import com .google .common .collect .ImmutableList ;
44+ import com .google .common .collect .Iterables ;
4445
4546import java .util .Map ;
4647import java .util .Optional ;
4748import java .util .OptionalInt ;
4849
50+ import static com .facebook .presto .SystemSessionProperties .isNativeExecutionEnabled ;
51+ import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRank ;
4952import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRowNumber ;
5053import static com .facebook .presto .common .predicate .Marker .Bound .BELOW ;
5154import static com .facebook .presto .common .type .BigintType .BIGINT ;
@@ -134,6 +137,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
134137 return replaceChildren (node , ImmutableList .of (rewrittenSource ));
135138 }
136139
140+ private boolean canReplaceWithTopNRowNumber (WindowNode node )
141+ {
142+ return (canOptimizeRowNumberFunction (node , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
143+ (canOptimizeRankFunction (node , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ));
144+ }
145+
137146 @ Override
138147 public PlanNode visitLimit (LimitNode node , RewriteContext <Void > context )
139148 {
@@ -152,16 +161,16 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152161 planChanged = true ;
153162 source = rowNumberNode ;
154163 }
155- else if (source instanceof WindowNode && canOptimizeWindowFunction (( WindowNode ) source , metadata . getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber ( session ) ) {
164+ else if (source instanceof WindowNode ) {
156165 WindowNode windowNode = (WindowNode ) source ;
157- // verify that unordered row_number window functions are replaced by RowNumberNode
158- verify (windowNode .getOrderingScheme ().isPresent ());
159- TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
160- if (windowNode .getPartitionBy ().isEmpty ()) {
161- return topNRowNumberNode ;
166+ if (canReplaceWithTopNRowNumber (windowNode )) {
167+ TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
168+ if (windowNode .getPartitionBy ().isEmpty ()) {
169+ return topNRowNumberNode ;
170+ }
171+ planChanged = true ;
172+ source = topNRowNumberNode ;
162173 }
163- planChanged = true ;
164- source = topNRowNumberNode ;
165174 }
166175 return replaceChildren (node , ImmutableList .of (source ));
167176 }
@@ -183,15 +192,17 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183192 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
184193 }
185194 }
186- else if (source instanceof WindowNode && canOptimizeWindowFunction (( WindowNode ) source , metadata . getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber ( session ) ) {
195+ else if (source instanceof WindowNode ) {
187196 WindowNode windowNode = (WindowNode ) source ;
188- VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
189- OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
190-
191- if (upperBound .isPresent ()) {
192- source = convertToTopNRowNumber (windowNode , upperBound .getAsInt ());
193- planChanged = true ;
194- return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
197+ if (canReplaceWithTopNRowNumber (windowNode )) {
198+ VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
199+ OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
200+
201+ if (upperBound .isPresent ()) {
202+ source = convertToTopNRowNumber (windowNode , upperBound .getAsInt ());
203+ planChanged = true ;
204+ return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
205+ }
195206 }
196207 }
197208 return replaceChildren (node , ImmutableList .of (source ));
@@ -287,12 +298,44 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
287298 Optional .empty ());
288299 }
289300
301+ private TopNRowNumberNode convertToTopNRank (WindowNode windowNode , int limit )
302+ {
303+ String windowFunction = Iterables .getOnlyElement (windowNode .getWindowFunctions ().values ()).getFunctionCall ().getFunctionHandle ().getName ();
304+ String [] parts = windowFunction .split ("\\ ." );
305+ String windowFunctionName = parts [parts .length - 1 ];
306+ TopNRowNumberNode .RankingFunction rankingFunction ;
307+ switch (windowFunctionName ) {
308+ case "row_number" :
309+ rankingFunction = TopNRowNumberNode .RankingFunction .ROW_NUMBER ;
310+ break ;
311+ case "rank" :
312+ rankingFunction = TopNRowNumberNode .RankingFunction .RANK ;
313+ break ;
314+ case "dense_rank" :
315+ rankingFunction = TopNRowNumberNode .RankingFunction .DENSE_RANK ;
316+ break ;
317+ default :
318+ throw new IllegalArgumentException ("Unsupported window function for TopNRowNumberNode: " + windowFunctionName );
319+ }
320+
321+ return new TopNRowNumberNode (
322+ windowNode .getSourceLocation (),
323+ idAllocator .getNextId (),
324+ windowNode .getSource (),
325+ windowNode .getSpecification (),
326+ rankingFunction ,
327+ getOnlyElement (windowNode .getCreatedVariable ()),
328+ limit ,
329+ false ,
330+ Optional .empty ());
331+ }
332+
290333 private static boolean canReplaceWithRowNumber (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
291334 {
292- return canOptimizeWindowFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
335+ return canOptimizeRowNumberFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
293336 }
294337
295- private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
338+ private static boolean canOptimizeRowNumberFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
296339 {
297340 if (node .getWindowFunctions ().size () != 1 ) {
298341 return false ;
@@ -301,10 +344,27 @@ private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTyp
301344 return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
302345 }
303346
347+ private static boolean canOptimizeRankFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
348+ {
349+ if (node .getWindowFunctions ().size () != 1 ) {
350+ return false ;
351+ }
352+ VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
353+ return isRankMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
354+ }
355+
304356 private static boolean isRowNumberMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
305357 {
306358 FunctionHandle rowNumberFunction = functionAndTypeManager .lookupFunction ("row_number" , ImmutableList .of ());
307359 return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rowNumberFunction ));
308360 }
361+
362+ private static boolean isRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
363+ {
364+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("rank" , ImmutableList .of ());
365+ FunctionHandle denseRankFunction = functionAndTypeManager .lookupFunction ("dense_rank" , ImmutableList .of ());
366+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction )) ||
367+ functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (denseRankFunction ));
368+ }
309369 }
310370}
0 commit comments