Skip to content

Commit 7ec045c

Browse files
committed
Adapted to new fairseq releases
1 parent d39d58f commit 7ec045c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

models/joint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from fairseq.modules import PositionalEmbedding
1616

1717
from fairseq.models import (
18-
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel, register_model, register_model_architecture
18+
FairseqIncrementalDecoder, FairseqEncoder, FairseqEncoderDecoderModel, register_model, register_model_architecture
1919
)
2020

2121
from .protected_multihead_attention import ProtectedMultiheadAttention
2222

2323
@register_model('joint_attention')
24-
class JointAttentionModel(FairseqModel):
24+
class JointAttentionModel(FairseqEncoderDecoderModel):
2525
"""
2626
Local Joint Source-Target model from
2727
`"Joint Source-Target Self Attention with Locality Constraints" (Fonollosa, et al, 2019)
@@ -225,7 +225,7 @@ def max_positions(self):
225225
"""Maximum input length supported by the encoder."""
226226
if self.embed_positions is None:
227227
return self.max_source_positions
228-
return min(self.max_source_positions, self.embed_positions.max_positions())
228+
return min(self.max_source_positions, self.embed_positions.max_positions)
229229

230230

231231
class JointAttentionDecoder(FairseqIncrementalDecoder):
@@ -413,7 +413,7 @@ def max_positions(self):
413413
"""Maximum output length supported by the decoder."""
414414
if self.embed_positions is None:
415415
return self.max_target_positions
416-
return min(self.max_target_positions, self.embed_positions.max_positions())
416+
return min(self.max_target_positions, self.embed_positions.max_positions)
417417

418418
def buffered_future_mask(self, tensor):
419419
"""Cached future mask."""

models/protected_multihead_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
# can be found in the PATENTS file in the same directory.
77

88
import torch
9+
import torch.nn.functional as F
910
from torch import nn
1011
from torch.nn import Parameter
11-
import torch.nn.functional as F
12+
from fairseq.incremental_decoding_utils import with_incremental_state
1213

1314
from fairseq import utils
1415

1516
# Adapted from faiserq/modules/multihead_attention to deal with local attention
1617
# Local attetion masking in combination with padding masking can lead to
1718
# all -Inf attention rows. This version detects and corrects this situation
19+
@with_incremental_state
1820
class ProtectedMultiheadAttention(nn.Module):
1921
"""Multi-headed attention.
2022
@@ -247,15 +249,13 @@ def reorder_incremental_state(self, incremental_state, new_order):
247249
self._set_input_buffer(incremental_state, input_buffer)
248250

249251
def _get_input_buffer(self, incremental_state):
250-
return utils.get_incremental_state(
251-
self,
252+
return self.get_incremental_state(
252253
incremental_state,
253254
'attn_state',
254255
) or {}
255256

256257
def _set_input_buffer(self, incremental_state, buffer):
257-
utils.set_incremental_state(
258-
self,
258+
return self.set_incremental_state(
259259
incremental_state,
260260
'attn_state',
261261
buffer,

0 commit comments

Comments
 (0)