-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
This is to address #1275 . We added a new phase to perform tail recursion elimination. We added test cases in the [GenCSuite](https://github.com/epfl-lara/stainless/pull/1626/files#diff-2091c70888d42120d35c352d0b060558b0c5ad02baa02de0ecd77b0fd0ded464). As discussed during the presentation, we may want to take a closer look at ghost elimination and see whether it is doing the job correctly. --------- Co-authored-by: Kacper Korban <[email protected]>
- Loading branch information
1 parent
c92fee2
commit 233cdcd
Showing
39 changed files
with
673 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package stainless | ||
package genc | ||
package ir | ||
|
||
import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation | ||
import Literals._ | ||
import Operators._ | ||
import IRs._ | ||
import scala.collection.mutable | ||
|
||
final class TailRecSimpTransformer extends Transformer(SIR, SIR) with NoEnv { | ||
import from._ | ||
|
||
private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC | ||
|
||
/** | ||
* Replace a variable assignment that is immediately | ||
* returned | ||
* | ||
* val i = f(...); | ||
* return i; | ||
* | ||
* ==> | ||
* | ||
* return f(...); | ||
* | ||
*/ | ||
private def replaceImmediateReturn(fd: Expr): Expr = { | ||
val transformer = new ir.Transformer(from, to) with NoEnv { | ||
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match { | ||
case Block(stmts) => | ||
Block(stmts.zipWithIndex.flatMap { | ||
case (expr @ Decl(id, Some(rhs)), idx) => | ||
stmts.lift(idx + 1) match { | ||
case Some(Return(Binding(retId))) if retId == id => | ||
List(Return(rhs)) | ||
case _ => List(recImpl(expr)._1) | ||
} | ||
case (expr @ Return(Binding(retId)), idx) => | ||
stmts.lift(idx - 1) match { | ||
case Some(Decl(id, rhs)) if id == retId => | ||
Nil | ||
case _ => List(recImpl(expr)._1) | ||
} | ||
case (expr, idx) => List(recImpl(expr)._1) | ||
}) -> () | ||
case expr => super.recImpl(expr) | ||
} | ||
} | ||
transformer(fd) | ||
} | ||
|
||
/** | ||
* Remove all statements after a return statement | ||
* | ||
* return f(...); | ||
* someStmt; | ||
* | ||
* ==> | ||
* | ||
* return f(...); | ||
* | ||
*/ | ||
private def removeAfterReturn(fd: Expr): Expr = { | ||
val transformer = new ir.Transformer(from, to) with NoEnv { | ||
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match { | ||
case Block(stmts) => | ||
val transformedStmts = stmts.map(recImpl(_)._1) | ||
val firstReturn = transformedStmts.find { | ||
case Return(_) => true | ||
case _ => false | ||
}.toList | ||
val newStmts = transformedStmts.takeWhile { | ||
case Return(_) => false | ||
case _ => true | ||
} | ||
Block(newStmts ++ firstReturn) -> () | ||
case expr => super.recImpl(expr) | ||
} | ||
} | ||
transformer(fd) | ||
} | ||
|
||
override protected def recImpl(fd: Expr)(using Env): (to.Expr, Env) = { | ||
val afterReturn = removeAfterReturn(fd) | ||
val immediateReturn = replaceImmediateReturn(afterReturn) | ||
immediateReturn -> () | ||
} | ||
} |
181 changes: 181 additions & 0 deletions
181
core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package stainless | ||
package genc | ||
package ir | ||
|
||
import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation | ||
import Literals._ | ||
import Operators._ | ||
import IRs._ | ||
import scala.collection.mutable | ||
|
||
final class TailRecTransformer(val ctx: inox.Context) extends Transformer(SIR, TIR) with NoEnv { | ||
import from._ | ||
|
||
private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC | ||
|
||
private given printer.Context = printer.Context(0) | ||
|
||
/** | ||
* If the function returns Unit type and the last one statement is a recursive call, | ||
* put the recursive call in a return statement. | ||
* | ||
* Example: | ||
* def countDown(n: Int): Unit = | ||
* if (n == 0) return | ||
* countDown(n - 1) | ||
* | ||
* ==> | ||
* | ||
* def countDown(n: Int): Unit = | ||
* if (n == 0) return | ||
* return countDown(n - 1) | ||
*/ | ||
private def putTailRecursiveUnitCallInReturn(fd: FunDef): FunDef = { | ||
def go(expr: Expr): Expr = expr match { | ||
case Block(stmts) if stmts.nonEmpty => | ||
Block(stmts.init :+ go(stmts.last)) | ||
case IfElse(cond, thenn, elze) => | ||
IfElse(cond, go(thenn), go(elze)) | ||
case app @ App(FunVal(calledFd), _, _) if calledFd.id == fd.id => | ||
Return(app) | ||
case _ => expr | ||
} | ||
fd.body match { | ||
case FunBodyAST(expr) if fd.returnType.isUnitType => | ||
fd.copy(body = FunBodyAST(go(expr))) | ||
case _ => fd | ||
} | ||
} | ||
|
||
private def isTailRecursive(fd: FunDef): Boolean = { | ||
var functionRefs = mutable.ListBuffer.empty[FunDef] | ||
val functionRefVisitor = new ir.Visitor(from) { | ||
override protected def visit(expr: Expr): Unit = expr match { | ||
case FunVal(fd) => functionRefs += fd | ||
case _ => | ||
} | ||
} | ||
var tailFunctionRefs = mutable.ListBuffer.empty[FunDef] | ||
val tailRecCallVisitor = new ir.Visitor(from) { | ||
override protected def visit(expr: Expr): Unit = expr match { | ||
case Return(App(FunVal(fdcall), _, _)) => tailFunctionRefs += fdcall | ||
|
||
case _ => | ||
} | ||
} | ||
functionRefVisitor(fd) | ||
tailRecCallVisitor(fd) | ||
functionRefs.contains(fd) && functionRefs.filter(_ == fd).size == tailFunctionRefs.filter(_ == fd).size | ||
} | ||
|
||
/* Rewrite a tail recursive function to a while loop | ||
* Example: | ||
* def fib(n: Int, i: Int = 0, j: Int = 1): Int = | ||
* if (n == 0) | ||
* return i | ||
* else | ||
* return fib(n-1, j, i+j) | ||
* | ||
* ==> | ||
* | ||
* def fib(n: Int, i: Int = 0, j: Int = 1): Int = { | ||
* | ||
* var n$ = n | ||
* var i$ = i | ||
* var j$ = j | ||
* while (true) { | ||
* someLabel: | ||
* if (n$ == 0) { | ||
* return i$ | ||
* } else { | ||
* val n$1 = n$ - 1 | ||
* val i$1 = j$ | ||
* val j$1 = i$ + j$ | ||
* n$ = n$1 | ||
* i$ = i$1 | ||
* j$ = j$1 | ||
* goto someLabel | ||
* } | ||
* } | ||
* } | ||
* Steps: | ||
* - Create a new variable for each parameter of the function | ||
* - Replace existing parameter references with the new variables | ||
* - Create a while loop with a condition true | ||
* - Replace the recursive return with a variable assignments (updating the state) and a continue statement | ||
*/ | ||
private def rewriteToAWhileLoop(fd: FunDef): FunDef = fd.body match { | ||
case FunBodyAST(body) => | ||
val newParams = fd.params.map(p => ValDef(freshId(p.id), p.typ, isVar = true)) | ||
val newParamMap = fd.params.zip(newParams).toMap | ||
val labelName = freshId("label") | ||
val bodyWithNewParams = replaceBindings(newParamMap, body) | ||
val bodyWithUnitReturn = bodyWithNewParams match { | ||
case Block(stmts) => | ||
if fd.returnType.isUnitType then | ||
Block(stmts :+ Return(Lit(UnitLit))) | ||
else | ||
bodyWithNewParams | ||
case _ => bodyWithNewParams | ||
} | ||
val declarations = newParamMap.toList.map { case (old, nw) => Decl(nw, Some(Binding(old))) } | ||
val newBody = replaceRecursiveCalls(fd, bodyWithUnitReturn, newParams.toList, labelName) | ||
val newBodyWithALabel = Labeled(labelName, newBody) | ||
val newBodyWithAWhileLoop = While(True, newBodyWithALabel) | ||
FunDef(fd.id, fd.returnType, fd.ctx, fd.params, FunBodyAST(Block(declarations :+ newBodyWithAWhileLoop)), fd.isExported, fd.isPure) | ||
case _ => fd | ||
} | ||
|
||
private def replaceRecursiveCalls(fd: FunDef, body: Expr, valdefs: List[ValDef], labelName: String): Expr = { | ||
val replacer = new Transformer(from, from) with NoEnv { | ||
override def recImpl(e: Expr)(using Env): (Expr, Env) = e match { | ||
case Return(App(FunVal(fdcall), _, args)) if fdcall == fd => | ||
val tmpValDefs = valdefs.map(vd => ValDef(freshId(vd.id), vd.typ, isVar = false)) | ||
val tmpDecls = tmpValDefs.zip(args).map { case (vd, arg) => Decl(vd, Some(arg)) } | ||
val valdefAssign = valdefs.zip(tmpValDefs).map { case (vd, tmp) => Assign(Binding(vd), Binding(tmp)) } | ||
Block(tmpDecls ++ valdefAssign :+ Goto(labelName)) -> () | ||
case _ => | ||
super.recImpl(e) | ||
} | ||
} | ||
replacer(body) | ||
} | ||
|
||
/* Replace the bindings in the function body with the mapped variables */ | ||
private def replaceBindings(mapping: Map[ValDef, ValDef], funBody: Expr): Expr = { | ||
val replacer = new Transformer(from, from) with NoEnv { | ||
override protected def rec(vd: ValDef)(using Env): to.ValDef = | ||
mapping.getOrElse(vd, vd) | ||
} | ||
replacer(funBody) | ||
} | ||
|
||
private def replaceWithNewFuns(prog: Prog, newFdsMap: Map[FunDef, FunDef]): Prog = { | ||
val replacer = new Transformer(from, from) with NoEnv { | ||
override protected def recImpl(fd: FunDef)(using Env): FunDef = | ||
super.recImpl(newFdsMap.getOrElse(fd, fd)) | ||
} | ||
replacer(prog) | ||
} | ||
|
||
override protected def rec(prog: from.Prog)(using Unit): to.Prog = { | ||
super.rec { | ||
val newFdsMap = prog.functions.map { fd => | ||
val fdWithTailRecUnitInReturn = putTailRecursiveUnitCallInReturn(fd) | ||
if isTailRecursive(fdWithTailRecUnitInReturn) then | ||
val fdRewrittenToLoop = rewriteToAWhileLoop(fdWithTailRecUnitInReturn) | ||
// val irPrinter = IRPrinter(SIR) | ||
// print(irPrinter.apply(newFd)(using irPrinter.Context(0))) | ||
fd -> fdRewrittenToLoop | ||
else | ||
fd -> fdWithTailRecUnitInReturn | ||
}.toMap | ||
val newProg = Prog(prog.decls, newFdsMap.values.toSeq, prog.classes) | ||
replaceWithNewFuns(newProg, newFdsMap) | ||
} | ||
} | ||
|
||
private def freshId(id: String): to.Id = id + "_" + freshCounter.next(id) | ||
|
||
private val freshCounter = new utils.UniqueCounter[String]() | ||
} |
Oops, something went wrong.