Skip to content

Commit

Permalink
refactor: rewrite for loops into infinite loops with if statement (#35)
Browse files Browse the repository at this point in the history
Fixes #27
  • Loading branch information
junlarsen authored Feb 4, 2025
1 parent d4d667c commit d95f756
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 65 deletions.
36 changes: 23 additions & 13 deletions compiler/eight-middle/src/ast_lowering_pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,12 @@ impl<'ast, 'hir> AstLoweringPass<'ast, 'hir> {
///
/// {
/// let x = 1;
/// loop (x < 10) {
/// loop {
/// if (!(x < 10)) {
/// break;
/// }
/// { foo(); }
/// { x = x + 1; }
/// x = x + 1;
/// }
/// }
/// ```
Expand Down Expand Up @@ -721,24 +724,31 @@ impl<'ast, 'hir> AstLoweringPass<'ast, 'hir> {
))
});
let increment = node.increment.map(|i| self.visit_expr(i)).transpose()?;
let body = HirBuilder::build_vec(node.body.iter(), |stmt| self.visit_stmt(stmt))?;
let mut body = HirBuilder::build_vec(node.body.iter(), |stmt| self.visit_stmt(stmt))?;
// Build the new block statement with the loop
let hir = HirStmt::Block(HirBuilder::build_block_stmt(
node.span,
vec![
initializer.map(HirStmt::Let).unwrap_or_else(|| {
HirStmt::Block(HirBuilder::build_block_stmt(Span::empty(), vec![]))
}),
HirStmt::Loop(HirBuilder::build_loop_stmt(node.span, condition, {
let mut stmts = body;
if let Some(i) = increment {
stmts.push(HirStmt::Expr(HirExprStmt {
span: i.span(),
expr: i,
}));
}
stmts
})),
HirStmt::Loop(HirBuilder::build_loop_stmt(
node.span,
vec![HirStmt::If(HirBuilder::build_if_stmt(
node.span,
condition,
{
if let Some(i) = increment {
body.push(HirStmt::Expr(HirExprStmt {
span: i.span(),
expr: i,
}));
}
body
},
vec![HirStmt::Break(HirBuilder::build_break_stmt(node.span))],
))],
)),
],
));
self.loop_depth.pop_back();
Expand Down
4 changes: 0 additions & 4 deletions compiler/eight-middle/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,6 @@ pub struct HirExprStmt<'hir> {
#[derive(Debug)]
pub struct HirLoopStmt<'hir> {
pub span: Span,
/// The condition for the loop to continue.
///
/// For infinite loops, this node should be a constant literal of boolean true.
pub condition: HirExpr<'hir>,
pub body: Vec<HirStmt<'hir>>,
}

Expand Down
12 changes: 2 additions & 10 deletions compiler/eight-middle/src/hir_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,8 @@ impl<'hir> HirBuilder {
}

/// Build a HIR loop statement.
pub fn build_loop_stmt(
span: Span,
condition: HirExpr<'hir>,
body: Vec<HirStmt<'hir>>,
) -> HirLoopStmt<'hir> {
HirLoopStmt {
span,
condition,
body,
}
pub fn build_loop_stmt(span: Span, body: Vec<HirStmt<'hir>>) -> HirLoopStmt<'hir> {
HirLoopStmt { span, body }
}

/// Build a HIR block statement.
Expand Down
6 changes: 1 addition & 5 deletions compiler/eight-middle/src/hir_textual_pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,7 @@ impl<'a> HirModuleTextualPass<'a> {
stmt: &'hir HirLoopStmt,
) -> DocBuilder<'a, Arena<'a>> {
self.arena
.text("while")
.append(self.arena.space())
.append(self.arena.text("("))
.append(self.visit_expr(&stmt.condition))
.append(self.arena.text(")"))
.text("loop")
.append(self.arena.space())
.append(self.arena.text("{"))
.append(
Expand Down
11 changes: 0 additions & 11 deletions compiler/eight-middle/src/hir_type_check_pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1047,12 +1047,6 @@ impl HirModuleTypeCheckerPass {
cx: &mut TypingContext<'hir>,
node: &mut HirLoopStmt<'hir>,
) -> HirResult<Option<HirStmt<'hir>>> {
substitute_if_changed!(
&mut node.condition,
Self::enter_expr(cx, &mut node.condition)?
);
// We also impose a new constraint that the condition must be a boolean
cx.infer(&mut node.condition, cx.cc.hir_boolean_type())?;
cx.enter_let_binding_scope();
for stmt in node.body.iter_mut() {
substitute_if_changed!(stmt, Self::enter_stmt(cx, stmt)?);
Expand All @@ -1066,11 +1060,6 @@ impl HirModuleTypeCheckerPass {
cx: &mut TypingContext<'hir>,
node: &mut HirLoopStmt<'hir>,
) -> HirResult<Option<HirStmt<'hir>>> {
substitute_if_changed!(
&mut node.condition,
Self::leave_expr(cx, &mut node.condition)?
);
Self::leave_expr(cx, &mut node.condition)?;
for stmt in node.body.iter_mut() {
substitute_if_changed!(stmt, Self::leave_stmt(cx, stmt)?);
}
Expand Down
10 changes: 7 additions & 3 deletions compiler/eight-tests/tests/snapshots/hir/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ hir_module {
} as vec512);
{
let i: i32 = (0 as i32);
while (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((i as i32),((Div<i32,i32,i32>::div::<> as fn(i32, i32)->i32)((512 as i32),(4 as i32)) as i32)) as bool)) {
((((vs as vec512).vs as *i32)[(i as i32)] as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((((x as vec512).vs as *i32)[(i as i32)] as i32),(((y as vec512).vs as *i32)[(i as i32)] as i32)) as i32),(((acc as vec512).vs as *i32)[(i as i32)] as i32)) as i32) as unit);
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
loop {
if (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((i as i32),((Div<i32,i32,i32>::div::<> as fn(i32, i32)->i32)((512 as i32),(4 as i32)) as i32)) as bool)){
((((vs as vec512).vs as *i32)[(i as i32)] as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((((x as vec512).vs as *i32)[(i as i32)] as i32),(((y as vec512).vs as *i32)[(i as i32)] as i32)) as i32),(((acc as vec512).vs as *i32)[(i as i32)] as i32)) as i32) as unit);
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
} else {
break;
}
}
}
return (vs as vec512);
Expand Down
38 changes: 25 additions & 13 deletions compiler/eight-tests/tests/snapshots/hir/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,35 @@ hir_module {
} as Matrix);
{
let i: i32 = (0 as i32);
while (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((i as i32),((a as Matrix).r as i32)) as bool)) {
{
let j: i32 = (0 as i32);
while (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((j as i32),((b as Matrix).c as i32)) as bool)) {
let sum: i32 = (0 as i32);
{
let k: i32 = (0 as i32);
while (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((k as i32),((a as Matrix).c as i32)) as bool)) {
((sum as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((sum as i32),((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((((a as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((i as i32),((a as Matrix).c as i32)) as i32),(k as i32)) as i32)] as i32),(((b as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((k as i32),((b as Matrix).c as i32)) as i32),(j as i32)) as i32)] as i32)) as i32)) as i32) as unit);
((k as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((k as i32),(1 as i32)) as i32) as unit);
loop {
if (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((i as i32),((a as Matrix).r as i32)) as bool)){
{
let j: i32 = (0 as i32);
loop {
if (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((j as i32),((b as Matrix).c as i32)) as bool)){
let sum: i32 = (0 as i32);
{
let k: i32 = (0 as i32);
loop {
if (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((k as i32),((a as Matrix).c as i32)) as bool)){
((sum as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((sum as i32),((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((((a as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((i as i32),((a as Matrix).c as i32)) as i32),(k as i32)) as i32)] as i32),(((b as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((k as i32),((b as Matrix).c as i32)) as i32),(j as i32)) as i32)] as i32)) as i32)) as i32) as unit);
((k as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((k as i32),(1 as i32)) as i32) as unit);
} else {
break;
}
}
}
((((c as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((i as i32),((b as Matrix).c as i32)) as i32),(j as i32)) as i32)] as i32) = (sum as i32) as unit);
((j as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((j as i32),(1 as i32)) as i32) as unit);
} else {
break;
}
}
((((c as Matrix).buf as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)(((Mul<i32,i32,i32>::mul::<> as fn(i32, i32)->i32)((i as i32),((b as Matrix).c as i32)) as i32),(j as i32)) as i32)] as i32) = (sum as i32) as unit);
((j as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((j as i32),(1 as i32)) as i32) as unit);
}
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
} else {
break;
}
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
}
}
return (c as Matrix);
Expand Down
16 changes: 10 additions & 6 deletions compiler/eight-tests/tests/snapshots/hir/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,18 @@ hir_module {
let i: i32 = ((Sub<i32,i32,i32>::sub::<> as fn(i32, i32)->i32)((l as i32),(1 as i32)) as i32);
{
let j: i32 = (l as i32);
while (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((j as i32),(r as i32)) as bool)) {
if (((Ord<i32,i32>::le::<> as fn(i32, i32)->bool)(((xs as *i32)[(j as i32)] as i32),(x as i32)) as bool)){
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
((swap::<> as fn(*i32, *i32)->unit)((&((xs as *i32)[(i as i32)] as i32) as *i32),(&((xs as *i32)[(j as i32)] as i32) as *i32)) as unit);
loop {
if (((Ord<i32,i32>::lt::<> as fn(i32, i32)->bool)((j as i32),(r as i32)) as bool)){
if (((Ord<i32,i32>::le::<> as fn(i32, i32)->bool)(((xs as *i32)[(j as i32)] as i32),(x as i32)) as bool)){
((i as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32) as unit);
((swap::<> as fn(*i32, *i32)->unit)((&((xs as *i32)[(i as i32)] as i32) as *i32),(&((xs as *i32)[(j as i32)] as i32) as *i32)) as unit);
} else {

}
((j as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((j as i32),(1 as i32)) as i32) as unit);
} else {

break;
}
((j as i32) = ((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((j as i32),(1 as i32)) as i32) as unit);
}
}
((swap::<> as fn(*i32, *i32)->unit)((&((xs as *i32)[((Add<i32,i32,i32>::add::<> as fn(i32, i32)->i32)((i as i32),(1 as i32)) as i32)] as i32) as *i32),(&((xs as *i32)[(r as i32)] as i32) as *i32)) as unit);
Expand Down

0 comments on commit d95f756

Please sign in to comment.