Skip to content

Commit

Permalink
Task spawn + sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Arsene Perard-Gayot committed Feb 24, 2015
1 parent aa0dd3c commit 25f5bf4
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions runtime/cpu/cpu_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
#include <iostream>
#include <thread>
#include <vector>
#include <unordered_map>
#include <cassert>

#include "cpu_runtime.h"
#include "thorin_runtime.h"

#ifdef USE_TBB
#include "tbb/tbb.h"
#include "tbb/parallel_for.h"
#include "tbb/task_scheduler_init.h"
#endif
Expand All @@ -25,6 +28,9 @@ void thorin_free(void *ptr) {
void thorin_print_total_timing() { }

#ifndef USE_TBB
static std::unordered_map<int, std::thread> thread_pool;
static std::vector<int> free_ids;

// C++11 threads version
void parallel_for(int num_threads, int lower, int upper, void *args, void *fun) {
// Get number of available hardware threads
Expand Down Expand Up @@ -56,11 +62,28 @@ void parallel_for(int num_threads, int lower, int upper, void *args, void *fun)
}
int parallel_spawn(void *args, void *fun) {
int (*fun_ptr) (void*) = reinterpret_cast<int (*) (void*)>(fun);
fun_ptr(args);

return 0;
int id;
if (free_ids.size()) {
id = free_ids.back();
free_ids.pop_back();
} else {
id = thread_pool.size();
}

auto spawned = std::make_pair(id, std::thread([=](){ fun_ptr(args); }));
thread_pool.emplace(std::move(spawned));
return id;
}
void parallel_sync(int id) {
auto thread = thread_pool.find(id);
if (thread != thread_pool.end()) {
thread->second.join();
free_ids.push_back(thread->first);
thread_pool.erase(thread);
} else {
assert(0 && "Trying to synchronize on invalid thread id");
}
}
#else
// TBB version
Expand All @@ -72,13 +95,52 @@ void parallel_for(int num_threads, int lower, int upper, void *args, void *fun)
fun_ptr(args, range.begin(), range.end());
});
}

static std::unordered_map<int, tbb::task*> task_pool;
static std::vector<int> free_ids;

class RuntimeTask : public tbb::task {
public:
RuntimeTask(void* args, void* fun)
: args_(args), fun_(fun)
{}

tbb::task* execute() {
int (*fun_ptr) (void*) = reinterpret_cast<int (*) (void*)>(fun_);
fun_ptr(args_);
set_ref_count(1);
return nullptr;
}

private:
void* args_;
void* fun_;
};

int parallel_spawn(void *args, void *fun) {
int (*fun_ptr) (void*) = reinterpret_cast<int (*) (void*)>(fun);
fun_ptr(args);
int id;
if (free_ids.size()) {
id = free_ids.back();
free_ids.pop_back();
} else {
id = task_pool.size();
}

return 0;
tbb::task* task = new (tbb::task::allocate_root()) RuntimeTask(args, fun);
tbb::task::spawn(*task);
task_pool[id] = task;
return id;
}
void parallel_sync(int id) {
auto task = task_pool.find(id);
if (task != task_pool.end()) {
task->second->wait_for_all();
tbb::task::destroy(*task->second);
free_ids.push_back(task->first);
task_pool.erase(task);
} else {
assert(0 && "Trying to synchronize on invalid task id");
}
}
#endif

0 comments on commit 25f5bf4

Please sign in to comment.