Skip to content

Commit f36f764

Browse files
authored
Merge pull request #128 from hassonlab/dev
Dev to Main 20221222
2 parents b8253c6 + 4588392 commit f36f764

14 files changed

+383
-234
lines changed

Makefile

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ create-sig-pickle:
6969
mkdir -p logs
7070
$(CMD) scripts/tfspkl_main.py \
7171
--project-id $(PRJCT_ID) \
72-
--sig-elec-file data/$(PRJCT_ID)/all-electrodes.csv
72+
--sig-elec-file all-electrodes2.csv
7373

7474
# upload pickles to google cloud bucket
7575
# on bucket we use 247 not tfs, so manually adjust as needed
7676
# upload-pickle: pid=247
7777
upload-pickle: pid=podcast
7878
upload-pickle:
7979
for sid in $(SID_LIST); do \
80-
gsutil -m rsync results/$(PRJCT_ID)/$$sid/pickles/ gs://247-podcast-data/$(pid)-pickles/$$sid; \
80+
gsutil -m rsync -rd results/$(PRJCT_ID)/$$sid/pickles/ gs://247-podcast-data/$(pid)-pickles/$$sid; \
8181
done
8282

8383
# upload raw data to google cloud bucket
@@ -107,7 +107,7 @@ download-247-pickles:
107107
"facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", \
108108
"facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-30b", \
109109
"facebook/blenderbot_small-90M"}
110-
%-embeddings: CNXT_LEN := 1024 512 256 128 64 32 16 8 4 2 1
110+
%-embeddings: CNXT_LEN := 1024
111111
%-embeddings: LAYER := all
112112
# {'all' for all layers | 'last' for the last layer | (list of) integer(s) >= 1}
113113
# Note: embeddings file is the same for all podcast subjects \
@@ -154,13 +154,11 @@ concatenate-embeddings:
154154
done;
155155

156156
# Podcast: copy embeddings to other subjects as well
157-
# for sid in 662 717 723 741 742 763 798 777; do
158157
copy-embeddings:
159-
@for fn in results/podcast/661/pickles/*embeddings.pkl; do \
160-
for sid in 777; do \
161-
cp -pf $$fn $$(echo $$fn | sed "s/661/$$sid/g"); \
162-
done; \
163-
done
158+
fn=results/podcast/661/pickles/embeddings
159+
for sid in 662 717 723 741 742 763 798 777; do \
160+
cp -rpf $$fn $$(echo $$fn | sed "s/661/$$sid/g"); \
161+
done; \
164162

165163

166164
# Download huggingface models to cache (before generating embeddings)

requirements.yml

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,83 @@ name: 247-main
22
channels:
33
- defaults
44
dependencies:
5-
- _libgcc_mutex=0.1
6-
- _openmp_mutex=5.1
7-
- bzip2=1.0.8
8-
- ca-certificates=2022.4.26
9-
- ld_impl_linux-64=2.38
10-
- libffi=3.3
11-
- libgcc-ng=11.2.0
12-
- libgomp=11.2.0
13-
- libstdcxx-ng=11.2.0
14-
- libuuid=1.0.3
15-
- ncurses=6.3
16-
- openssl=1.1.1p
17-
- python=3.10.4
18-
- readline=8.1.2
19-
- sqlite=3.38.5
20-
- tk=8.6.12
21-
- tzdata=2022a
22-
- wheel=0.37.1
23-
- xz=5.2.5
24-
- zlib=1.2.12
5+
- _libgcc_mutex=0.1=main
6+
- _openmp_mutex=5.1=1_gnu
7+
- bzip2=1.0.8=h7b6447c_0
8+
- ca-certificates=2022.4.26=h06a4308_0
9+
- ld_impl_linux-64=2.38=h1181459_1
10+
- libffi=3.3=he6710b0_2
11+
- libgcc-ng=11.2.0=h1234567_1
12+
- libgomp=11.2.0=h1234567_1
13+
- libstdcxx-ng=11.2.0=h1234567_1
14+
- libuuid=1.0.3=h7f8727e_2
15+
- ncurses=6.3=h5eee18b_3
16+
- openssl=1.1.1p=h5eee18b_0
17+
- python=3.10.4=h12debd9_0
18+
- readline=8.1.2=h7f8727e_1
19+
- sqlite=3.38.5=hc218d9a_0
20+
- tk=8.6.12=h1ccaba5_0
21+
- tzdata=2022a=hda174b7_0
22+
- wheel=0.37.1=pyhd3eb1b0_0
23+
- xz=5.2.5=h7f8727e_1
24+
- zlib=1.2.12=h7f8727e_2
2525
- pip:
26+
- accelerate==0.14.0
27+
- aiohttp==3.8.3
28+
- aiosignal==1.2.0
29+
- alabaster==0.7.12
2630
- anyio==3.6.1
31+
- appdirs==1.4.4
2732
- argon2-cffi==21.3.0
2833
- argon2-cffi-bindings==21.2.0
2934
- asttokens==2.0.5
35+
- async-timeout==4.0.2
3036
- attrs==21.4.0
37+
- audioread==3.0.0
3138
- babel==2.10.3
3239
- backcall==0.2.0
3340
- beautifulsoup4==4.11.1
34-
- black==22.6.0
41+
- black==22.12.0
3542
- bleach==5.0.1
43+
- boltons==21.0.0
44+
- bracex==2.3.post1
3645
- certifi==2022.6.15
3746
- cffi==1.15.1
3847
- charset-normalizer==2.1.0
3948
- click==8.1.3
49+
- click-extra==3.5.0
50+
- click-log==0.4.0
51+
- cloup==2.0.0.post1
52+
- colorama==0.4.5
53+
- commentjson==0.9.0
4054
- cycler==0.11.0
55+
- datasets==2.5.2
4156
- debugpy==1.6.2
4257
- decorator==5.1.1
4358
- defusedxml==0.7.1
59+
- dill==0.3.5.1
60+
- docutils==0.19
4461
- entrypoints==0.4
4562
- executing==0.8.3
4663
- fastjsonschema==2.15.3
64+
- ffmpeg-python==0.2.0
4765
- filelock==3.7.1
4866
- fonttools==4.34.4
67+
- frozenlist==1.3.1
68+
- fsspec==2022.8.2
69+
- future==0.18.2
4970
- gensim==4.2.0
5071
- h5py==3.7.0
51-
- huggingface-hub==0.8.1
72+
- htmlmin==0.1.12
73+
- huggingface-hub==0.10.0
74+
- icecream==2.1.3
5275
- idna==3.3
76+
- imagehash==4.2.1
77+
- imagesize==1.4.1
5378
- ipykernel==6.15.1
5479
- ipython==8.4.0
5580
- ipython-genutils==0.2.0
81+
- ipywidgets==7.7.1
5682
- isort==5.10.1
5783
- jedi==0.18.1
5884
- jinja2==3.1.2
@@ -65,73 +91,119 @@ dependencies:
6591
- jupyterlab==3.4.3
6692
- jupyterlab-pygments==0.2.2
6793
- jupyterlab-server==2.15.0
94+
- jupyterlab-widgets==1.1.1
6895
- kiwisolver==1.4.3
96+
- lark-parser==0.7.8
97+
- librosa==0.9.2
6998
- llvmlite==0.38.1
7099
- markupsafe==2.1.1
71100
- mat73==0.59
72101
- matplotlib==3.5.2
73102
- matplotlib-inline==0.1.3
103+
- mergedeep==1.3.4
104+
- missingno==0.5.1
74105
- mistune==0.8.4
106+
- more-itertools==8.14.0
107+
- multidict==6.0.2
108+
- multimethod==1.8
109+
- multiprocess==0.70.13
75110
- mypy-extensions==0.4.3
76111
- nbclassic==0.4.2
77112
- nbclient==0.6.6
78113
- nbconvert==6.5.0
79114
- nbformat==5.4.0
80115
- nest-asyncio==1.5.5
116+
- networkx==2.8.5
81117
- nltk==3.7
118+
- notebook==6.4.12
82119
- notebook-shim==0.1.0
83120
- numba==0.55.2
84121
- numexpr==2.8.3
85122
- numpy==1.22.4
123+
- nvidia-cublas-cu11==11.10.3.66
124+
- nvidia-cuda-nvrtc-cu11==11.7.99
125+
- nvidia-cuda-runtime-cu11==11.7.99
126+
- nvidia-cudnn-cu11==8.5.0.96
127+
- packageurl-python==0.10.4
86128
- packaging==21.3
129+
- pallets-sphinx-themes==2.0.2
87130
- pandas==1.4.3
131+
- pandas-profiling==3.2.0
88132
- pandocfilters==1.5.0
89133
- parso==0.8.3
90134
- pathspec==0.9.0
91135
- pexpect==4.8.0
136+
- phik==0.12.2
92137
- pickleshare==0.7.5
93138
- pillow==9.2.0
94139
- pip==21.2.4
95140
- platformdirs==2.5.2
141+
- pooch==1.6.0
96142
- prometheus-client==0.14.1
97143
- prompt-toolkit==3.0.30
98144
- psutil==5.9.1
99145
- ptyprocess==0.7.0
100146
- pure-eval==0.2.2
147+
- pyarrow==9.0.0
101148
- pycparser==2.21
149+
- pydantic==1.9.2
102150
- pygments==2.12.0
151+
- pygments-ansi-color==0.1.0
103152
- pyparsing==3.0.9
104153
- pyrsistent==0.18.1
105154
- python-dateutil==2.8.2
106155
- pytz==2022.1
156+
- pywavelets==1.3.0
107157
- pyyaml==6.0
108158
- pyzmq==23.2.0
109159
- regex==2022.7.9
110160
- requests==2.28.1
161+
- resampy==0.4.2
162+
- responses==0.18.0
111163
- scikit-learn==1.1.1
112164
- scipy==1.8.1
165+
- seaborn==0.11.2
113166
- send2trash==1.8.0
114167
- setuptools==61.2.0
115168
- six==1.16.0
116169
- smart-open==6.0.0
117170
- sniffio==1.2.0
171+
- snowballstemmer==2.2.0
172+
- soundfile==0.11.0
118173
- soupsieve==2.3.2.post1
174+
- sphinx==5.3.0
175+
- sphinxcontrib-applehelp==1.0.2
176+
- sphinxcontrib-devhelp==1.0.2
177+
- sphinxcontrib-htmlhelp==2.0.0
178+
- sphinxcontrib-jsmath==1.0.1
179+
- sphinxcontrib-qthelp==1.0.3
180+
- sphinxcontrib-serializinghtml==1.1.5
119181
- stack-data==0.3.0
182+
- tabulate==0.9.0
183+
- tangled-up-in-unicode==0.2.0
120184
- terminado==0.15.0
121185
- threadpoolctl==3.1.0
122186
- tinycss2==1.1.1
123187
- tokenizers==0.12.1
124188
- tomli==2.0.1
125-
- torch==1.12.0+cu113
126-
- torchaudio==0.12.0+cu113
127-
- torchvision==0.13.0+cu113
189+
- tomli-w==1.0.0
190+
- torch==1.13.1
191+
- torchaudio==0.13.1
192+
- torchvision==0.14.1
128193
- tornado==6.2
129194
- tqdm==4.64.0
130195
- traitlets==5.3.0
131-
- transformers==4.20.1
196+
- transformers==4.25.1
132197
- typing-extensions==4.3.0
133198
- urllib3==1.26.10
199+
- visions==0.7.4
200+
- wcmatch==8.4.1
134201
- wcwidth==0.2.5
135202
- webencodings==0.5.1
136203
- websocket-client==1.3.3
204+
- whisper==1.0
205+
- widgetsnbextension==3.6.1
206+
- xmltodict==0.13.0
207+
- xxhash==3.0.0
208+
- yarl==1.8.1
137209
prefix: /home/hgazula/.conda/envs/247-main

scripts/electrode_utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from scipy.io import loadmat
9+
from tfspkl_config import ELECTRODE_FOLDER_MAP
910

1011

1112
def get_electrode(CONFIG, elec_id):
@@ -19,21 +20,15 @@ def get_electrode(CONFIG, elec_id):
1920
"""
2021
conversation, electrode = elec_id
2122

22-
if CONFIG["project_id"] == "podcast":
23-
search_str = conversation + f"/preprocessed_all/*_{electrode}.mat"
24-
elif CONFIG["project_id"] == "tfs":
25-
if CONFIG["subject"] == "7170":
26-
search_str = conversation + f"/preprocessed_v2/*_{electrode}.mat"
27-
# TODO: check if it is preprocessed or preprocessed_v2
28-
elif CONFIG["subject"] == "798":
29-
search_str = (
30-
conversation + f"/preprocessed_allElec/*_{electrode}.mat"
31-
)
32-
else:
33-
search_str = conversation + f"/preprocessed/*_{electrode}.mat"
34-
else:
35-
print("Incorrect Project ID")
36-
sys.exit()
23+
electrode_folder = ELECTRODE_FOLDER_MAP.get(CONFIG["project_id"], None).get(
24+
CONFIG["subject"], None
25+
)
26+
27+
if not electrode_folder:
28+
print("Incorrect Project ID or Subject")
29+
exit()
30+
31+
search_str = conversation + f"/{electrode_folder}/*_{electrode}.mat"
3732

3833
mat_fn = glob.glob(search_str)
3934
if mat_fn:

scripts/tfsemb_LMBase.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,27 @@
77
from utils import save_pickle as svpkl
88

99

10+
def clean_lm_model_name(item):
11+
"""Remove unnecessary parts from the language model name.
12+
13+
Args:
14+
item (str/list): full model name from HF Hub
15+
16+
Returns:
17+
(str/list): pretty model name
18+
19+
Example:
20+
clean_lm_model_name(EleutherAI/gpt-neo-1.3B) == 'gpt-neo-1.3B'
21+
"""
22+
if isinstance(item, str):
23+
return item.split("/")[-1]
24+
25+
if isinstance(item, list):
26+
return [clean_lm_model_name(i) for i in item]
27+
28+
print("Invalid input. Please check.")
29+
30+
1031
def add_vocab_columns(args, df, column=None):
1132
"""Add columns to the dataframe indicating whether each word is in the
1233
vocabulary of the language models we're using.
@@ -27,17 +48,20 @@ def add_vocab_columns(args, df, column=None):
2748
model, local_files_only=False
2849
)
2950

30-
key = model.split("/")[-1]
51+
key = clean_lm_model_name(model)
3152
print(f"Adding column: (token) in_{key}")
3253

3354
try:
3455
curr_vocab = tokenizer.vocab
3556
except AttributeError:
3657
curr_vocab = tokenizer.get_vocab()
3758

38-
df[f"in_{key}"] = df[column].apply(
39-
lambda x: isinstance(curr_vocab.get(x), int)
40-
)
59+
def helper(x):
60+
if len(tokenizer.tokenize(x)) == 1:
61+
return isinstance(curr_vocab.get(tokenizer.tokenize(x)[0]), int)
62+
return False
63+
64+
df[f"in_{key}"] = df[column].apply(helper)
4165

4266
return df
4367

@@ -49,16 +73,17 @@ def main():
4973

5074
base_df = load_pickle(args.labels_pickle, "labels")
5175

76+
glove = api.load("glove-wiki-gigaword-50")
77+
base_df["in_glove50"] = base_df.word.str.lower().apply(
78+
lambda x: isinstance(glove.key_to_index.get(x), int)
79+
)
80+
5281
if args.embedding_type == "glove50":
53-
base_df = add_vocab_columns(args, base_df, column="word")
82+
base_df = base_df[base_df["in_glove50"]]
83+
base_df = add_vocab_columns(args, base_df, column="word", flag=True)
5484
else:
55-
# Add glove
56-
glove = api.load("glove-wiki-gigaword-50")
57-
base_df["in_glove"] = base_df.word.str.lower().apply(
58-
lambda x: isinstance(glove.key_to_index.get(x), int)
59-
)
6085
base_df = tokenize_and_explode(args, base_df)
61-
base_df = add_vocab_columns(args, base_df, column="token")
86+
base_df = add_vocab_columns(args, base_df, column="token2word")
6287

6388
svpkl(base_df, args.base_df_file)
6489

0 commit comments

Comments
 (0)