Skip to content

Commit

Permalink
Support for spawn(body) and sync(id).
Browse files Browse the repository at this point in the history
  • Loading branch information
richardmembarth committed Feb 23, 2015
1 parent a24760c commit aa0dd3c
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 5 deletions.
16 changes: 16 additions & 0 deletions runtime/cpu/cpu_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ void parallel_for(int num_threads, int lower, int upper, void *args, void *fun)
for (int i = 0; i < num_threads; i++)
pool[i].join();
}
int parallel_spawn(void *args, void *fun) {
int (*fun_ptr) (void*) = reinterpret_cast<int (*) (void*)>(fun);
fun_ptr(args);

return 0;
}
void parallel_sync(int id) {
}
#else
// TBB version
void parallel_for(int num_threads, int lower, int upper, void *args, void *fun) {
Expand All @@ -64,5 +72,13 @@ void parallel_for(int num_threads, int lower, int upper, void *args, void *fun)
fun_ptr(args, range.begin(), range.end());
});
}
int parallel_spawn(void *args, void *fun) {
int (*fun_ptr) (void*) = reinterpret_cast<int (*) (void*)>(fun);
fun_ptr(args);

return 0;
}
void parallel_sync(int id) {
}
#endif

2 changes: 2 additions & 0 deletions runtime/cpu/cpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
extern "C" {
// parallel runtime functions
void parallel_for(int num_threads, int lower, int upper, void *args, void *fun);
int parallel_spawn(void *args, void *fun);
void parallel_sync(int id);
}

#endif
2 changes: 2 additions & 0 deletions runtime/platforms/generic.s
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ declare i32 @map_memory(i32, i32, i8*, i32, i32);
declare void @unmap_memory(i32);

declare void @parallel_for(i32, i32, i32, i8*, i8*);
declare i32 @parallel_spawn(i8*, i8*);
declare void @parallel_sync(i32);

2 changes: 2 additions & 0 deletions runtime/platforms/intrinsics_thorin.impala
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ extern "thorin" {
fn spir(int, (int, int, int), (int, int, int), fn() -> ()) -> ();
fn opencl(int, (int, int, int), (int, int, int), fn() -> ()) -> ();
fn parallel(num_threads: int, lower: int, upper: int, body: fn(i32) -> ()) -> ();
fn spawn(body: fn() -> ()) -> int;
fn sync(id: int) -> ();
fn vectorize(vector_length: int, lower: int, upper: int, body: fn() -> ()) -> ();
}
2 changes: 2 additions & 0 deletions src/thorin/be/llvm/llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Lambda* CodeGen::emit_intrinsic(Lambda* lambda) {
case Intrinsic::SPIR: return spir_runtime_->emit_host_code(*this, lambda);
case Intrinsic::OpenCL: return opencl_runtime_->emit_host_code(*this, lambda);
case Intrinsic::Parallel: return emit_parallel(lambda);
case Intrinsic::Spawn: return emit_spawn(lambda);
case Intrinsic::Sync: return emit_sync(lambda);
#ifdef WFV2_SUPPORT
case Intrinsic::Vectorize: return emit_vectorize_continuation(lambda);
#endif
Expand Down
2 changes: 2 additions & 0 deletions src/thorin/be/llvm/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class CodeGen {
private:
Lambda* emit_intrinsic(Lambda*);
Lambda* emit_parallel(Lambda*);
Lambda* emit_spawn(Lambda*);
Lambda* emit_sync(Lambda*);
Lambda* emit_vectorize_continuation(Lambda*);
Lambda* emit_atomic(Lambda*);
void emit_vectorize(u32, llvm::Value*, llvm::Function*, llvm::CallInst*);
Expand Down
93 changes: 90 additions & 3 deletions src/thorin/be/llvm/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ Lambda* CodeGen::emit_parallel(Lambda* lambda) {
// wrapper(void* closure, int lower, int upper)
llvm::Type* wrapper_arg_types[] = { builder_.getInt8PtrTy(0), builder_.getInt32Ty(), builder_.getInt32Ty() };
auto wrapper_ft = llvm::FunctionType::get(builder_.getVoidTy(), wrapper_arg_types, false);
auto wrapper_name = kernel->unique_name() + "_parallel";
auto wrapper_name = kernel->unique_name() + "_parallel_for";
auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft);
runtime_->parallel_for(num_threads, lower, upper, ptr, wrapper);

// set insert point to the wrapper function
auto oldBB = builder_.GetInsertBlock();
auto old_bb = builder_.GetInsertBlock();
auto bb = llvm::BasicBlock::Create(context_, wrapper_name, wrapper);
builder_.SetInsertPoint(bb);

Expand All @@ -79,10 +79,97 @@ Lambda* CodeGen::emit_parallel(Lambda* lambda) {
builder_.CreateRetVoid();

// restore old insert point
builder_.SetInsertPoint(oldBB);
builder_.SetInsertPoint(old_bb);

return lambda->arg(PAR_ARG_RETURN)->as_lambda();
}

enum {
SPAWN_ARG_MEM,
SPAWN_ARG_BODY,
SPAWN_ARG_RETURN,
SPAWN_NUM_ARGS
};

Lambda* CodeGen::emit_spawn(Lambda* lambda) {
auto target = lambda->to()->as_lambda();
assert(target->intrinsic() == Intrinsic::Spawn);
assert(lambda->num_args() >= SPAWN_NUM_ARGS && "required arguments are missing");

auto kernel = lambda->arg(SPAWN_ARG_BODY)->as<Global>()->init()->as_lambda();
const size_t num_kernel_args = lambda->num_args() - SPAWN_NUM_ARGS;

// build parallel-function signature
Array<llvm::Type*> par_args(num_kernel_args);
for (size_t i = 0; i < num_kernel_args; ++i) {
Type type = lambda->arg(i + SPAWN_NUM_ARGS)->type();
par_args[i] = convert(type);
}

// fetch values and create a unified struct which contains all values (closure)
auto closure_type = convert(world_.tuple_type(lambda->arg_fn_type()->args().slice_from_begin(SPAWN_NUM_ARGS)));
llvm::Value* closure = llvm::UndefValue::get(closure_type);
for (size_t i = 0; i < num_kernel_args; ++i)
closure = builder_.CreateInsertValue(closure, lookup(lambda->arg(i + SPAWN_NUM_ARGS)), unsigned(i));

// allocate closure object and write values into it
auto ptr = builder_.CreateAlloca(closure_type, nullptr);
builder_.CreateStore(closure, ptr, false);

// create wrapper function and call the runtime
// wrapper(void* closure)
llvm::Type* wrapper_arg_types[] = { builder_.getInt8PtrTy(0) };
auto wrapper_ft = llvm::FunctionType::get(builder_.getVoidTy(), wrapper_arg_types, false);
auto wrapper_name = kernel->unique_name() + "_parallel_spawn";
auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft);
auto tid = runtime_->parallel_spawn(ptr, wrapper);

// set insert point to the wrapper function
auto old_bb = builder_.GetInsertBlock();
auto bb = llvm::BasicBlock::Create(context_, wrapper_name, wrapper);
builder_.SetInsertPoint(bb);

// extract all arguments from the closure
auto wrapper_args = wrapper->arg_begin();
auto load_ptr = builder_.CreateBitCast(&*wrapper_args, llvm::PointerType::get(closure_type, 0));
auto val = builder_.CreateLoad(load_ptr);
std::vector<llvm::Value*> target_args(num_kernel_args);
for (size_t i = 0; i < num_kernel_args; ++i)
target_args[i] = builder_.CreateExtractValue(val, { unsigned(i) });

// call kernel body
auto par_type = llvm::FunctionType::get(builder_.getVoidTy(), llvm_ref(par_args), false);
auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type);
builder_.CreateCall(kernel_par_func, target_args);
builder_.CreateRetVoid();

// restore old insert point
builder_.SetInsertPoint(old_bb);

// bind parameter of continuation to received handle
auto ret = lambda->arg(SPAWN_ARG_RETURN)->as_lambda();
params_[ret->params().back()] = tid;

return ret;
}

enum {
SYNC_ARG_MEM,
SYNC_ARG_ID,
SYNC_ARG_RETURN,
SYNC_NUM_ARGS
};

Lambda* CodeGen::emit_sync(Lambda* lambda) {
auto target = lambda->to()->as_lambda();
assert(target->intrinsic() == Intrinsic::Sync);
assert(lambda->num_args() == SYNC_NUM_ARGS && "wrong number of arguments");

auto id = lookup(lambda->arg(SYNC_ARG_ID));
runtime_->parallel_sync(id);

return lambda->arg(SYNC_ARG_RETURN)->as_lambda();
}

}

16 changes: 14 additions & 2 deletions src/thorin/be/llvm/runtimes/generic_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,24 @@ llvm::Value* GenericRuntime::munmap(llvm::Value* mem) {

llvm::Value* GenericRuntime::parallel_for(llvm::Value* num_threads, llvm::Value* lower, llvm::Value* upper,
llvm::Value* closure_ptr, llvm::Value* fun_ptr) {
llvm::Value* parallel_args[] = {
llvm::Value* parallel_for_args[] = {
num_threads, lower, upper,
builder_.CreateBitCast(closure_ptr, builder_.getInt8PtrTy()),
builder_.CreateBitCast(fun_ptr, builder_.getInt8PtrTy())
};
return builder_.CreateCall(get("parallel_for"), parallel_args);
return builder_.CreateCall(get("parallel_for"), parallel_for_args);
}

llvm::Value* GenericRuntime::parallel_spawn(llvm::Value* closure_ptr, llvm::Value* fun_ptr) {
llvm::Value* parallel_spawn_args[] = {
builder_.CreateBitCast(closure_ptr, builder_.getInt8PtrTy()),
builder_.CreateBitCast(fun_ptr, builder_.getInt8PtrTy())
};
return builder_.CreateCall(get("parallel_spawn"), parallel_spawn_args);
}

llvm::Value* GenericRuntime::parallel_sync(llvm::Value* id) {
return builder_.CreateCall(get("parallel_sync"), id);
}

}
2 changes: 2 additions & 0 deletions src/thorin/be/llvm/runtimes/generic_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class GenericRuntime : public Runtime {
virtual llvm::Value* munmap(llvm::Value* mem);
virtual llvm::Value* parallel_for(llvm::Value* num_threads, llvm::Value* lower, llvm::Value* upper,
llvm::Value* closure_ptr, llvm::Value* fun_ptr);
virtual llvm::Value* parallel_spawn(llvm::Value* closure_ptr, llvm::Value* fun_ptr);
virtual llvm::Value* parallel_sync(llvm::Value* id);
};

}
Expand Down
2 changes: 2 additions & 0 deletions src/thorin/lambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ void Lambda::set_intrinsic() {
else if (name == "spir") intrinsic_ = Intrinsic::SPIR;
else if (name == "opencl") intrinsic_ = Intrinsic::OpenCL;
else if (name == "parallel") intrinsic_ = Intrinsic::Parallel;
else if (name == "spawn") intrinsic_ = Intrinsic::Spawn;
else if (name == "sync") intrinsic_ = Intrinsic::Sync;
else if (name == "vectorize") intrinsic_ = Intrinsic::Vectorize;
else if (name == "mmap") intrinsic_ = Intrinsic::Mmap;
else if (name == "munmap") intrinsic_ = Intrinsic::Munmap;
Expand Down
2 changes: 2 additions & 0 deletions src/thorin/lambda.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ enum class Intrinsic : uint8_t {
SPIR, ///< Internal SPIR-Backend.
OpenCL, ///< Internal OpenCL-Backend.
Parallel, ///< Internal Parallel-CPU-Backend.
Spawn, ///< Internal Parallel-CPU-Backend.
Sync, ///< Internal Parallel-CPU-Backend.
Vectorize, ///< External vectorizer.
_Accelerator_End,
Mmap = _Accelerator_End, ///< Intrinsic memory-mapping function.
Expand Down

0 comments on commit aa0dd3c

Please sign in to comment.