Skip to content

Commit 5af4f85

Browse files
committed
Hard code synchony_size for users, but leave flexible code underneathe
1 parent 853d8a4 commit 5af4f85

File tree

4 files changed

+30
-42
lines changed

4 files changed

+30
-42
lines changed

doc/get_started/quickstart.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
673673
'min_spikes': 0,
674674
'window_size_s': 1},
675675
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
676-
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
676+
'synchrony': {}
677677
678678
679679
Since the recording is very short, let’s change some parameters to

doc/modules/qualitymetrics/synchrony.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
1212
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
1313
within and across spike trains.
1414

15-
Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
15+
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.
1616

1717

1818

@@ -29,7 +29,7 @@ Example code
2929
3030
import spikeinterface.qualitymetrics as sqm
3131
# Combine a sorting and recording into a sorting_analyzer
32-
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
32+
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
3333
# synchrony is a tuple of dicts with the synchrony metrics for each unit
3434
3535

src/spikeinterface/qualitymetrics/misc_metrics.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -520,18 +520,18 @@ def compute_sliding_rp_violations(
520520
)
521521

522522

523-
def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
523+
def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])):
524524
"""
525525
Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`.
526526
527527
Parameters
528528
----------
529529
spikes : np.array
530530
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
531-
synchrony_sizes : numpy array
532-
The synchrony sizes to compute. Should be pre-sorted.
533531
all_unit_ids : list or None, default: None
534532
List of unit ids to compute the synchrony metrics. Expecting all units.
533+
synchrony_sizes : numpy array
534+
The synchrony sizes to compute. Should be pre-sorted.
535535
536536
Returns
537537
-------
@@ -565,37 +565,32 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
565565
return synchrony_counts
566566

567567

568-
def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
568+
def compute_synchrony_metrics(sorting_analyzer, unit_ids=None):
569569
"""
570570
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
571-
"synchrony_size" spikes at the exact same sample index.
571+
spikes at the exact same sample index, with synchrony sizes 2, 4 and 8.
572572
573573
Parameters
574574
----------
575575
sorting_analyzer : SortingAnalyzer
576576
A SortingAnalyzer object.
577-
synchrony_sizes : list or tuple, default: (2, 4, 8)
578-
The synchrony sizes to compute.
579577
unit_ids : list or None, default: None
580578
List of unit ids to compute the synchrony metrics. If None, all units are used.
581579
582580
Returns
583581
-------
584582
sync_spike_{X} : dict
585583
The synchrony metric for synchrony size X.
586-
Returns are as many as synchrony_sizes.
587584
588585
References
589586
----------
590587
Based on concepts described in [Grün]_
591588
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
592589
"""
593-
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
594-
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
595-
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
596-
synchrony_sizes_np.sort()
597590

598-
res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np])
591+
synchrony_sizes = np.array([2, 4, 8])
592+
593+
res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes])
599594

600595
sorting = sorting_analyzer.sorting
601596

@@ -606,10 +601,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
606601

607602
spikes = sorting.to_spike_vector()
608603
all_unit_ids = sorting.unit_ids
609-
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)
604+
synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes)
610605

611606
synchrony_metrics_dict = {}
612-
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
607+
for sync_idx, synchrony_size in enumerate(synchrony_sizes):
613608
sync_id_metrics_dict = {}
614609
for i, unit_id in enumerate(all_unit_ids):
615610
if unit_id not in unit_ids:
@@ -623,7 +618,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
623618
return res(**synchrony_metrics_dict)
624619

625620

626-
_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
621+
_default_params["synchrony"] = dict()
627622

628623

629624
def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None):

src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
compute_firing_ranges,
4040
compute_amplitude_cv_metrics,
4141
compute_sd_ratio,
42-
get_synchrony_counts,
42+
_get_synchrony_counts,
4343
compute_quality_metrics,
4444
)
4545

@@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync():
352352
one_spike["sample_index"] = spike_times
353353
one_spike["unit_index"] = spike_units
354354

355-
sync_count = get_synchrony_counts(one_spike, np.array((2)), [0])
355+
sync_count = _get_synchrony_counts(one_spike, [0])
356356

357357
assert np.all(sync_count[0] == np.array([0]))
358358

@@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync():
372372
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
373373
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))
374374

375-
sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1])
375+
sync_count = _get_synchrony_counts(two_spikes, [0, 1])
376376

377377
assert np.all(sync_count[0] == np.array([1, 1]))
378378

@@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync():
392392
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
393393
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))
394394

395-
sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3])
395+
sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3])
396396

397397
assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
398398
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))
@@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units():
409409
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
410410
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))
411411

412-
sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2])
412+
sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2])
413413

414414
assert np.all(sync_count[0] == np.array([0, 1, 1]))
415415

@@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations):
610610
def test_synchrony_metrics(sorting_analyzer_simple):
611611
sorting_analyzer = sorting_analyzer_simple
612612
sorting = sorting_analyzer.sorting
613-
synchrony_sizes = (2, 3, 4)
614-
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes)
615-
print(synchrony_metrics)
613+
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer)
614+
615+
synchrony_sizes = np.array([2, 4, 8])
616616

617617
# check returns
618618
for size in synchrony_sizes:
@@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple):
625625
sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level)
626626
sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory")
627627

628-
previous_synchrony_metrics = compute_synchrony_metrics(
629-
previous_sorting_analyzer, synchrony_sizes=synchrony_sizes
630-
)
631-
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes)
628+
previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer)
629+
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync)
632630
print(current_synchrony_metrics)
633631
# check that all values increased
634632
for i, col in enumerate(previous_synchrony_metrics._fields):
@@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):
647645

648646
unit_ids_subset = [3, 7]
649647

650-
synchrony_sizes = (2,)
651-
(synchrony_metrics,) = compute_synchrony_metrics(
652-
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
653-
)
648+
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset)
654649

655-
assert list(synchrony_metrics.keys()) == [3, 7]
650+
assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7]
651+
assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7]
652+
assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7]
656653

657654

658655
def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):
659656

660-
# all_unit_ids = sorting_analyzer_simple.sorting.unit_ids
661-
662-
synchrony_sizes = (2,)
663-
(synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)
664-
665-
assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)
657+
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple)
658+
assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids)
666659

667660

668661
@pytest.mark.sortingcomponents

0 commit comments

Comments
 (0)