Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yin yang dataset, TTFS readout and event-prop implementation #111

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions examples/event_prop/yin_yang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np

from ml_genn import InputLayer, Layer, SequentialNetwork
from ml_genn.callbacks import Checkpoint, OptimiserParamSchedule, SpikeRecorder, VarRecorder
from ml_genn.compilers import EventPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense, FixedProbability
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrateFire, SpikeInput
from ml_genn.optimisers import Adam
from ml_genn.serialisers import Numpy
from ml_genn.synapses import Exponential

from time import perf_counter
from ml_genn.utils.data import generate_yin_yang_dataset

from ml_genn.compilers.event_prop_compiler import default_params

import matplotlib.pyplot as plt

NUM_INPUT = 4
NUM_HIDDEN = 100
NUM_OUTPUT = 3
BATCH_SIZE = 512
NUM_EPOCHS = 10
NUM_TRAIN = BATCH_SIZE * 10 * NUM_OUTPUT
NUM_TEST = BATCH_SIZE * 2 * NUM_OUTPUT
EXAMPLE_TIME = 30.0
DT = 0.01
TRAIN = True
KERNEL_PROFILING = True

spikes, labels = generate_yin_yang_dataset(NUM_TRAIN if TRAIN else NUM_TEST,
EXAMPLE_TIME - (4 * DT))

# Plot training data
fig, axis = plt.subplots()
axis.scatter([d.spike_times[0] for d in spikes], [d.spike_times[1] for d in spikes], c=labels)

serialiser = Numpy("yin_yang_checkpoints")
network = SequentialNetwork(default_params)
with network:
# Populations
input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * NUM_INPUT),
NUM_INPUT, record_spikes=True)
hidden = Layer(Dense(Normal(mean=1.5, sd=0.78)),
LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
tau_refrac=None),
NUM_HIDDEN, Exponential(5.0), record_spikes=True)
output = Layer(Dense(Normal(mean=0.93, sd=0.1)),
LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
tau_refrac=None,
readout="first_spike_time"),
NUM_OUTPUT, Exponential(5.0), record_spikes=True)

max_example_timesteps = int(np.ceil(EXAMPLE_TIME / DT))
if TRAIN:
compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
losses="sparse_categorical_crossentropy",
optimiser=Adam(0.003, 0.9, 0.99), batch_size=BATCH_SIZE,
softmax_temperature=0.5, ttfs_alpha=0.1, dt=DT,
kernel_profiling=KERNEL_PROFILING)
compiled_net = compiler.compile(network)

with compiled_net:
def alpha_schedule(epoch, alpha):
return 0.003 * (0.998 ** epoch)

# Evaluate model on dataset
start_time = perf_counter()
examples = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
examples_back = [512, 522, 532, 542, 552, 562, 572, 582, 592, 602]
callbacks = ["batch_progress_bar", Checkpoint(serialiser),
OptimiserParamSchedule("alpha", alpha_schedule),
SpikeRecorder(input, key="InputSpikes", example_filter=examples),
SpikeRecorder(hidden, key="HiddenSpikes", example_filter=examples),
SpikeRecorder(output, key="OutputSpikes", example_filter=examples),
VarRecorder(output, key="OutputTTFS", genn_var="TFirstSpike", example_filter=examples),
VarRecorder(output, key="OutputLambdaV", genn_var="LambdaV", example_filter=examples_back)]
metrics, cb_data = compiled_net.train({input: spikes},
{output: labels},
num_epochs=NUM_EPOCHS, shuffle=False,
callbacks=callbacks)
for e in range(NUM_EPOCHS):
fig, axes = plt.subplots(4, 10, sharex="col", sharey="row")
timesteps = np.arange(0.0, EXAMPLE_TIME, DT)
for i in range(10):
#in_spikes = [
axes[0,i].set_title(f"Example {(e * 10) + i}")
axes[0,i].scatter(cb_data["InputSpikes"][0][(e * 10) + i], cb_data["InputSpikes"][1][(e * 10) + i], s=1)
axes[1,i].scatter(cb_data["HiddenSpikes"][0][(e * 10) + i], cb_data["HiddenSpikes"][1][(e * 10) + i], s=1)
axes[2,i].scatter(cb_data["OutputSpikes"][0][(e * 10) + i], cb_data["OutputSpikes"][1][(e * 10) + i], s=1)

axes[2,i].scatter(-cb_data["OutputTTFS"][(e * 10) + i][-1,:], np.arange(3), marker="X", alpha=0.5)

#for i in NUM_OUTPUT:
for j in range(NUM_OUTPUT):
axes[3,i].plot(timesteps, (j * 0.002) + cb_data["OutputLambdaV"][((e * 10) + i)][::-1,j],
linestyle=("-" if labels[examples[i]] == j else "--"))
#axes[3,i].plot(j + cb_data["OutputLambdaI"][BATCH_SIZE + (i * 10)][::-1,j])
axes[3,i].set_xlabel("Time [ms]")
axes[3,i].set_xlim((0, EXAMPLE_TIME))

axes[0,0].set_ylabel("Input neuron ID")
axes[1,0].set_ylabel("Hidden neuron ID")
axes[2,0].set_ylabel("Output neuron ID")
plt.show()

end_time = perf_counter()
print(f"Accuracy = {100 * metrics[output].result}%")
print(f"Time = {end_time - start_time}s")

if KERNEL_PROFILING:
print(f"Neuron update time = {compiled_net.genn_model.neuron_update_time}")
print(f"Presynaptic update time = {compiled_net.genn_model.presynaptic_update_time}")
print(f"Gradient batch reduce time = {compiled_net.genn_model.get_custom_update_time('GradientBatchReduce')}")
print(f"Gradient learn time = {compiled_net.genn_model.get_custom_update_time('GradientLearn')}")
print(f"Reset time = {compiled_net.genn_model.get_custom_update_time('Reset')}")
print(f"Softmax1 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax1')}")
print(f"Softmax2 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax2')}")
print(f"Softmax3 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax3')}")
else:
# Load network state from final checkpoint
network.load((NUM_EPOCHS - 1,), serialiser)

compiler = InferenceCompiler(evaluate_timesteps=max_example_timesteps,
reset_in_syn_between_batches=True,
batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

with compiled_net:
# Evaluate model on numpy dataset
start_time = perf_counter()
metrics, _ = compiled_net.evaluate({input: spikes},
{output: labels})
end_time = perf_counter()
print(f"Accuracy = {100 * metrics[output].result}%")
print(f"Time = {end_time - start_time}s")
1 change: 0 additions & 1 deletion ml_genn/ml_genn/callbacks/conn_var_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pygenn import get_var_access_dim
from ..utils.filter import get_neuron_filter_mask
from ..utils.network import get_underlying_conn
from ..utils.value import get_genn_var_name
from ..connection import Connection

logger = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions ml_genn/ml_genn/compilers/compiled_training_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from .compiled_network import CompiledNetwork
from ..callbacks import BatchProgressBar
from ..connectivity.sparse_base import SparseBase
from ..metrics import Metric
from ..metrics import Metric, MetricsType
from ..serialisers import Serialiser
from ..utils.callback_list import CallbackList
from ..utils.data import MetricsType

from ..utils.data import (batch_dataset, get_dataset_size,
permute_dataset, split_dataset)
Expand Down
13 changes: 8 additions & 5 deletions ml_genn/ml_genn/compilers/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,23 @@

# Second pass of softmax - calculate scaled sum of exp(value)
softmax_2_model = {
"params": [("Temp", "scalar")],
"vars": [("SumExpVal", "scalar", CustomUpdateVarAccess.REDUCE_NEURON_SUM)],
"var_refs": [("Val", "scalar", VarAccessMode.READ_ONLY),
("MaxVal", "scalar", VarAccessMode.READ_ONLY)],
"update_code": """
SumExpVal = exp(Val - MaxVal);
SumExpVal = exp((Val - MaxVal) / Temp);
"""}

# Third pass of softmax - calculate softmax value
softmax_3_model = {
"params": [("Temp", "scalar")],
"var_refs": [("Val", "scalar", VarAccessMode.READ_ONLY),
("MaxVal", "scalar", VarAccessMode.READ_ONLY),
("SumExpVal", "scalar", VarAccessMode.READ_ONLY),
("SoftmaxVal", "scalar")],
"update_code": """
SoftmaxVal = exp(Val - MaxVal) / SumExpVal;
SoftmaxVal = exp((Val - MaxVal) / Temp) / SumExpVal;
"""}

def set_dynamic_param(param_names, set_param_dynamic):
Expand Down Expand Up @@ -343,7 +345,8 @@ def add_out_post_zero_custom_update(self, genn_model, genn_syn_pop,

def add_softmax_custom_updates(self, genn_model, genn_pop,
input_var_name: str, output_var_name: str,
custom_update_group_prefix: str = ""):
custom_update_group_prefix: str = "",
temperature: float = 1.0):
"""Adds a numerically stable softmax to the model:

.. math::
Expand Down Expand Up @@ -379,7 +382,7 @@ def add_softmax_custom_updates(self, genn_model, genn_pop,
# Create custom update model to implement
# second softmax pass and add to model
softmax_2 = CustomUpdateModel(
softmax_2_model, {}, {"SumExpVal": 0.0},
softmax_2_model, {"Temp": temperature}, {"SumExpVal": 0.0},
{"Val": create_var_ref(genn_pop, input_var_name),
"MaxVal": create_var_ref(genn_softmax_1, "MaxVal")})

Expand All @@ -391,7 +394,7 @@ def add_softmax_custom_updates(self, genn_model, genn_pop,
# Create custom update model to implement
# third softmax pass and add to model
softmax_3 = CustomUpdateModel(
softmax_3_model, {}, {},
softmax_3_model, {"Temp": temperature}, {},
{"Val": create_var_ref(genn_pop, input_var_name),
"MaxVal": create_var_ref(genn_softmax_1, "MaxVal"),
"SumExpVal": create_var_ref(genn_softmax_2, "SumExpVal"),
Expand Down
2 changes: 1 addition & 1 deletion ml_genn/ml_genn/compilers/eprop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
CustomUpdateOnBatchEnd, CustomUpdateOnTimestepEnd)
from ..communicators import Communicator
from ..losses import Loss, SparseCategoricalCrossentropy
from ..metrics import MetricsType
from ..neurons import (AdaptiveLeakyIntegrateFire, Input,
LeakyIntegrate, LeakyIntegrateFire,
LeakyIntegrateFireInput)
from ..optimisers import Optimiser
from ..synapses import Delta
from ..utils.callback_list import CallbackList
from ..utils.data import MetricsType
from ..utils.model import (CustomUpdateModel, NeuronModel,
SynapseModel, WeightUpdateModel)
from ..utils.snippet import ConnectivitySnippet
Expand Down
Loading