From c3e786436eee66438e05346ab08dbcc0cd4673c6 Mon Sep 17 00:00:00 2001 From: Erick Fonseca Date: Sun, 5 Jul 2020 15:45:30 +0100 Subject: [PATCH] fix in output_attentions keyword --- api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api.py b/api.py index 7bf20c9..59ea2cc 100644 --- a/api.py +++ b/api.py @@ -16,7 +16,7 @@ def __init__(self, model_name: str): ''' super().__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = AutoModel.from_pretrained(model_name, output_attentions=True).to( + self.model = AutoModel.from_pretrained(model_name).to( self.device ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -26,7 +26,7 @@ def _grab_attn(self, context): function to get the attention for a model. First runs a forward pass and then extracts and formats attn. ''' - output = self.model(context) + output = self.model(context, output_attentions=True) # Grab the attention from the output # Format as Layer x Head x From x To attn = torch.cat([l for l in output[-1]], dim=0)