diff --git a/api.py b/api.py index 7bf20c9..59ea2cc 100644 --- a/api.py +++ b/api.py @@ -16,7 +16,7 @@ def __init__(self, model_name: str): ''' super().__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = AutoModel.from_pretrained(model_name, output_attentions=True).to( + self.model = AutoModel.from_pretrained(model_name).to( self.device ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -26,7 +26,7 @@ def _grab_attn(self, context): function to get the attention for a model. First runs a forward pass and then extracts and formats attn. ''' - output = self.model(context) + output = self.model(context, output_attentions=True) # Grab the attention from the output # Format as Layer x Head x From x To attn = torch.cat([l for l in output[-1]], dim=0)