Skip to content

Commit de7811c

Browse files
authored
Merge pull request #12 from NxNiki/dev
add analysis
2 parents ebf6ab4 + 64b253f commit de7811c

31 files changed

+2201
-1101
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
results/**/*.html filter=lfs diff=lfs merge=lfs -text

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,10 @@ src/brain_decoding/__pycache__/
3535
data/
3636
._data
3737
config/*.yaml
38-
results/
38+
39+
results/**/*.npy
40+
results/**/*.png
41+
results/**/*.tar
42+
!results/**/*.html
43+
3944
wandb/

.run/run_twilight_merge.run.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="run_twilight_merge" type="PythonConfigurationType" factoryName="Python">
3+
<module name="brain_decoding" />
4+
<option name="ENV_FILES" value="" />
5+
<option name="INTERPRETER_OPTIONS" value="" />
6+
<option name="PARENT_ENVS" value="true" />
7+
<envs>
8+
<env name="PYTHONUNBUFFERED" value="1" />
9+
</envs>
10+
<option name="SDK_HOME" value="" />
11+
<option name="SDK_NAME" value="movie_decoding" />
12+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
13+
<option name="IS_MODULE_SDK" value="false" />
14+
<option name="ADD_CONTENT_ROOTS" value="true" />
15+
<option name="ADD_SOURCE_ROOTS" value="true" />
16+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/scripts/run_model_twilight_merge.py" />
17+
<option name="PARAMETERS" value="" />
18+
<option name="SHOW_COMMAND_LINE" value="false" />
19+
<option name="EMULATE_TERMINAL" value="false" />
20+
<option name="MODULE_MODE" value="false" />
21+
<option name="REDIRECT_INPUT" value="false" />
22+
<option name="INPUT_FILE" value="" />
23+
<method v="2" />
24+
</configuration>
25+
</component>

.run/run_twilight_vs_24.run.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="run_twilight_vs_24" type="PythonConfigurationType" factoryName="Python">
3+
<module name="brain_decoding" />
4+
<option name="ENV_FILES" value="" />
5+
<option name="INTERPRETER_OPTIONS" value="" />
6+
<option name="PARENT_ENVS" value="true" />
7+
<envs>
8+
<env name="PYTHONUNBUFFERED" value="1" />
9+
</envs>
10+
<option name="SDK_HOME" value="" />
11+
<option name="SDK_NAME" value="movie_decoding" />
12+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
13+
<option name="IS_MODULE_SDK" value="false" />
14+
<option name="ADD_CONTENT_ROOTS" value="true" />
15+
<option name="ADD_SOURCE_ROOTS" value="true" />
16+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/scripts/run_model_twilight_vs_24.py" />
17+
<option name="PARAMETERS" value="" />
18+
<option name="SHOW_COMMAND_LINE" value="false" />
19+
<option name="EMULATE_TERMINAL" value="false" />
20+
<option name="MODULE_MODE" value="false" />
21+
<option name="REDIRECT_INPUT" value="false" />
22+
<option name="INPUT_FILE" value="" />
23+
<method v="2" />
24+
</configuration>
25+
</component>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:f548b2d34b96898a91570a6fcb3e4764b0ec95a68482cbc88475f3166c2deef0
3+
size 1047036
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:33edd3c48cfca7166d95d7faf8b3a9e83016725c671183b621e02c7d80f09213
3+
size 433763

scripts/plot_activation.ipynb

Lines changed: 0 additions & 236 deletions
This file was deleted.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from pydantic import BaseModel
2+
3+
4+
class Config(BaseModel):
5+
class Config:
6+
# arbitrary_types_allowed = True
7+
extra = "allow" # Allow arbitrary attributes
8+
9+
_list_fields = set() # A set to track which fields should be treated as lists
10+
11+
def ensure_list(self, name: str):
12+
value = getattr(self, name, None)
13+
if value is not None and not isinstance(value, list):
14+
setattr(self, name, [value])
15+
# Mark the field to always be treated as a list
16+
self._list_fields.add(name)
17+
18+
def __setattr__(self, name, value):
19+
if name in self._list_fields and not isinstance(value, list):
20+
# Automatically convert to a list if it's in the list fields
21+
value = [value]
22+
super().__setattr__(name, value)
23+
24+
25+
class SupConfig(Config):
26+
pass
27+
28+
29+
# Example usage
30+
config = SupConfig()
31+
32+
# Dynamically adding attributes
33+
config.param1 = "a"
34+
35+
# Ensuring param1 is a list
36+
config.ensure_list("param1")
37+
print(config.param1) # Output: ['a']
38+
39+
# Assigning new value to param1
40+
config.param1 = "ab"
41+
print(config.param1) # Output: ['ab'] gets automatically converted to ['ab']
42+
43+
# Adding another parameter and ensuring it's a list
44+
config.ensure_list("param2")
45+
config.param2 = 123
46+
print(config.param2) # Output: [123]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Any, Dict, Set
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class BaseConfig(BaseModel):
7+
class Config:
8+
extra = "allow" # Allow arbitrary attributes
9+
10+
def __init__(self, **data: Any) -> None:
11+
super().__init__(**data)
12+
self.__dict__["_list_fields"]: Set[str] = set()
13+
self.__dict__["_alias"]: Dict[str, str] = {}
14+
15+
def __getitem__(self, key: str) -> Any:
16+
return getattr(self, key)
17+
18+
def __setitem__(self, key: str, value: Any):
19+
setattr(self, key, value)
20+
21+
def __getattr__(self, name):
22+
"""Handles alias access and custom parameters."""
23+
if name in self._alias:
24+
return getattr(self, self._alias[name])
25+
26+
def __setattr__(self, name, value):
27+
"""Handles alias assignment, field setting, or adding to _param."""
28+
if name in self._alias:
29+
name = self._alias[name]
30+
if name in self._list_fields and not isinstance(value, list):
31+
value = [value]
32+
super().__setattr__(name, value)
33+
34+
def __contains__(self, key: str) -> bool:
35+
return hasattr(self, key)
36+
37+
def __repr__(self):
38+
attrs = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
39+
attr_str = "\n".join(f" {key}: {value!r}" for key, value in attrs.items())
40+
return f"{self.__class__.__name__}(\n{attr_str}\n)"
41+
42+
def set_alias(self, name: str, alias: str) -> None:
43+
self.__dict__["_alias"][alias] = name
44+
45+
def ensure_list(self, name: str):
46+
"""Mark the field to always be treated as a list"""
47+
value = getattr(self, name, None)
48+
if value is not None and not isinstance(value, list):
49+
setattr(self, name, [value])
50+
self._list_fields.add(name)
51+
52+
53+
class Foo(BaseConfig):
54+
a: int = 1
55+
56+
class Config:
57+
extra = "allow"
58+
59+
60+
print(Foo(**{"a": 1, "b": 2}).model_dump()) # == {'a': 1, 'b': 2}
61+
62+
foo = Foo()
63+
foo.b = 2
64+
print(foo.model_dump())

src/brain_decoding/config/file_path.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
SURROGATE_FILE_PATH = ROOT_PATH / "data/surrogate_windows"
77
CONFIG_FILE_PATH = ROOT_PATH / "config"
88
RESULT_PATH = ROOT_PATH / "results"
9+
MOVIE24_LABEL_PATH = f"{DATA_PATH}/8concepts_merged.npy"
10+
TWILIGHT_LABEL_PATH = f"{DATA_PATH}/twilight_concepts.npy"
11+
TWILIGHT_MERGE_LABEL_PATH = f"{DATA_PATH}/twilight_concepts_merged.npy"
12+
MOVIE_LABEL_TWILIGHT_VS_24 = f"{DATA_PATH}/twilight_vs_24.npy"

scripts/save_config.py renamed to src/brain_decoding/config/save_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Custom parameters can be added to any of the three fields of config (experiment, model, data).
44
"""
55

6+
from torch import nn
7+
68
from brain_decoding.config.config import ExperimentConfig, PipelineConfig
79
from brain_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH
810

@@ -18,7 +20,7 @@
1820
config.model.lr_drop = 50
1921
config.model.validation_step = 10
2022
config.model.early_stop = 75
21-
config.model.num_labels = 8
23+
config.model.num_labels = 18 # 8 for 24, 18 for twilight
2224
config.model.merge_label = True
2325
config.model.img_embedding_size = 192
2426
config.model.hidden_size = 256
@@ -27,6 +29,7 @@
2729
config.model.patch_size = (1, 5)
2830
config.model.intermediate_size = 192 * 2
2931
config.model.classifier_proj_size = 192
32+
config.model.train_loss = nn.BCEWithLogitsLoss(reduction="none")
3033

3134
config.experiment.seed = 42
3235
config.experiment.use_spike = True
@@ -44,8 +47,8 @@
4447
config.experiment.use_shuffle_diagnostic = True
4548
config.experiment.testing_mode = False # in testing mode, a maximum of 1e4 clusterless data will be loaded.
4649
config.experiment.model_aggregate_type = "sum"
47-
config.experiment.train_phases = ["movie_1"]
48-
config.experiment.test_phases = ["sleep_2"]
50+
config.experiment.train_phases = ["twilight_1"]
51+
config.experiment.test_phases = ["sleep_1"]
4952
config.experiment.compute_accuracy = False
5053

5154
config.experiment.ensure_list("train_phases")
@@ -61,6 +64,8 @@
6164
config.data.spike_data_sd_inference = 3.5
6265
config.data.model_aggregate_type = "sum"
6366
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
67+
config.data.movie_label_sr = 1
6468
config.data.movie_sampling_rate = 30
69+
config.data.filter_low_occurrence_samples = True
6570

6671
# config.export_config(CONFIG_FILE_PATH)

src/brain_decoding/dataloader/clusterless_clean.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import glob
22
import os
33
import re
4+
from typing import List, Union
45

56
import numpy as np
67
import pandas as pd
@@ -16,6 +17,11 @@ def __init__(self, operator: str, threshold: int):
1617
self.threshold = threshold
1718

1819

20+
def sort_file_name(filenames: str) -> List[Union[int, str]]:
21+
"""Extract the numeric part of the filename and use it as the sort key"""
22+
return [int(x) if x.isdigit() else x for x in re.findall(r"\d+|\D+", filenames)]
23+
24+
1925
def find_true_indices(mask, op_thresh: OpThresh = None):
2026
"""
2127
Returns an nx3 matrix containing start, end, and length of all true samples in a 1D boolean mask.
@@ -161,14 +167,10 @@ def load_data_from_bundle(clu_bundle_filepaths):
161167

162168

163169
def get_oneshot_clean(patient_number, desired_samplerate, mode, category="recall", phase=None, version="notch"):
164-
def sort_filename(filename):
165-
"""Extract the numeric part of the filename and use it as the sort key"""
166-
return [int(x) if x.isdigit() else x for x in re.findall(r"\d+|\D+", filename)]
167-
168170
# folder contains the clustless data, I saved the folder downloaded from the drive as '562/clustless_raw'
169171
spike_path = f"/mnt/SSD2/yyding/Datasets/neuron/spike_data/{patient_number}/raw_{mode}/"
170172
spike_files = glob.glob(os.path.join(spike_path, "*.csv"))
171-
spike_files = sorted(spike_files, key=sort_filename)
173+
spike_files = sorted(spike_files, key=sort_file_name)
172174

173175
for bundle in range(0, len(spike_files), 8):
174176
df = load_data_from_bundle(spike_files[bundle : bundle + 8])

0 commit comments

Comments
 (0)