|
15 | 15 | from fairseq.modules import PositionalEmbedding
|
16 | 16 |
|
17 | 17 | from fairseq.models import (
|
18 |
| - FairseqIncrementalDecoder, FairseqEncoder, FairseqModel, register_model, register_model_architecture |
| 18 | + FairseqIncrementalDecoder, FairseqEncoder, FairseqEncoderDecoderModel, register_model, register_model_architecture |
19 | 19 | )
|
20 | 20 |
|
21 | 21 | from .protected_multihead_attention import ProtectedMultiheadAttention
|
22 | 22 |
|
23 | 23 | @register_model('joint_attention')
|
24 |
| -class JointAttentionModel(FairseqModel): |
| 24 | +class JointAttentionModel(FairseqEncoderDecoderModel): |
25 | 25 | """
|
26 | 26 | Local Joint Source-Target model from
|
27 | 27 | `"Joint Source-Target Self Attention with Locality Constraints" (Fonollosa, et al, 2019)
|
@@ -225,7 +225,7 @@ def max_positions(self):
|
225 | 225 | """Maximum input length supported by the encoder."""
|
226 | 226 | if self.embed_positions is None:
|
227 | 227 | return self.max_source_positions
|
228 |
| - return min(self.max_source_positions, self.embed_positions.max_positions()) |
| 228 | + return min(self.max_source_positions, self.embed_positions.max_positions) |
229 | 229 |
|
230 | 230 |
|
231 | 231 | class JointAttentionDecoder(FairseqIncrementalDecoder):
|
@@ -413,7 +413,7 @@ def max_positions(self):
|
413 | 413 | """Maximum output length supported by the decoder."""
|
414 | 414 | if self.embed_positions is None:
|
415 | 415 | return self.max_target_positions
|
416 |
| - return min(self.max_target_positions, self.embed_positions.max_positions()) |
| 416 | + return min(self.max_target_positions, self.embed_positions.max_positions) |
417 | 417 |
|
418 | 418 | def buffered_future_mask(self, tensor):
|
419 | 419 | """Cached future mask."""
|
|
0 commit comments