Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Parallel execution #259

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions brian2cuda/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(self):
# list of pre/post ID arrays that are not needed in device memory
self.delete_synaptic_pre = {}
self.delete_synaptic_post = {}
# dictionary to store parallalelization information
self.stream_info = {}
# The following nested dictionary collects all codeobjects that use random
# number generation (RNG).
self.codeobjects_with_rng = {
Expand Down Expand Up @@ -359,6 +361,7 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
template_kwds["sm_multiplier"] = prefs["devices.cuda_standalone.SM_multiplier"]
template_kwds["syn_launch_bounds"] = prefs["devices.cuda_standalone.syn_launch_bounds"]
template_kwds["calc_occupancy"] = prefs["devices.cuda_standalone.calc_occupancy"]
template_kwds["stream_info"] = self.stream_info
if template_name in ["threshold", "spikegenerator"]:
template_kwds["extra_threshold_kernel"] = prefs["devices.cuda_standalone.extra_threshold_kernel"]
codeobj = super(CUDAStandaloneDevice, self).code_object(owner, name, abstract_code, variables,
Expand All @@ -374,7 +377,7 @@ def check_openmp_compatible(self, nb_threads):
if nb_threads > 0:
raise NotImplementedError("Using OpenMP in a CUDA standalone project is not supported")

def generate_objects_source(self, writer, arange_arrays, synapses, static_array_specs, networks):
def generate_objects_source(self, writer, arange_arrays, synapses, static_array_specs, networks, stream_info):
sm_multiplier = prefs.devices.cuda_standalone.SM_multiplier
num_parallel_blocks = prefs.devices.cuda_standalone.parallel_blocks
curand_generator_type = prefs.devices.cuda_standalone.random_number_generator_type
Expand All @@ -393,6 +396,9 @@ def generate_objects_source(self, writer, arange_arrays, synapses, static_array_
for syn in synapses:
if syn.multisynaptic_index is not None:
multisyn_vars.append(syn.variables[syn.multisynaptic_index])
# get number of unique streams

num_stream = max(Counter(stream_info).values())
arr_tmp = self.code_object_class().templater.objects(
None, None,
array_specs=self.arrays,
Expand All @@ -415,7 +421,9 @@ def generate_objects_source(self, writer, arange_arrays, synapses, static_array_
eventspace_arrays=self.eventspace_arrays,
spikegenerator_eventspaces=self.spikegenerator_eventspaces,
multisynaptic_idx_vars=multisyn_vars,
profiled_codeobjects=self.profiled_codeobjects)
profiled_codeobjects=self.profiled_codeobjects,
parallelize=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should become a preference later on. Just putting it here as TODO, not to forget.

stream_size=num_stream)
# Reinsert deleted entries, in case we use self.arrays later? maybe unnecassary...
self.arrays.update(self.eventspace_arrays)
writer.write('objects.*', arr_tmp)
Expand Down Expand Up @@ -445,7 +453,8 @@ def generate_main_source(self, writer):
# For codeobjects run every tick, this happens in the init() of
# the random number buffer called at first clock cycle of the network
main_lines.append('random_number_buffer.ensure_enough_curand_states();')
main_lines.append(f'_run_{codeobj.name}();')
# add stream - default
main_lines.append(f'_run_{codeobj.name}(0);')
elif func == 'after_run_code_object':
codeobj, = args
main_lines.append(f'_after_run_{codeobj.name}();')
Expand Down Expand Up @@ -986,10 +995,14 @@ def generate_network_source(self, writer):
maximum_run_time = self._maximum_run_time
if maximum_run_time is not None:
maximum_run_time = float(maximum_run_time)
num_stream = max(Counter(self.stream_info).values())
network_tmp = self.code_object_class().templater.network(None, None,
maximum_run_time=maximum_run_time,
eventspace_arrays=self.eventspace_arrays,
spikegenerator_eventspaces=self.spikegenerator_eventspaces)
spikegenerator_eventspaces=self.spikegenerator_eventspaces,
parallelize = True,
stream_info = self.stream_info,
num_stream= num_stream)
writer.write('network.*', network_tmp)

def generate_synapses_classes_source(self, writer):
Expand Down Expand Up @@ -1310,7 +1323,7 @@ def build(self, directory='output',

self.generate_objects_source(self.writer, self.arange_arrays,
net_synapses, self.static_array_specs,
self.networks)
self.networks, self.stream_info)
self.generate_network_source(self.writer)
self.generate_synapses_classes_source(self.writer)
self.generate_run_source(self.writer)
Expand Down Expand Up @@ -1382,6 +1395,25 @@ def network_run(self, net, duration, report=None, report_period=10*second,
self.clocks.update(net._clocks)
net.t_ = float(t_end)


# Create dictionary for parallelisation with stream
streams_organization = defaultdict(list)
for obj in net.sorted_objects:
streams_organization[(obj.when, obj.order)].append(obj)

# associate each code object with a particular stream
streams_details = defaultdict(list)
count = 1
for key in streams_organization:
for object in streams_organization[key]:
streams_details[object.name] = count
count +=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, lets make the default 0. Or do we even need a default? Can't we just pass 0 to the kernel (which would run it in the actual CUDA default stream)? Lets check this later.


self.stream_info = streams_details
self.stream_info['default'] = 0



# TODO: remove this horrible hack
for clock in self.clocks:
if clock.name=='clock':
Expand Down Expand Up @@ -1516,11 +1548,21 @@ def network_run(self, net, duration, report=None, report_period=10*second,

# create all random numbers needed for the next clock cycle
for clock in net._clocks:
run_lines.append(f'{net.name}.add(&{clock.name}, _run_random_number_buffer);')
run_lines.append(f'{net.name}.add(&{clock.name}, _run_random_number_buffer, {self.stream_info["default"]});')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The random number buffer is a special case. It is not generated from common_group.cu, but is defined separately in rand.cu. So you don't need to add a stream argument here at all (I think this should even fail, because _run_random_number_buffer in rand.cu is defined without arguments).

For context: The random number buffer has a fixed size of memory on the GPU (which can be controlled via preference). It generates random number from the host, knowing how many random numbers the kernels will require. The kernels then use this data for multiple time steps (where the _run_random_number_buffer only increments the data pointer to the random number). And only when the generated numbers on the GPU are empty, new numbers are generated.

Each random number generation call should generate enough random numbers to occupy the entire GPU. So no need for concurrent kernel execution here at all.


all_clocks = set()
# TODO add for every code object -> add where in the list are there.
# TODO create new dic (code object, position in list)
for clock, codeobj in code_objects:
run_lines.append(f'{net.name}.add(&{clock.name}, _run_{codeobj.name});')
# add this position as additional number here
# check if codeobj.name has _codeobject in it
name = codeobj.name
if "_codeobject" in codeobj.name:
name = codeobj.name[:-11]
if name in self.stream_info.keys():
run_lines.append(f'{net.name}.add(&{clock.name}, _run_{codeobj.name}, {self.stream_info[name]});')
else:
run_lines.append(f'{net.name}.add(&{clock.name}, _run_{codeobj.name}, {self.stream_info["default"]});')
all_clocks.add(clock)

# Under some rare circumstances (e.g. a NeuronGroup only defining a
Expand Down
8 changes: 4 additions & 4 deletions brian2cuda/templates/common_group.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ _run_kernel_{{codeobj_name}}(
{% endblock kernel %}


void _run_{{codeobj_name}}()
void _run_{{codeobj_name}}(cudaStream_t stream)
{
using namespace brian;

Expand Down Expand Up @@ -292,7 +292,7 @@ void _run_{{codeobj_name}}()
{% endblock %}

{% block kernel_call %}
_run_kernel_{{codeobj_name}}<<<num_blocks, num_threads>>>(
_run_kernel_{{codeobj_name}}<<<num_blocks, num_threads, 0, stream>>>(
_N,
num_threads,
///// HOST_PARAMETERS /////
Expand Down Expand Up @@ -326,7 +326,7 @@ void _run_{{codeobj_name}}()
#ifndef _INCLUDED_{{codeobj_name}}
#define _INCLUDED_{{codeobj_name}}

void _run_{{codeobj_name}}();
void _run_{{codeobj_name}}(cudaStream_t);

{% block extra_functions_h %}
{% endblock %}
Expand Down Expand Up @@ -362,7 +362,7 @@ void _after_run_{{codeobj_name}}()
}
{% endmacro %}


// {{codeobj_name}}
{% macro after_run_h_file() %}
#ifndef _INCLUDED_{{codeobj_name}}_after
#define _INCLUDED_{{codeobj_name}}_affer
Expand Down
66 changes: 55 additions & 11 deletions brian2cuda/templates/network.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,33 @@

double Network::_last_run_time = 0.0;
double Network::_last_run_completed_fraction = 0.0;
{% if parallelize %}
cudaStream_t custom_stream[{{num_stream}}];
{% endif %}

Network::Network()
{
t = 0.0;
{% if parallelize %}
for(int i=0;i<{{num_stream}};i++){
CUDA_SAFE_CALL(cudaStreamCreate(&(custom_stream[i])));
}
{% endif %}
}

void Network::clear()
{
objects.clear();
}

void Network::add(Clock *clock, codeobj_func func)
// TODO have to makr change in objects - make it a tuple
// make decision which bject has which stream
void Network::add(Clock *clock, codeobj_func func, int group_num)
{
#if defined(_MSC_VER) && (_MSC_VER>=1700)
objects.push_back(std::make_pair(std::move(clock), std::move(func)));
objects.push_back(std::make_tuple(std::move(clock), std::move(func), std::move(group_num)));
#else
objects.push_back(std::make_pair(clock, func));
objects.push_back(std::make_tuple(clock, func, group_num));
#endif
}

Expand All @@ -56,7 +66,7 @@ void Network::run(const double duration, void (*report_func)(const double, const
Clock* clock = next_clocks();
double elapsed_realtime;
bool did_break_early = false;

//TODO here
while(clock && clock->running())
{
t = clock->t[0];
Expand All @@ -73,17 +83,42 @@ void Network::run(const double duration, void (*report_func)(const double, const
next_report_time += report_period;
}
}
Clock *obj_clock = objects[i].first;
// TODO tuple of clock and function
//Clock *obj_clock = objects[i].first;
Clock *obj_clock = std::get<0>(objects[i]);
int group_int = std::get<2>(objects[i]);
// Only execute the object if it uses the right clock for this step
if (curclocks.find(obj_clock) != curclocks.end())
{
codeobj_func func = objects[i].second;
// function -> whixh is in templates like common_group.cu
// sort the code object - waiting mechanism between groups
// cudaEvent or cudaSynchronise
//codeobj_func func = objects[i].second;
codeobj_func func = std::get<1>(objects[i]);
int func_group_int = std::get<2>(objects[i]);
if (func) // code objects can be NULL in cases where we store just the clock
{
func();
func_groups[func_group_int].push_back(func);
//func_groups.push_back(std::make_pair(func_group_int,func));
//func();
// [[func1,func2,func3],[func4...]]
}
}
}

// get maximum in objects.cu array

// go through each list of func group - 2 loops
for(int i=0; i<func_groups.size(); i++){
for(int j=0; j<func_groups[i].size(); j++){
codeobj_func func = func_groups[i][j];
func(custom_stream[j]);
}
// reset the func group for that sub stream
cudaDeviceSynchronize();
func_groups[i].resize(0);
}

for(std::set<Clock*>::iterator i=curclocks.begin(); i!=curclocks.end(); i++)
(*i)->tick();
clock = next_clocks();
Expand Down Expand Up @@ -129,7 +164,8 @@ void Network::compute_clocks()
clocks.clear();
for(int i=0; i<objects.size(); i++)
{
Clock *clock = objects[i].first;
Clock *clock = std::get<0>(objects[i]);
// Clock *clock = std::get<0>()objects[i].first;
clocks.insert(clock);
}
}
Expand Down Expand Up @@ -174,22 +210,30 @@ Clock* Network::next_clocks()
#include <ctime>
#include "brianlib/clocks.h"

typedef void (*codeobj_func)();
typedef void (*codeobj_func)(cudaStream_t);

class Network
{
std::set<Clock*> clocks, curclocks;
void compute_clocks();
Clock* next_clocks();
public:
std::vector< std::pair< Clock*, codeobj_func > > objects;
// TODO vectory of tuples having clock , codeobj_func and stread integer
std::vector< std::tuple< Clock*, codeobj_func, int > > objects;
//std::vector< std::pair< Clock*, codeobj_func > > objects;
std::vector<std::vector<codeobj_func >> func_groups = std::vector<std::vector<codeobj_func >>({{num_stream}});
//std::vector<std::pair< int, codeobj_func >> func_groups;
double t;
static double _last_run_time;
static double _last_run_completed_fraction;
int num_streams;
{% if parallelize %}
cudaStream_t custom_stream[{{num_stream}}];
{% endif %}

Network();
void clear();
void add(Clock *clock, codeobj_func func);
void add(Clock *clock, codeobj_func func, int num_streams);
void run(const double duration, void (*report_func)(const double, const double, const double, const double), const double report_period);
};

Expand Down
21 changes: 21 additions & 0 deletions brian2cuda/templates/objects.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ const int brian::_num_{{varname}} = {{var.size}};
{% endif %}
{% endfor %}


///////////////// array of streams for parallelization //////////////////////////
// {% if parallelize %}
// cudaStream_t brian::custom_stream[{{stream_size}}];
// {% endif %}

//////////////// eventspaces ///////////////
// we dynamically create multiple eventspaces in no_or_const_delay_mode
// for initiating the first spikespace, we need a host pointer
Expand Down Expand Up @@ -226,6 +232,14 @@ void _init_arrays()
);
{% endif %}

// {% if parallelize %}
// for(int i=0;i<{{stream_size}};i++){
// CUDA_SAFE_CALL(cudaStreamCreate(&(custom_stream[i])));
// }
// {% endif %}



// this sets seed for host and device api RNG
random_number_buffer.set_seed(seed);

Expand Down Expand Up @@ -546,6 +560,7 @@ typedef {{curand_float_type}} randomNumber_t; // random number type
#include "network.h"
#include "rand.h"

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <curand.h>
#include <curand_kernel.h>
Expand Down Expand Up @@ -597,6 +612,12 @@ extern thrust::device_vector<{{c_data_type(var.dtype)}}*> addresses_monitor_{{va
extern thrust::device_vector<{{c_data_type(var.dtype)}}>* {{varname}};
{% endfor %}

//////////////// stream ////////////
// {% if parallelize %}
// extern cudaStream_t custom_stream[{{stream_size}}];
// {% endif %}


/////////////// static arrays /////////////
{% for (name, dtype_spec, N, filename) in static_array_specs | sort %}
{# arrays that are initialized from static data are already declared #}
Expand Down
9 changes: 5 additions & 4 deletions brian2cuda/templates/rand.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ namespace {


// need a function pointer for Network::add(), can't pass a pointer to a class
// method, which is of different type
void _run_random_number_buffer()
// method, which is of different type. Random number buffer runs in default
// stream always, the `stream` parameter is not used.
void _run_random_number_buffer(cudaStream_t stream)
{
// random_number_buffer is a RandomNumberBuffer instance, declared in objects.cu
random_number_buffer.next_time_step();
Expand Down Expand Up @@ -472,7 +473,7 @@ void RandomNumberBuffer::next_time_step()

#include <curand.h>

void _run_random_number_buffer();
void _run_random_number_buffer(cudaStream_t);

class RandomNumberBuffer
{
Expand Down Expand Up @@ -562,4 +563,4 @@ public:

#endif

{% endmacro %}
{% endmacro %}
5 changes: 3 additions & 2 deletions brian2cuda/templates/synapses.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ if ({{pathway.name}}_max_size > 0)
{
if (defaultclock.timestep[0] >= {{pathway.name}}_delay)
{
cudaMemcpy(&num_spiking_neurons,
CUDA_SAFE_CALL(cudaMemcpyAsync(&num_spiking_neurons,
&dev{{_eventspace}}[{{pathway.name}}_eventspace_idx][_num_{{_eventspace}} - 1],
sizeof(int32_t), cudaMemcpyDeviceToHost);
sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
CUDA_SAFE_CALL(cudaStreamSynchronize(stream));
num_blocks = num_parallel_blocks * num_spiking_neurons;
//TODO collect info abt mean, std of num spiking neurons per time
//step and print INFO at end of simulation
Expand Down
Loading