Skip to content

Commit

Permalink
Merge pull request #110 from hassonlab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
zkokaja authored Dec 15, 2022
2 parents 2fa1ef1 + 762a421 commit b8253c6
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 97 deletions.
26 changes: 13 additions & 13 deletions scripts/tfsemb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_model_layer_count(args):
# NOTE: layer_idx is shifted by 1 because the first item in hidden_states
# corresponds to the output of the embeddings_layer
if args.layer_idx == "all":
args.layer_idx = np.arange(1, max_layers + 1)
args.layer_idx = np.arange(0, max_layers + 1)
elif args.layer_idx == "last":
args.layer_idx = [max_layers]
else:
Expand All @@ -42,14 +42,9 @@ def select_tokenizer_and_model(args):
return

try:
(
args.model,
args.tokenizer,
) = tfsemb_dwnld.download_tokenizers_and_models(
(args.model, args.tokenizer,) = tfsemb_dwnld.download_tokenizers_and_models(
model_name, local_files_only=True, debug=False
)[
model_name
]
)[model_name]
except OSError:
# NOTE: Please refer to make-target: cache-models for more information.
print(
Expand All @@ -73,7 +68,7 @@ def select_tokenizer_and_model(args):
def process_inputs(args):
if len(args.layer_idx) == 1:
if args.layer_idx[0].isdecimal():
args.layer_idx = int(args.layer_idx[0])
args.layer_idx = [int(args.layer_idx[0])]
else:
args.layer_idx = args.layer_idx[0]
else:
Expand Down Expand Up @@ -111,9 +106,7 @@ def setup_environ(args):
)

args.input_dir = os.path.join(DATA_DIR, args.subject)
args.conversation_list = sorted(
glob.glob1(args.input_dir, "NY*Part*conversation*")
)
args.conversation_list = sorted(glob.glob1(args.input_dir, "NY*Part*conversation*"))

select_tokenizer_and_model(args)
stra = f"{args.trimmed_model_name}/{args.pkl_identifier}/cnxt_{args.context_length:04d}"
Expand All @@ -128,12 +121,19 @@ def setup_environ(args):
output_file_name = args.conversation_list[args.conversation_id - 1]
args.output_file = os.path.join(args.output_dir, output_file_name)

# saving the base dataframe
# saving the base dataframe
args.base_df_file = os.path.join(
args.EMB_DIR,
args.trimmed_model_name,
args.pkl_identifier,
"base_df.pkl",
)

# saving logits as dataframe
args.logits_df_file = os.path.join(
args.EMB_DIR,
stra,
"logits.pkl",
)

return
75 changes: 10 additions & 65 deletions scripts/tfsemb_download.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os

from transformers import (AutoModel, AutoModelForCausalLM,
AutoModelForMaskedLM, AutoModelForSeq2SeqLM,
AutoTokenizer)
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)

CAUSAL_MODELS = [
"gpt2",
Expand All @@ -18,11 +22,10 @@
"facebook/opt-2.7b",
"facebook/opt-6.7b",
"facebook/opt-30b",
"bigscience/bloom",
]
SEQ2SEQ_MODELS = ["facebook/blenderbot_small-90M", "facebook/blenderbot-3B"]

CLONE_MODELS = []

MLM_MODELS = [
# "gpt2-xl", # uncomment to run this model with MLM input
# "gpt2-medium", # uncomment to run this model with MLM input
Expand Down Expand Up @@ -104,58 +107,6 @@ def download_tokenizer_and_model(
return (model, tokenizer)


def clone_model_repo(
CACHE_DIR,
tokenizer_class,
model_class,
model_name,
local_files_only=False,
):
"""Cache (load) the model and tokenizer from the model repository (cache).
Args:
CACHE_DIR (str): path where the model and tokenizer will be cached.
tokenizer_class (Tokenizer): Tokenizer class to be instantiated for the model.
model_class (Huggingface Model): Model class corresponding to model_name.
model_name (str): Model name as seen on https://hugginface.co/models.
local_files_only (bool, optional): False (Default) if caching.
True if loading from cache.
Returns:
tuple or None: (tokenizer, model) if local_files_only is True
None if local_files_only is False.
"""
model_dir = os.path.join(CACHE_DIR, model_name)

if local_files_only:
if os.path.exists(model_dir):
model, tokenizer = download_tokenizer_and_model(
CACHE_DIR,
tokenizer_class,
model_class,
model_dir,
local_files_only,
)
return model, tokenizer
else:
print(f"Model directory {model_dir} does not exist")
else:
try:
if (
"tiger" in os.uname().nodename
): # probably redundant, but just in case we are on tiger
os.system("module load git")

os.system(f"git lfs install")
os.system(
f"git clone https://huggingface.co/{model_name} {model_dir}"
)
except:
# FIXME: Raise appropriate exception
print("Possible git lfs version issues")
exit(1)


def set_cache_dir():
CACHE_DIR = os.path.join(os.path.dirname(os.getcwd()), ".cache")
os.makedirs(CACHE_DIR, exist_ok=True)
Expand Down Expand Up @@ -200,13 +151,7 @@ def download_tokenizers_and_models(
for model_name in MODELS:
print(f"Model Name: {model_name}")

cache_function = (
clone_model_repo
if model_name in CLONE_MODELS
else download_tokenizer_and_model
)

model_dict[model_name] = cache_function(
model_dict[model_name] = download_tokenizer_and_model(
CACHE_DIR,
AutoTokenizer,
model_class,
Expand All @@ -218,7 +163,7 @@ def download_tokenizers_and_models(
if debug:
print("Checking if model has been cached successfully")
try:
cache_function(
download_tokenizer_and_model(
CACHE_DIR,
AutoTokenizer,
model_class,
Expand Down
57 changes: 43 additions & 14 deletions scripts/tfsemb_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tfsemb_parser import arg_parser
from utils import load_pickle, main_timer
from utils import save_pickle as svpkl
from accelerate import Accelerator, find_executable_batch_size


def save_pickle(args, item, embeddings=None):
Expand Down Expand Up @@ -154,9 +155,14 @@ def process_extracted_logits(args, concat_logits, sentence_token_ids):
-prediction_probabilities * logp, dim=1
).tolist()

top1_probabilities, top1_probabilities_idx = prediction_probabilities.max(
dim=1
top1_probabilities, top1_probabilities_idx = torch.topk(
prediction_probabilities, 1, dim=1
)
top1_probabilities, top1_probabilities_idx = (
top1_probabilities.squeeze(),
top1_probabilities_idx.squeeze(),
)

predicted_tokens = args.tokenizer.convert_ids_to_tokens(
top1_probabilities_idx
)
Expand Down Expand Up @@ -195,16 +201,16 @@ def process_extracted_logits(args, concat_logits, sentence_token_ids):


def extract_select_vectors(batch_idx, array):
if batch_idx == 0:
x = array[0, :-1, :].clone()
if batch_idx == 0: # first batch
x = array[0, :-1, :].clone() # first window, all but last embeddings
if array.shape[0] > 1:
try:
try: # (n-1)-th embedding
rem_sentences_preds = array[1:, -2, :].clone()
except:
except: # n-th embedding
rem_sentences_preds = array[1:, -1, :].clone()

x = torch.cat([x, rem_sentences_preds], axis=0)
else:
else: # remaining batches
try:
x = array[:, -2, :].clone()
except:
Expand Down Expand Up @@ -499,12 +505,30 @@ def make_input_from_tokens(args, token_list):
return windows


def make_dataloader_from_input(windows):
def make_dataloader_from_input(windows, batch_size):
input_ids = torch.tensor(windows)
data_dl = data.DataLoader(input_ids, batch_size=8, shuffle=False)
data_dl = data.DataLoader(input_ids, batch_size=batch_size, shuffle=False)
return data_dl


def inference_function(args, model_input):
accelerator = Accelerator()

@find_executable_batch_size(starting_batch_size=128)
def inner_training_loop(batch_size=128):
nonlocal accelerator # Ensure they can be used in our context
accelerator.free_memory() # Free all lingering references
accelerator.print(batch_size)
input_dl = make_dataloader_from_input(model_input, batch_size)
embeddings, logits = model_forward_pass(args, input_dl)

return embeddings, logits

embeddings, logits = inner_training_loop()

return embeddings, logits


def generate_causal_embeddings(args, df):
if args.embedding_type in tfsemb_dwnld.CAUSAL_MODELS:
args.tokenizer.pad_token = args.tokenizer.eos_token
Expand All @@ -513,11 +537,11 @@ def generate_causal_embeddings(args, df):
final_top1_prob = []
final_true_y_prob = []
final_true_y_rank = []
final_logits = []
for conversation in df.conversation_id.unique():
token_list = get_conversation_tokens(df, conversation)
model_input = make_input_from_tokens(args, token_list)
input_dl = make_dataloader_from_input(model_input)
embeddings, logits = model_forward_pass(args, input_dl)
embeddings, logits = inference_function(args, model_input)

embeddings = process_extracted_embeddings_all_layers(args, embeddings)
for _, item in embeddings.items():
Expand All @@ -535,6 +559,7 @@ def generate_causal_embeddings(args, df):
final_top1_prob.extend(top1_prob)
final_true_y_prob.extend(true_y_prob)
final_true_y_rank.extend(true_y_rank)
final_logits.extend([None] + torch.cat(logits, axis=0).tolist())

if len(final_embeddings) > 1:
# TODO concat all embeddings and return a dictionary
Expand All @@ -551,7 +576,10 @@ def generate_causal_embeddings(args, df):
df["surprise"] = -df["true_pred_prob"] * np.log2(df["true_pred_prob"])
df["entropy"] = entropy

return df, final_embeddings
df_logits = pd.DataFrame()
df_logits["logits"] = final_logits

return df, df_logits, final_embeddings


def get_vector(x, glove):
Expand Down Expand Up @@ -596,12 +624,13 @@ def main():
# Generate Embeddings
embeddings = None
output = generate_func(args, utterance_df)
if len(output) == 2:
df, embeddings = output
if len(output) == 3:
df, df_logits, embeddings = output
else:
df = output

save_pickle(args, df, embeddings)
svpkl(df_logits, args.logits_df_file)

return

Expand Down
6 changes: 2 additions & 4 deletions scripts/tfspkl_build_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def build_design_matrices(CONFIG, delimiter=","):
df = pd.DataFrame(sigelec_list, columns=["subject", "electrode"])
except:
# If the electrode file is in the new format
df = pd.read_csv(
CONFIG["sig_elec_file"], columns=["subject", "electrode"]
)
else:
df = pd.read_csv(CONFIG["sig_elec_file"])
finally:
electrodes_dict = (
df.groupby("subject")["electrode"].apply(list).to_dict()
)
Expand Down
5 changes: 4 additions & 1 deletion scripts/tfspkl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def arg_parser():
group.add_argument("--max-electrodes", type=int, default=1e4)

group1 = parser.add_mutually_exclusive_group()
group1.add_argument("--subject", type=str, default=661)
group1.add_argument("--subject", type=str, default=None)
group1.add_argument("--sig-elec-file", type=str, default="")

parser.add_argument("--bin-size", type=int, default=32)
Expand All @@ -42,4 +42,7 @@ def arg_parser():

args = parser.parse_args()

if not args.subject:
args.subject = "777"

return args

0 comments on commit b8253c6

Please sign in to comment.