Skip to content

Commit 0dfd6b6

Browse files
dianaml0facebook-github-bot
authored andcommitted
Add linting with black (facebookresearch#2678)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: fairinternal/fairseq-py#2678 Reviewed By: Mortimerp9 Differential Revision: D32653381 Pulled By: dianaml0 fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
1 parent 3dc1691 commit 0dfd6b6

File tree

137 files changed

+2142
-1356
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+2142
-1356
lines changed

.github/workflows/build.yml

+5
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ jobs:
5353
- name: Run tests
5454
run: |
5555
python setup.py test
56+
57+
- name: Lint with black
58+
run: |
59+
pip install black
60+
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'

fairseq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
# initialize hydra
2929
from fairseq.dataclass.initialize import hydra_init
30+
3031
hydra_init()
3132

3233
import fairseq.criterions # noqa

fairseq/benchmark/dummy_mt.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import numpy as np
99
import torch
10+
1011
from fairseq.data import Dictionary, FairseqDataset
1112
from fairseq.tasks import LegacyFairseqTask, register_task
1213

13-
1414
logger = logging.getLogger(__name__)
1515

1616

@@ -36,7 +36,7 @@ def __init__(self, args, dictionary):
3636

3737
@classmethod
3838
def setup_task(cls, args, **kwargs):
39-
"""Setup the task. """
39+
"""Setup the task."""
4040
dictionary = Dictionary()
4141
for i in range(args.dict_size):
4242
dictionary.add_symbol("word{}".format(i))

fairseq/checkpoint_utils.py

+15-22
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ def is_better(a, b):
9696

9797
checkpoint_conds[
9898
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
99-
cfg.best_checkpoint_metric,
100-
val_loss,
101-
rand_sfx,
102-
suffix
99+
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
103100
)
104101
] = worst_best is None or is_better(val_loss, worst_best)
105102
checkpoint_conds[
@@ -468,9 +465,7 @@ def load_model_ensemble_and_task(
468465
and len(state["optimizer_history"]) > 0
469466
and "num_updates" in state["optimizer_history"][-1]
470467
):
471-
model.set_num_updates(
472-
state["optimizer_history"][-1]["num_updates"]
473-
)
468+
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
474469
model.load_state_dict(
475470
state["model"], strict=strict, model_cfg=cfg.model
476471
)
@@ -588,9 +583,8 @@ def _upgrade_state_dict(state):
588583
# backward compatibility, cfg updates
589584
if "args" in state and state["args"] is not None:
590585
# old model checkpoints may not have separate source/target positions
591-
if (
592-
hasattr(state["args"], "max_positions")
593-
and not hasattr(state["args"], "max_source_positions")
586+
if hasattr(state["args"], "max_positions") and not hasattr(
587+
state["args"], "max_source_positions"
594588
):
595589
state["args"].max_source_positions = state["args"].max_positions
596590
state["args"].max_target_positions = state["args"].max_positions
@@ -615,13 +609,10 @@ def _upgrade_state_dict(state):
615609
state["args"].stop_min_lr = state["args"].min_lr
616610
del state["args"].min_lr
617611
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
618-
if (
619-
hasattr(state["args"], "criterion")
620-
and state["args"].criterion in [
621-
"binary_cross_entropy",
622-
"kd_binary_cross_entropy",
623-
]
624-
):
612+
if hasattr(state["args"], "criterion") and state["args"].criterion in [
613+
"binary_cross_entropy",
614+
"kd_binary_cross_entropy",
615+
]:
625616
state["args"].criterion = "wav2vec"
626617
# remove log_keys if it's None (criteria will supply a default value of [])
627618
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
@@ -659,7 +650,9 @@ def _upgrade_state_dict(state):
659650
):
660651
cfg.task.eval_wer_config.print_alignment = "hard"
661652
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
662-
cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None
653+
cfg.generation.print_alignment = (
654+
"hard" if cfg.generation.print_alignment else None
655+
)
663656
if (
664657
"model" in cfg
665658
and "w2v_args" in cfg.model
@@ -833,16 +826,16 @@ def load_ema_from_checkpoint(fpath):
833826
params_dict = collections.OrderedDict()
834827
new_state = None
835828

836-
with PathManager.open(fpath, 'rb') as f:
829+
with PathManager.open(fpath, "rb") as f:
837830
new_state = torch.load(
838831
f,
839832
map_location=(
840-
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
833+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
841834
),
842835
)
843836

844837
# EMA model is stored in a separate "extra state"
845-
model_params = new_state['extra_state']['ema']
838+
model_params = new_state["extra_state"]["ema"]
846839

847840
for key in list(model_params.keys()):
848841
p = model_params[key]
@@ -860,5 +853,5 @@ def load_ema_from_checkpoint(fpath):
860853
"ema model weights, is this model trained with EMA?"
861854
)
862855

863-
new_state['model'] = params_dict
856+
new_state["model"] = params_dict
864857
return new_state

fairseq/criterions/fastspeech2_loss.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
@dataclass
2222
class FastSpeech2CriterionConfig(FairseqDataclass):
23-
ctc_weight: float = field(
24-
default=0.0, metadata={"help": "weight for CTC loss"}
25-
)
23+
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
2624

2725

2826
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
@@ -44,7 +42,7 @@ def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
4442
speaker=sample["speaker"],
4543
durations=sample["durations"],
4644
pitches=sample["pitches"],
47-
energies=sample["energies"]
45+
energies=sample["energies"],
4846
)
4947

5048
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
@@ -57,8 +55,7 @@ def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
5755
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
5856
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
5957
if _feat_out_post is not None:
60-
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat,
61-
reduction=reduction)
58+
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
6259

6360
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
6461
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
@@ -69,16 +66,23 @@ def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
6966
log_dur = torch.log(dur + 1)[src_mask]
7067
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
7168

72-
ctc_loss = torch.tensor(0.).type_as(l1_loss)
73-
if self.ctc_weight > 0.:
69+
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
70+
if self.ctc_weight > 0.0:
7471
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
7572
lprobs = lprobs.transpose(0, 1) # T x B x C
7673
src_mask = lengths_to_mask(src_lens)
7774
src_tokens_flat = src_tokens.masked_select(src_mask)
78-
ctc_loss = F.ctc_loss(
79-
lprobs, src_tokens_flat, tgt_lens, src_lens,
80-
reduction=reduction, zero_infinity=True
81-
) * self.ctc_weight
75+
ctc_loss = (
76+
F.ctc_loss(
77+
lprobs,
78+
src_tokens_flat,
79+
tgt_lens,
80+
src_lens,
81+
reduction=reduction,
82+
zero_infinity=True,
83+
)
84+
* self.ctc_weight
85+
)
8286

8387
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
8488

@@ -102,8 +106,12 @@ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
102106
ntot = sum(ns)
103107
ws = [n / (ntot + 1e-8) for n in ns]
104108
for key in [
105-
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss",
106-
"ctc_loss"
109+
"loss",
110+
"l1_loss",
111+
"dur_loss",
112+
"pitch_loss",
113+
"energy_loss",
114+
"ctc_loss",
107115
]:
108116
vals = [log.get(key, 0) for log in logging_outputs]
109117
val = sum(val * w for val, w in zip(vals, ws))
@@ -115,10 +123,10 @@ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
115123
return
116124
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
117125
for key, new_key in [
118-
("mcd_loss", "mcd_loss"),
119-
("pred_frames", "pred_ratio"),
120-
("nins", "ins_rate"),
121-
("ndel", "del_rate"),
126+
("mcd_loss", "mcd_loss"),
127+
("pred_frames", "pred_ratio"),
128+
("nins", "ins_rate"),
129+
("ndel", "del_rate"),
122130
]:
123131
val = sum(log.get(key, 0) for log in logging_outputs)
124132
metrics.log_scalar(new_key, val / n, n, round=3)

fairseq/criterions/hubert_criterion.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ class HubertCriterionConfig(FairseqDataclass):
3737

3838
@register_criterion("hubert", dataclass=HubertCriterionConfig)
3939
class HubertCriterion(FairseqCriterion):
40-
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
40+
def __init__(
41+
self,
42+
task,
43+
pred_masked_weight,
44+
pred_nomask_weight,
45+
loss_weights=None,
46+
log_keys=None,
47+
):
4148
super().__init__(task)
4249
self.pred_masked_weight = pred_masked_weight
4350
self.pred_nomask_weight = pred_nomask_weight
@@ -52,7 +59,7 @@ def forward(self, model, sample, reduce=True, log_pred=False):
5259
3) logging outputs to display while training
5360
"""
5461
net_output = model(target_list=sample["target_list"], **sample["net_input"])
55-
loss = 0.
62+
loss = 0.0
5663
sample_size = 0
5764
logging_output = {}
5865
reduction = "sum" if reduce else "none"
@@ -89,7 +96,9 @@ def forward(self, model, sample, reduce=True, log_pred=False):
8996
names = [names]
9097
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
9198
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
92-
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
99+
assert len(extra_losses) == len(
100+
self.loss_weights
101+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
93102
for p, n, coef in zip(extra_losses, names, self.loss_weights):
94103
if coef != 0 and p is not None:
95104
p = coef * p.float() * sample_size
@@ -140,12 +149,20 @@ def reduce_metrics(logging_outputs) -> None:
140149
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
141150
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
142151

143-
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
152+
metrics.log_scalar(
153+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
154+
)
144155
if sample_size != ntokens:
145-
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
146-
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
156+
metrics.log_scalar(
157+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
158+
)
159+
metrics.log_derived(
160+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
161+
)
147162
else:
148-
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
163+
metrics.log_derived(
164+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
165+
)
149166

150167
counts = {}
151168
for lk in logging_outputs[0].keys():

fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@
99
from fairseq.criterions import register_criterion
1010
from fairseq.criterions.label_smoothed_cross_entropy import (
1111
LabelSmoothedCrossEntropyCriterion,
12-
LabelSmoothedCrossEntropyCriterionConfig
12+
LabelSmoothedCrossEntropyCriterionConfig,
1313
)
1414

1515
try:
1616
from simuleval.metrics.latency import (
1717
AverageLagging,
1818
AverageProportion,
19-
DifferentiableAverageLagging
19+
DifferentiableAverageLagging,
2020
)
21+
2122
LATENCY_METRICS = {
2223
"average_lagging": AverageLagging,
2324
"average_proportion": AverageProportion,
24-
"differentiable_average_lagging": DifferentiableAverageLagging,
25+
"differentiable_average_lagging": DifferentiableAverageLagging,
2526
}
2627
except ImportError:
2728
LATENCY_METRICS = None
@@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
5657
metadata={"help": "Add latency loss after certain steps"},
5758
)
5859

60+
5961
@register_criterion(
6062
"latency_augmented_label_smoothed_cross_entropy",
61-
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig
63+
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
6264
)
6365
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
6466
LabelSmoothedCrossEntropyCriterion
@@ -101,9 +103,9 @@ def forward(self, model, sample, reduce=True):
101103

102104
if self.latency_update_after > 0:
103105
num_updates = getattr(model.decoder, "num_updates", None)
104-
assert num_updates is not None, (
105-
"model.decoder doesn't have attribute 'num_updates'"
106-
)
106+
assert (
107+
num_updates is not None
108+
), "model.decoder doesn't have attribute 'num_updates'"
107109
if num_updates <= self.latency_update_after:
108110
latency_loss = 0
109111

@@ -134,9 +136,7 @@ def compute_latency_loss(self, model, sample, net_output):
134136
assert (
135137
net_output[-1].encoder_padding_mask is None
136138
or not net_output[-1].encoder_padding_mask[:, 0].any()
137-
), (
138-
"Only right padding on source is supported."
139-
)
139+
), "Only right padding on source is supported."
140140
# 1. Obtain the expected alignment
141141
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
142142
num_layers = len(alpha_list)
@@ -174,8 +174,7 @@ def compute_latency_loss(self, model, sample, net_output):
174174
.view(-1)
175175
)
176176
expected_latency = LATENCY_METRICS[self.latency_avg_type](
177-
expected_delays, src_lengths, None,
178-
target_padding_mask=target_padding_mask
177+
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
179178
)
180179

181180
# 2.1 average expected latency of heads
@@ -210,24 +209,12 @@ def compute_latency_loss(self, model, sample, net_output):
210209
@classmethod
211210
def reduce_metrics(cls, logging_outputs) -> None:
212211
super().reduce_metrics(logging_outputs)
213-
latency = sum(
214-
log.get("latency", 0) for log in logging_outputs
215-
)
216-
delays_var = sum(
217-
log.get("delays_var", 0) for log in logging_outputs
218-
)
219-
latency_loss = sum(
220-
log.get("latency_loss", 0) for log in logging_outputs
221-
)
212+
latency = sum(log.get("latency", 0) for log in logging_outputs)
213+
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
214+
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
222215
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
216+
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
217+
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
223218
metrics.log_scalar(
224-
"latency", latency.float() / nsentences, nsentences, round=3
225-
)
226-
metrics.log_scalar(
227-
"delays_var", delays_var / nsentences,
228-
nsentences, round=3
229-
)
230-
metrics.log_scalar(
231-
"latency_loss", latency_loss / nsentences,
232-
nsentences, round=3
219+
"latency_loss", latency_loss / nsentences, nsentences, round=3
233220
)

0 commit comments

Comments
 (0)