4141import com .facebook .presto .sql .relational .RowExpressionDeterminismEvaluator ;
4242import com .facebook .presto .sql .relational .RowExpressionDomainTranslator ;
4343import com .google .common .collect .ImmutableList ;
44+ import com .google .common .collect .Iterables ;
4445
4546import java .util .Map ;
4647import java .util .Optional ;
4748import java .util .OptionalInt ;
4849
50+ import static com .facebook .presto .SystemSessionProperties .isNativeExecutionEnabled ;
51+ import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRank ;
4952import static com .facebook .presto .SystemSessionProperties .isOptimizeTopNRowNumber ;
5053import static com .facebook .presto .common .predicate .Marker .Bound .BELOW ;
5154import static com .facebook .presto .common .type .BigintType .BIGINT ;
@@ -134,6 +137,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
134137 return replaceChildren (node , ImmutableList .of (rewrittenSource ));
135138 }
136139
140+ private boolean canReplaceWithTopNRowNumber (WindowNode node )
141+ {
142+ return (canOptimizeRowNumberFunction (node , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber (session )) ||
143+ (canOptimizeRankFunction (node , metadata .getFunctionAndTypeManager ()) && isOptimizeTopNRank (session ) && isNativeExecutionEnabled (session ));
144+ }
145+
137146 @ Override
138147 public PlanNode visitLimit (LimitNode node , RewriteContext <Void > context )
139148 {
@@ -152,16 +161,18 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
152161 planChanged = true ;
153162 source = rowNumberNode ;
154163 }
155- else if (source instanceof WindowNode && canOptimizeWindowFunction (( WindowNode ) source , metadata . getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber ( session ) ) {
164+ else if (source instanceof WindowNode ) {
156165 WindowNode windowNode = (WindowNode ) source ;
157- // verify that unordered row_number window functions are replaced by RowNumberNode
158- verify (windowNode .getOrderingScheme ().isPresent ());
159- TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
160- if (windowNode .getPartitionBy ().isEmpty ()) {
161- return topNRowNumberNode ;
166+ if (canReplaceWithTopNRowNumber (windowNode )) {
167+ TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber (windowNode , limit );
168+ // Limit can be entirely skipped for row_number without partitioning.
169+ if (windowNode .getPartitionBy ().isEmpty () &&
170+ canOptimizeRowNumberFunction (windowNode , metadata .getFunctionAndTypeManager ())) {
171+ return topNRowNumberNode ;
172+ }
173+ planChanged = true ;
174+ source = topNRowNumberNode ;
162175 }
163- planChanged = true ;
164- source = topNRowNumberNode ;
165176 }
166177 return replaceChildren (node , ImmutableList .of (source ));
167178 }
@@ -183,15 +194,17 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
183194 return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
184195 }
185196 }
186- else if (source instanceof WindowNode && canOptimizeWindowFunction (( WindowNode ) source , metadata . getFunctionAndTypeManager ()) && isOptimizeTopNRowNumber ( session ) ) {
197+ else if (source instanceof WindowNode ) {
187198 WindowNode windowNode = (WindowNode ) source ;
188- VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
189- OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
190-
191- if (upperBound .isPresent ()) {
192- source = convertToTopNRowNumber (windowNode , upperBound .getAsInt ());
193- planChanged = true ;
194- return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
199+ if (canReplaceWithTopNRowNumber (windowNode )) {
200+ VariableReferenceExpression rowNumberVariable = getOnlyElement (windowNode .getCreatedVariable ());
201+ OptionalInt upperBound = extractUpperBound (tupleDomain , rowNumberVariable );
202+
203+ if (upperBound .isPresent ()) {
204+ source = convertToTopNRowNumber (windowNode , upperBound .getAsInt ());
205+ planChanged = true ;
206+ return rewriteFilterSource (node , source , rowNumberVariable , upperBound .getAsInt ());
207+ }
195208 }
196209 }
197210 return replaceChildren (node , ImmutableList .of (source ));
@@ -275,12 +288,30 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
275288
276289 private TopNRowNumberNode convertToTopNRowNumber (WindowNode windowNode , int limit )
277290 {
291+ String windowFunction = Iterables .getOnlyElement (windowNode .getWindowFunctions ().values ()).getFunctionCall ().getFunctionHandle ().getName ();
292+ String [] parts = windowFunction .split ("\\ ." );
293+ String windowFunctionName = parts [parts .length - 1 ];
294+ TopNRowNumberNode .RankingFunction rankingFunction ;
295+ switch (windowFunctionName ) {
296+ case "row_number" :
297+ rankingFunction = TopNRowNumberNode .RankingFunction .ROW_NUMBER ;
298+ break ;
299+ case "rank" :
300+ rankingFunction = TopNRowNumberNode .RankingFunction .RANK ;
301+ break ;
302+ case "dense_rank" :
303+ rankingFunction = TopNRowNumberNode .RankingFunction .DENSE_RANK ;
304+ break ;
305+ default :
306+ throw new IllegalArgumentException ("Unsupported window function for TopNRowNumberNode: " + windowFunctionName );
307+ }
308+
278309 return new TopNRowNumberNode (
279310 windowNode .getSourceLocation (),
280311 idAllocator .getNextId (),
281312 windowNode .getSource (),
282313 windowNode .getSpecification (),
283- TopNRowNumberNode . RankingFunction . ROW_NUMBER ,
314+ rankingFunction ,
284315 getOnlyElement (windowNode .getCreatedVariable ()),
285316 limit ,
286317 false ,
@@ -289,22 +320,37 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
289320
290321 private static boolean canReplaceWithRowNumber (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
291322 {
292- return canOptimizeWindowFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
323+ return canOptimizeRowNumberFunction (node , functionAndTypeManager ) && !node .getOrderingScheme ().isPresent ();
324+ }
325+
326+ private static boolean canOptimizeRowNumberFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
327+ {
328+ if (node .getWindowFunctions ().size () != 1 ) {
329+ return false ;
330+ }
331+ return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (getOnlyElement (node .getWindowFunctions ().values ()).getFunctionHandle ()));
293332 }
294333
295- private static boolean canOptimizeWindowFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
334+ private static boolean canOptimizeRankFunction (WindowNode node , FunctionAndTypeManager functionAndTypeManager )
296335 {
297336 if (node .getWindowFunctions ().size () != 1 ) {
298337 return false ;
299338 }
300- VariableReferenceExpression rowNumberVariable = getOnlyElement (node .getWindowFunctions ().keySet ());
301- return isRowNumberMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (node .getWindowFunctions ().get (rowNumberVariable ).getFunctionHandle ()));
339+ return isRankMetadata (functionAndTypeManager , functionAndTypeManager .getFunctionMetadata (getOnlyElement (node .getWindowFunctions ().values ()).getFunctionHandle ()));
302340 }
303341
304342 private static boolean isRowNumberMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
305343 {
306344 FunctionHandle rowNumberFunction = functionAndTypeManager .lookupFunction ("row_number" , ImmutableList .of ());
307345 return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rowNumberFunction ));
308346 }
347+
348+ private static boolean isRankMetadata (FunctionAndTypeManager functionAndTypeManager , FunctionMetadata functionMetadata )
349+ {
350+ FunctionHandle rankFunction = functionAndTypeManager .lookupFunction ("rank" , ImmutableList .of ());
351+ FunctionHandle denseRankFunction = functionAndTypeManager .lookupFunction ("dense_rank" , ImmutableList .of ());
352+ return functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (rankFunction )) ||
353+ functionMetadata .equals (functionAndTypeManager .getFunctionMetadata (denseRankFunction ));
354+ }
309355 }
310356}
0 commit comments