|
9 | 9 | from fairseq.criterions import register_criterion
|
10 | 10 | from fairseq.criterions.label_smoothed_cross_entropy import (
|
11 | 11 | LabelSmoothedCrossEntropyCriterion,
|
12 |
| - LabelSmoothedCrossEntropyCriterionConfig |
| 12 | + LabelSmoothedCrossEntropyCriterionConfig, |
13 | 13 | )
|
14 | 14 |
|
15 | 15 | try:
|
16 | 16 | from simuleval.metrics.latency import (
|
17 | 17 | AverageLagging,
|
18 | 18 | AverageProportion,
|
19 |
| - DifferentiableAverageLagging |
| 19 | + DifferentiableAverageLagging, |
20 | 20 | )
|
| 21 | + |
21 | 22 | LATENCY_METRICS = {
|
22 | 23 | "average_lagging": AverageLagging,
|
23 | 24 | "average_proportion": AverageProportion,
|
24 |
| - "differentiable_average_lagging": DifferentiableAverageLagging, |
| 25 | + "differentiable_average_lagging": DifferentiableAverageLagging, |
25 | 26 | }
|
26 | 27 | except ImportError:
|
27 | 28 | LATENCY_METRICS = None
|
@@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
|
56 | 57 | metadata={"help": "Add latency loss after certain steps"},
|
57 | 58 | )
|
58 | 59 |
|
| 60 | + |
59 | 61 | @register_criterion(
|
60 | 62 | "latency_augmented_label_smoothed_cross_entropy",
|
61 |
| - dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig |
| 63 | + dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig, |
62 | 64 | )
|
63 | 65 | class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
64 | 66 | LabelSmoothedCrossEntropyCriterion
|
@@ -101,9 +103,9 @@ def forward(self, model, sample, reduce=True):
|
101 | 103 |
|
102 | 104 | if self.latency_update_after > 0:
|
103 | 105 | num_updates = getattr(model.decoder, "num_updates", None)
|
104 |
| - assert num_updates is not None, ( |
105 |
| - "model.decoder doesn't have attribute 'num_updates'" |
106 |
| - ) |
| 106 | + assert ( |
| 107 | + num_updates is not None |
| 108 | + ), "model.decoder doesn't have attribute 'num_updates'" |
107 | 109 | if num_updates <= self.latency_update_after:
|
108 | 110 | latency_loss = 0
|
109 | 111 |
|
@@ -134,9 +136,7 @@ def compute_latency_loss(self, model, sample, net_output):
|
134 | 136 | assert (
|
135 | 137 | net_output[-1].encoder_padding_mask is None
|
136 | 138 | or not net_output[-1].encoder_padding_mask[:, 0].any()
|
137 |
| - ), ( |
138 |
| - "Only right padding on source is supported." |
139 |
| - ) |
| 139 | + ), "Only right padding on source is supported." |
140 | 140 | # 1. Obtain the expected alignment
|
141 | 141 | alpha_list = [item["alpha"] for item in net_output[1].attn_list]
|
142 | 142 | num_layers = len(alpha_list)
|
@@ -174,8 +174,7 @@ def compute_latency_loss(self, model, sample, net_output):
|
174 | 174 | .view(-1)
|
175 | 175 | )
|
176 | 176 | expected_latency = LATENCY_METRICS[self.latency_avg_type](
|
177 |
| - expected_delays, src_lengths, None, |
178 |
| - target_padding_mask=target_padding_mask |
| 177 | + expected_delays, src_lengths, None, target_padding_mask=target_padding_mask |
179 | 178 | )
|
180 | 179 |
|
181 | 180 | # 2.1 average expected latency of heads
|
@@ -210,24 +209,12 @@ def compute_latency_loss(self, model, sample, net_output):
|
210 | 209 | @classmethod
|
211 | 210 | def reduce_metrics(cls, logging_outputs) -> None:
|
212 | 211 | super().reduce_metrics(logging_outputs)
|
213 |
| - latency = sum( |
214 |
| - log.get("latency", 0) for log in logging_outputs |
215 |
| - ) |
216 |
| - delays_var = sum( |
217 |
| - log.get("delays_var", 0) for log in logging_outputs |
218 |
| - ) |
219 |
| - latency_loss = sum( |
220 |
| - log.get("latency_loss", 0) for log in logging_outputs |
221 |
| - ) |
| 212 | + latency = sum(log.get("latency", 0) for log in logging_outputs) |
| 213 | + delays_var = sum(log.get("delays_var", 0) for log in logging_outputs) |
| 214 | + latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs) |
222 | 215 | nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 216 | + metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3) |
| 217 | + metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3) |
223 | 218 | metrics.log_scalar(
|
224 |
| - "latency", latency.float() / nsentences, nsentences, round=3 |
225 |
| - ) |
226 |
| - metrics.log_scalar( |
227 |
| - "delays_var", delays_var / nsentences, |
228 |
| - nsentences, round=3 |
229 |
| - ) |
230 |
| - metrics.log_scalar( |
231 |
| - "latency_loss", latency_loss / nsentences, |
232 |
| - nsentences, round=3 |
| 219 | + "latency_loss", latency_loss / nsentences, nsentences, round=3 |
233 | 220 | )
|
0 commit comments