Skip to content

Commit

Permalink
use dupnoneed on rgrads
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 20, 2025
1 parent d6d2962 commit 490ae8c
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions src/execution/llvm/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ impl CodegenModule for LlvmModule {
self.codegen_mut().compile_gradient(
*func_id,
&[
CompileGradientArgType::Dup,
CompileGradientArgType::Dup,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
Expand Down Expand Up @@ -379,8 +379,8 @@ impl CodegenModule for LlvmModule {
*func_id,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Dup,
CompileGradientArgType::Dup,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
Expand Down Expand Up @@ -416,8 +416,8 @@ impl CodegenModule for LlvmModule {
*func_id,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Dup,
CompileGradientArgType::Dup,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
Expand All @@ -444,7 +444,10 @@ impl CodegenModule for LlvmModule {
) -> Result<Self::FuncId> {
self.codegen_mut().compile_gradient(
*func_id,
&[CompileGradientArgType::Dup, CompileGradientArgType::Dup],
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
],
CompileMode::Reverse,
)
}
Expand Down Expand Up @@ -495,16 +498,38 @@ impl CodegenModule for LlvmModule {
}

fn post_autodiff_optimisation(&mut self) -> Result<()> {
self.codegen()
.module()
.print_to_file("post_autodiff_optimisation.ll")
.unwrap();

// remove noinline attribute from barrier function as only needed for enzyme
if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
}

//self.codegen()
// .module()
// .print_to_file("post_autodiff_optimisation.ll")
// .unwrap();

let initialization_config = &InitializationConfig::default();
Target::initialize_all(initialization_config);
let triple = TargetTriple::create(self.0.triple.to_string().as_str());
let target = Target::from_triple(&triple).unwrap();
let machine = target
.create_target_machine(
&triple,
TargetMachine::get_host_cpu_name().to_string().as_str(),
TargetMachine::get_host_cpu_features().to_string().as_str(),
inkwell::OptimizationLevel::Default,
inkwell::targets::RelocMode::Default,
inkwell::targets::CodeModel::Default,
)
.unwrap();

let passes = "default<O3>";
self.codegen_mut()
.module()
.run_passes(passes, &machine, PassBuilderOptions::create())
.map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;

Ok(())
}
}
Expand Down

0 comments on commit 490ae8c

Please sign in to comment.