Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve channel selection #117

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 81 additions & 63 deletions bci_essentials/channel_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def channel_selection_by_method(
preds : numpy.ndarray
The predictions from the model.
1D array with the same shape as `y`.

shape = (`n_trials`)
accuracy : float
The accuracy of the trained classification model.
Expand Down Expand Up @@ -206,6 +205,7 @@ def channel_selection_by_method(


def __check_stopping_criterion(
algorithm,
current_time,
n_channels,
current_performance_delta,
Expand All @@ -218,6 +218,8 @@ def __check_stopping_criterion(

Parameters
----------
algorithm : str
The algorithm being used for channel selection.
current_time : float
The time elapsed since the start of the channel selection method.
n_channels : int
Expand Down Expand Up @@ -246,19 +248,21 @@ def __check_stopping_criterion(
logger.debug("Stopping based on time")
return True

elif n_channels <= min_channels:
logger.debug("Stopping because minimum number of channels reached")
return True
if algorithm == "SBS" or algorithm == "SBFS":
if n_channels <= min_channels:
logger.debug("Stopping because minimum number of channels reached")
return True

elif n_channels >= max_channels:
logger.debug("Stopping because maximum number of channels reached")
return True
if algorithm == "SFS" or algorithm == "SFFS":
if n_channels >= max_channels:
logger.debug("Stopping because maximum number of channels reached")
return True

elif current_performance_delta < performance_delta:
if current_performance_delta < performance_delta:
logger.debug("Stopping because performance improvements are declining")
return True
else:
return False

return False


def __sfs(
Expand Down Expand Up @@ -410,7 +414,7 @@ def __sfs(
for channel in range(n_channels):
if channel not in sfs_subset:
set_to_try = sfs_subset.copy()
set_to_try.append(c)
set_to_try.append(channel)
sets_to_try.append(set_to_try)

# Get the new subset of data
Expand Down Expand Up @@ -492,22 +496,22 @@ def __sfs(
best_precision = precision
best_recall = recall

if record_performance is True:
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]

step += 1

stop_criterion = __check_stopping_criterion(
"SFS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -523,6 +527,9 @@ def __sfs(
logger.debug("%s : %s", metric, best_performance)
logger.debug("Time to optimal subset: %s s", time.time() - start_time)

if record_performance is True:
logger.info(results_df)

# Get the best model

return (
Expand Down Expand Up @@ -677,9 +684,9 @@ def __sbs(
# Exclusion Step
sets_to_try = []
X_to_try = []
for c in sbs_subset:
for channel in sbs_subset:
set_to_try = sbs_subset.copy()
set_to_try.remove(c)
set_to_try.remove(channel)
set_to_try.sort()

# Only try sets that have not been tried before
Expand Down Expand Up @@ -769,26 +776,26 @@ def __sbs(
best_precision = precision
best_recall = recall

if record_performance is True:
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]

step += 1

# Break if SBFS subset is 1 channel
# Break if SBS subset is 1 channel
if len(sbs_subset) == 1:
break

stop_criterion = __check_stopping_criterion(
"SBS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -804,6 +811,9 @@ def __sbs(
logger.debug("%s : %s", metric, best_performance)
logger.debug("Time to optimal subset: %s s", time.time() - start_time)

if record_performance is True:
logger.info(results_df)

return (
best_channel_subset,
best_model,
Expand Down Expand Up @@ -1204,19 +1214,18 @@ def __sbfs(
best_precision = precision
best_recall = recall

if record_performance:
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]
step += 1
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]
step += 1

performance_at_n_channels[length_of_resultant_set - 1] = (
current_performance
Expand All @@ -1229,6 +1238,7 @@ def __sbfs(

# Check stopping criterion
stop_criterion = __check_stopping_criterion(
"SBFS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -1239,6 +1249,7 @@ def __sbfs(
)

stop_criterion = __check_stopping_criterion(
"SBFS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -1258,6 +1269,9 @@ def __sbfs(
logger.debug("%s : %s", metric, best_performance)
logger.debug("Time to optimal subset: %s s", time.time() - start_time)

if record_performance is True:
logger.info(results_df)

return (
best_channel_subset,
best_model,
Expand Down Expand Up @@ -1517,18 +1531,17 @@ def __sffs(
best_precision = precision
best_recall = recall

if record_performance:
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]
new_channel_subset.sort()
results_df.loc[step] = [
step,
time.time() - start_time,
len(new_channel_subset),
"".join(new_channel_subset),
len(sets_to_try),
accuracy,
precision,
recall,
]

step += 1

Expand Down Expand Up @@ -1678,6 +1691,7 @@ def __sffs(
# Check stopping criterion
if pass_stopping_criterion is False:
stop_criterion = __check_stopping_criterion(
"SFFS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -1692,6 +1706,7 @@ def __sffs(
continue
else:
stop_criterion = __check_stopping_criterion(
"SFFS",
time.time() - start_time,
len(new_channel_subset),
p_delta,
Expand All @@ -1707,6 +1722,9 @@ def __sffs(
logger.debug("%s : %s", metric, best_performance)
logger.debug("Time to optimal subset: %s s", time.time() - start_time)

if record_performance is True:
logger.info(results_df)

return (
best_channel_subset,
best_model,
Expand Down
2 changes: 1 addition & 1 deletion examples/mi_ch_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# Define channel selection settings
initial_subset = []
classifier.setup_channel_selection(
method="SFFS",
method="SBFS",
metric="accuracy",
iterative_selection=True,
initial_channels=initial_subset, # wrapper setup
Expand Down
15 changes: 15 additions & 0 deletions examples/mi_unity_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
# Set settings
classifier.set_mi_classifier_settings(n_splits=3, type="TS", random_seed=35)

# Define channel selection settings
initial_subset = []
classifier.setup_channel_selection(
method="SFS",
metric="accuracy",
iterative_selection=False,
initial_channels=initial_subset, # wrapper setup
max_time=100,
min_channels=0,
max_channels=4,
performance_delta=-1, # stopping criterion
n_jobs=-1,
record_performance=True,
)

# Define the MI data object
mi_data = BciController(
classifier, eeg_source, marker_source, paradigm, data_tank, messenger
Expand Down
Loading
Loading