Skip to content

Commit 9f90948

Browse files
Start work on LICM
1 parent ab6a360 commit 9f90948

File tree

6 files changed

+201
-14
lines changed

6 files changed

+201
-14
lines changed

common/src/cfg.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ impl NodeEntry for BasicBlock {
4949
pub struct Dominators<'a> {
5050
cfg: &'a Cfg,
5151
// The dominator set for each node
52+
// If an element is present in set set_per_node[i] then it is a dominator of node i
5253
pub set_per_node: Vec<HashSet<NodeIndex>>,
5354
}
5455

@@ -109,7 +110,7 @@ impl<Data: NodeEntry> DirectedGraph<Data> {
109110
let node_name = self.get_node_name(index);
110111
let node_text = self.nodes[index].data.get_textual_representation();
111112
statements.push(format!(
112-
"\"{node_name}\" [shape=record, label=\"{node_name} | {node_text}\"]",
113+
"\"{node_name}\" [shape=record, label=\"{node_name} \\| idx={index} | {node_text}\"]",
113114
));
114115
// Add more information for each node
115116
}
@@ -356,7 +357,7 @@ impl<'a> Dominators<'a> {
356357
}
357358
}
358359

359-
fn convert_cfg_to_instruction_stream(cfg: Cfg) -> Vec<Code> {
360+
pub fn convert_cfg_to_instruction_stream(cfg: Cfg) -> Vec<Code> {
360361
cfg.dag
361362
.nodes
362363
.into_iter()

dataflow_analysis/src/lib.rs

+37-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ enum Direction {
1919
}
2020

2121
trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
22-
fn run(&self, cfg: &'a Cfg, init: HashSet<ValueType>, direction: Direction) -> () {
22+
fn run(
23+
&self,
24+
cfg: &'a Cfg,
25+
init: HashSet<ValueType>,
26+
direction: Direction,
27+
display: Option<bool>,
28+
) -> (Vec<HashSet<ValueType>>, Vec<HashSet<ValueType>>) {
29+
let display = display.unwrap_or(false);
2330
let all_predecessors: Vec<&[usize]> = (0..cfg.dag.number_of_nodes())
2431
.map(|node_index| cfg.dag.get_predecessor_indices(node_index))
2532
.collect();
@@ -57,7 +64,10 @@ trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
5764
worklist.extend(output_edges[node_index]);
5865
}
5966
}
60-
self.display(cfg, &input_list, &output_list);
67+
if display {
68+
self.display(cfg, &input_list, &output_list);
69+
}
70+
(input_list, output_list)
6171
}
6272

6373
fn display(
@@ -127,11 +137,21 @@ trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
127137
) -> HashSet<ValueType>;
128138
}
129139

130-
struct LiveVariableAnalysis {}
140+
pub struct LiveVariableAnalysis {}
141+
142+
impl LiveVariableAnalysis {
143+
pub fn run_analysis<'a>(
144+
&self,
145+
cfg: &'a Cfg,
146+
display: Option<bool>,
147+
) -> (Vec<HashSet<&'a str>>, Vec<HashSet<&'a str>>) {
148+
self.run(cfg, HashSet::new(), Direction::Backward, display)
149+
}
150+
}
131151

132152
#[derive(Derivative)]
133153
#[derivative(Eq, PartialEq, Hash)]
134-
struct Definition<'a> {
154+
pub struct Definition<'a> {
135155
destination_variable: &'a str,
136156
basic_block_index: usize,
137157
instruction_index: usize,
@@ -152,7 +172,17 @@ impl Clone for Definition<'_> {
152172
}
153173
}
154174

155-
struct ReachingDefinitions {}
175+
pub struct ReachingDefinitions {}
176+
177+
impl ReachingDefinitions {
178+
pub fn run_analysis<'a>(
179+
&self,
180+
cfg: &'a Cfg,
181+
display: Option<bool>,
182+
) -> (Vec<HashSet<Definition<'a>>>, Vec<HashSet<Definition<'a>>>) {
183+
self.run(cfg, HashSet::new(), Direction::Forward, display)
184+
}
185+
}
156186

157187
impl<'a> Analysis<'a, &'a str> for LiveVariableAnalysis {
158188
fn merge(
@@ -313,14 +343,10 @@ pub fn run_analysis(dataflow_analysis_name: DataflowAnalyses, program: &Program)
313343
.map(|f| (f, Cfg::new(f)))
314344
.for_each(|(f, cfg)| match dataflow_analysis_name {
315345
DataflowAnalyses::LiveVariable => {
316-
LiveVariableAnalysis {}.run(&cfg, HashSet::new(), Direction::Backward);
346+
let _ = LiveVariableAnalysis {}.run_analysis(&cfg, Some(true));
317347
}
318348
DataflowAnalyses::ReachingDefinitions => {
319-
ReachingDefinitions {}.run(
320-
&cfg,
321-
create_set_of_definitions_from_function_arguments(&cfg, &f.args),
322-
Direction::Forward,
323-
);
349+
let _ = ReachingDefinitions {}.run_analysis(&cfg, Some(true));
324350
}
325351
});
326352
}

driver/src/main.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ struct Args {
1616
#[arg(short, long, value_enum, help = "Type of dataflow analysis to run")]
1717
dataflow_analysis: Option<DataflowAnalyses>,
1818

19-
#[arg(long, help = "Dump the AST as a DOT file")]
19+
#[arg(
20+
long,
21+
help = "Dump the AST of each function in the program as DOT/Graphviz format"
22+
)]
2023
dump_ast_as_dot: bool,
2124

2225
#[arg(long, help = "Output the program after optimizations")]

optimizations/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = "2021"
77
brilirs = { version = "0.1.0", path = "../bril/brilirs" }
88
clap = "4.5.20"
99
common = { version = "0.1.0", path = "../common" }
10+
dataflow_analysis = { version = "0.1.0", path = "../dataflow_analysis" }
1011
indoc = "2.0.5"
1112
smallstr = "0.3.0"
1213
smallvec = "1.13.2"

optimizations/src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod local_dead_code_elimination;
22
mod local_value_numbering;
3+
mod loop_invariant_code_motion;
34

45
use std::vec;
56

@@ -10,6 +11,7 @@ use common::BasicBlock;
1011
pub enum OptimizationPass {
1112
LocalDeadCodeElimination,
1213
LocalValueNumbering,
14+
LoopInvariantCodeMotion,
1315
}
1416

1517
pub struct PassManager {
@@ -25,6 +27,9 @@ impl PassManager {
2527
OptimizationPass::LocalValueNumbering => {
2628
Box::new(local_value_numbering::LocalValueNumberingPass::new())
2729
}
30+
OptimizationPass::LoopInvariantCodeMotion => {
31+
Box::new(loop_invariant_code_motion::LoopInvariantCodeMotionPass::new())
32+
}
2833
}
2934
}
3035
pub fn new() -> Self {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use crate::Pass;
2+
use common::cfg::{self, Cfg, Dominators};
3+
use dataflow_analysis::{run_analysis, DataflowAnalyses, ReachingDefinitions};
4+
use std::collections::HashSet;
5+
6+
pub struct LoopInvariantCodeMotionPass {}
7+
8+
impl Pass for LoopInvariantCodeMotionPass {
9+
fn apply(&mut self, mut program: bril_rs::Program) -> bril_rs::Program {
10+
let mut output_program = cfg::convert_to_ssa(program);
11+
for function in output_program.functions.iter_mut() {
12+
function.instrs = common::cfg::convert_cfg_to_instruction_stream(
13+
self.process_cfg(Cfg::new(function)),
14+
);
15+
}
16+
output_program
17+
}
18+
}
19+
20+
fn find_back_edges(
21+
cfg: &Cfg,
22+
dominators: &Dominators,
23+
) -> HashSet<(usize /*Start index*/, usize /*End index*/)> {
24+
if cfg.dag.number_of_nodes() == 0 {
25+
return HashSet::new();
26+
}
27+
let mut back_edges = HashSet::new();
28+
let mut visited = vec![false; cfg.dag.number_of_nodes()];
29+
let mut nodes_to_visit = vec![0];
30+
while !nodes_to_visit.is_empty() {
31+
let node = nodes_to_visit.pop().unwrap();
32+
visited[node] = true;
33+
for &successor in cfg.dag.get_successor_indices(node) {
34+
// If the successor is visited and is a dominator of the current node, then it is a back edge.
35+
if visited[successor] && dominators.set_per_node[node].contains(&successor) {
36+
back_edges.insert((node, successor));
37+
} else {
38+
nodes_to_visit.push(successor);
39+
}
40+
}
41+
}
42+
back_edges
43+
}
44+
45+
fn find_loop_nodes(
46+
cfg: &Cfg,
47+
dominators: &Dominators,
48+
loop_header: usize,
49+
seed: usize,
50+
) -> Vec<usize> {
51+
let mut loop_nodes = vec![loop_header];
52+
let mut visited = vec![false; cfg.dag.number_of_nodes()];
53+
visited[loop_header] = true;
54+
let mut nodes_to_visit = vec![seed];
55+
while !nodes_to_visit.is_empty() {
56+
let node = nodes_to_visit.pop().unwrap();
57+
visited[node] = true;
58+
loop_nodes.push(node);
59+
for &predecessor in cfg.dag.get_predecessor_indices(node) {
60+
// If the predecessor is not visited and if the loop header dominates the predecessor, then add it to the list of nodes to visit.
61+
// All nodes in the loop should have the loop header as a dominator.
62+
if !visited[predecessor] && dominators.set_per_node[predecessor].contains(&loop_header)
63+
{
64+
nodes_to_visit.push(predecessor);
65+
}
66+
}
67+
}
68+
loop_nodes
69+
}
70+
71+
impl LoopInvariantCodeMotionPass {
72+
pub fn new() -> Self {
73+
LoopInvariantCodeMotionPass {}
74+
}
75+
fn process_cfg(&mut self, cfg: Cfg) -> Cfg {
76+
// Precondition: We make the assumption that the CFG is reducible.
77+
let dominators = cfg::Dominators::new(&cfg);
78+
let (reaching_definitions_in, reaching_definitions_out) =
79+
ReachingDefinitions {}.run_analysis(&cfg, Some(false));
80+
find_back_edges(&cfg, &dominators)
81+
.iter()
82+
.map(|(src, loop_header)| {
83+
// For each back edge, find the loop nodes.
84+
assert!(
85+
dominators.set_per_node[*src].contains(&loop_header),
86+
"{src}->{loop_header}| dominators of {} are: {:#?}",
87+
loop_header,
88+
dominators.set_per_node[*loop_header]
89+
);
90+
(
91+
*loop_header,
92+
find_loop_nodes(&cfg, &dominators, *loop_header, *src),
93+
)
94+
})
95+
.for_each(
96+
|(loop_header, loop_nodes)| // For each loop, find the loop invariant instructions.
97+
{
98+
// todo!("Find loop invariant instructions for loop: {:#?}", loop_nodes);
99+
},
100+
);
101+
102+
cfg
103+
}
104+
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use crate::Pass;
109+
use bril_rs::Program;
110+
111+
fn parse_program(text: &str) -> Program {
112+
let program = common::parse_bril_text(text);
113+
assert!(program.is_ok(), "{}", program.err().unwrap());
114+
program.unwrap()
115+
}
116+
117+
#[test]
118+
fn test_loop_invariant_code_motion() {
119+
let program = parse_program(indoc::indoc! {r#"
120+
@main {
121+
n: int = const 10;
122+
inc: int = const 5;
123+
one: int = const 1;
124+
invariant: int = const 100;
125+
i: int = const 0;
126+
sum: int = const 0;
127+
.loop:
128+
cond: bool = lt i n;
129+
br cond .body .done;
130+
.body:
131+
temp: int = add invariant inc;
132+
sum: int = add sum temp;
133+
i: int = add i one;
134+
body_cond: bool = lt temp sum;
135+
br body_cond .body_left .body_right;
136+
.body_left:
137+
jmp .body_join;
138+
.body_right:
139+
dead_store: int = const 0;
140+
jmp .body_join;
141+
.body_join:
142+
jmp .loop;
143+
.done:
144+
print sum;
145+
ret;
146+
}
147+
"#});
148+
let optimized_program = super::LoopInvariantCodeMotionPass::new().apply(program);
149+
println!("{}", optimized_program);
150+
}
151+
}

0 commit comments

Comments
 (0)