Skip to content

Commit 29a0f04

Browse files
committed
Format files to conform to flake8
1 parent 8fc57a9 commit 29a0f04

File tree

3 files changed

+101
-109
lines changed

3 files changed

+101
-109
lines changed

src/mokka/equalizers/adaptive/torch.py

Lines changed: 87 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""PyTorch implementations of adaptive equalizers."""
2+
23
from ..torch import Butterfly2x2
3-
from ..torch import correct_start_polarization, correct_start, find_start_offset
4+
from ..torch import correct_start_polarization, correct_start
45
from ...functional.torch import convolve_overlap_save
56
from ..torch import h2f
67
import torch
78
import logging
9+
from collections import namedtuple
810

911
logger = logging.getLogger(__name__)
1012

11-
from collections import namedtuple
12-
1313

1414
class CMA(torch.nn.Module):
1515
"""Class to perform CMA equalization."""
@@ -33,12 +33,17 @@ def __init__(
3333
:param R: constant modulus radius
3434
:param sps: samples per symbol
3535
:param lr: learning rate
36-
:param butterfly_filter: Optional :py:class:`mokka.equalizers.torch.Butterfly2x2` object
37-
:param filter_length: butterfly filter length (if object is not given)
38-
:param block_size: Number of symbols to process before updating the equalizer taps
39-
:param no_singularity: Initialize the x- and y-polarization to avoid singularity by decoding
40-
the signal from the same polarization twice.
41-
:param singularity_length: Delay for initialization with the no_singularity approach
36+
:param butterfly_filter: Optional
37+
:py:class:`mokka.equalizers.torch.Butterfly2x2` object
38+
:param filter_length: butterfly filter length
39+
(if object is not given)
40+
:param block_size: Number of symbols to process before
41+
updating the equalizer taps
42+
:param no_singularity: Initialize the x- and y-polarization to avoid
43+
singularity by decoding
44+
the signal from the same polarization twice.
45+
:param singularity_length: Delay for initialization with the
46+
no_singularity approach
4247
"""
4348
super(CMA, self).__init__()
4449
self.register_buffer("R", torch.as_tensor(R))
@@ -79,7 +84,8 @@ def forward(self, y):
7984
"""
8085
# Implement CMA "by hand"
8186
# Basically step through the signal advancing always +sps symbols
82-
# and filtering 2*filter_len samples which will give one output sample with mode "valid"
87+
# and filtering 2*filter_len samples which will give one output
88+
# sample with mode "valid"
8389

8490
equalizer_length = self.butterfly_filter.taps.size()[1]
8591
num_samp = y.shape[1]
@@ -92,9 +98,8 @@ def forward(self, y):
9298
out = torch.zeros(
9399
2, (num_samp - equalizer_length) // self.sps, dtype=torch.complex64
94100
)
95-
eq_offset = (
96-
equalizer_length - 1
97-
) // 2 # We try to put the symbol of interest in the center tap of the equalizer
101+
# We try to put the symbol of interest in the center tap of the equalizer
102+
eq_offset = (equalizer_length - 1) // 2
98103
for i, k in enumerate(
99104
range(eq_offset, num_samp - 1 - eq_offset * 2, self.sps * self.block_size)
100105
):
@@ -143,11 +148,6 @@ def get_error_signal(self):
143148
return self.out_e
144149

145150

146-
##############################################################################################
147-
########################### Variational Autoencoer based Equalizer ###########################
148-
##############################################################################################
149-
150-
151151
def ELBO_DP(
152152
y,
153153
q,
@@ -168,8 +168,8 @@ def ELBO_DP(
168168
N = y.shape[1]
169169
# Now we have two polarizations in the first dimension
170170
# We assume the same transmit constellation for both, calculating
171-
# q needs to be shaped 2 x N x M -> for each observation on each polarization we have M q-values
172-
# we have M constellation symbols
171+
# q needs to be shaped 2 x N x M -> for each observation on each polarization we
172+
# have M q-values and we have M constellation symbols
173173
L = butterfly_filter.taps.shape[1]
174174
L_offset = (L - 1) // 2
175175
if p_constellation is None:
@@ -180,7 +180,7 @@ def ELBO_DP(
180180
# # Precompute E_Q{c} = sum( q * c) where c is x and |x|**2
181181
E_Q_x = torch.zeros(2, N, device=q.device, dtype=torch.complex64)
182182
E_Q_x_abssq = torch.zeros(2, N, device=q.device, dtype=torch.float32)
183-
if IQ_separate == True:
183+
if IQ_separate:
184184
num_lev = constellation_symbols.shape[0]
185185
E_Q_x[:, ::sps] = torch.complex(
186186
torch.sum(
@@ -217,8 +217,8 @@ def ELBO_DP(
217217
axis=-1,
218218
)
219219

220-
# Term A - sum all the things, but spare the first dimension, since the two polarizations
221-
# are sorta independent
220+
# Term A - sum all the things, but spare the first dimension,
221+
# since the two polarizations are sorta independent
222222
bias = 1e-14
223223
A = torch.sum(
224224
q[:, L_offset:-L_offset, :]
@@ -230,11 +230,13 @@ def ELBO_DP(
230230
)
231231

232232
# Precompute h \ast E_Q{x}
233-
h_conv_E_Q_x = butterfly_filter(
234-
E_Q_x, mode="valid"
235-
) # Due to definition that we assume the symbol is at the center tap we remove (filter_length - 1)//2 at start and end
236-
# Limit the length of y to the "computable space" because y depends on more past values than given
237-
# We try to generate the received symbol sequence with the estimated symbol sequence
233+
# Due to definition that we assume the symbol is at the center tap
234+
# we remove (filter_length - 1)//2 at start and end
235+
h_conv_E_Q_x = butterfly_filter(E_Q_x, mode="valid")
236+
# Limit the length of y to the "computable space" because y depends
237+
# on more past values than given
238+
# We try to generate the received symbol sequence
239+
# with the estimated symbol sequence
238240
C = torch.sum(
239241
y[:, L_offset:-L_offset].real ** 2 + y[:, L_offset:-L_offset].imag ** 2, axis=1
240242
)
@@ -256,9 +258,6 @@ def ELBO_DP(
256258
return loss, var
257259

258260

259-
##############################################################################################
260-
261-
262261
class VAE_LE_DP(torch.nn.Module):
263262
"""
264263
Adaptive Equalizer based on the variational autoencoder principle with a linear equalizer.
@@ -267,7 +266,7 @@ class VAE_LE_DP(torch.nn.Module):
267266
268267
[1] V. Lauinger, F. Buchali, and L. Schmalen, ‘Blind equalization and channel estimation in coherent optical communications using variational autoencoders’,
269268
IEEE Journal on Selected Areas in Communications, vol. 40, no. 9, pp. 2529–2539, Sep. 2022, doi: 10.1109/JSAC.2022.3191346.
270-
"""
269+
""" # noqa
271270

272271
def __init__(
273272
self,
@@ -285,20 +284,25 @@ def __init__(
285284
"""
286285
Initialize :py:class:`VAE_LE_DP`.
287286
288-
This VAE equalizer is implemented with a butterfly linear equalizer in the forward path and a butterfly linear equalizer in
289-
the backward pass. Therefore, it is limited to correct impairments of linear channels.
287+
This VAE equalizer is implemented with a butterfly linear equalizer in the
288+
forward path and a butterfly linear equalizer in the backward pass.
289+
Therefore, it is limited to correct impairments of linear channels.
290290
291291
:param num_taps_forward: number of equalizer taps
292292
:param num_taps_backward: number of channel taps
293293
:param demapper: mokka demapper object to perform complex symbol demapping
294294
:param sps: samples per symbol
295-
:param block_size: number of symbols per block - defines the update rate of the equalizer
295+
:param block_size: number of symbols per block - defines the update rate
296+
of the equalizer
296297
:param lr: learning rate for the adam algorithm
297298
:param requires_q: return q-values in forward call
298-
:param IQ_separate: process I and Q separately - requires a demapper which performs demapping on real values
299-
and a bit-mapping which is equal on I and Q.
300-
:param var_from_estimate: Update the variance in the demapper from the SNR estimate of the output
301-
:param num_block_train: Number of blocks to train the equalizer before switching to non-training equalization mode (for static channels only)
299+
:param IQ_separate: process I and Q separately - requires a demapper
300+
which performs demapping on real values
301+
and a bit-mapping which is equal on I and Q.
302+
:param var_from_estimate: Update the variance in the demapper from
303+
the SNR estimate of the output
304+
:param num_block_train: Number of blocks to train the equalizer before
305+
switching to non-training equalization mode (for static channels only)
302306
"""
303307
super(VAE_LE_DP, self).__init__()
304308

@@ -365,17 +369,17 @@ def forward(self, y):
365369
out = []
366370
out_q = []
367371
# We start our loop already at num_taps (because we cannot equalize the start)
368-
# We will end the loop at num_samps - num_taps - sps*block_size (safety, so we don't overrun)
369-
# We will process sps * block_size - 2 * num_taps because we will cut out the first and last block
372+
# We will end the loop at num_samps - num_taps - sps*block_size
373+
# (safety, so we don't overrun)
374+
# We will process sps * block_size - 2 * num_taps because we will cut out
375+
# the first and last block
370376

371377
index_padding = (self.butterfly_forward.num_taps - 1) // 2
378+
# Back-off one block-size + filter_overlap from end to avoid overrunning
372379
for i, k in enumerate(
373380
range(
374381
index_padding,
375-
num_samps
376-
- index_padding
377-
- self.sps
378-
* self.block_size, # Back-off one block-size + filter_overlap from end to avoid overrunning
382+
num_samps - index_padding - self.sps * self.block_size,
379383
self.sps * self.block_size,
380384
)
381385
):
@@ -387,15 +391,18 @@ def forward(self, y):
387391
k - index_padding,
388392
k + self.sps * self.block_size + index_padding,
389393
)
390-
# Equalization will give sps * block_size samples (because we add (num_taps - 1) in the beginning)
394+
# Equalization will give sps * block_size samples (because we add
395+
# (num_taps - 1) in the beginning)
391396
y_hat = self.butterfly_forward(y[:, in_index], "valid")
392397

393-
# We downsample so we will have floor(((sps * block_size - num_taps + 1) / sps) = floor(block_size - (num_taps - 1)/sps)
398+
# We downsample so we will have
399+
# floor(((sps * block_size - num_taps + 1) / sps)
400+
# = floor(block_size - (num_taps - 1)/sps)
394401
y_symb = y_hat[
395402
:, 0 :: self.sps
396403
] # ---> y[0,(self.butterfly_forward.num_taps + 1)//2 +1 ::self.sps]
397404

398-
if self.IQ_separate == True:
405+
if self.IQ_separate:
399406
q_hat = torch.cat(
400407
(
401408
torch.cat(
@@ -422,8 +429,8 @@ def forward(self, y):
422429
self.demapper(y_symb[1, :]).unsqueeze(0),
423430
)
424431
)
425-
# We calculate the loss with less symbols, since the forward operation with "valid"
426-
# is missing some symbols
432+
# We calculate the loss with less symbols, since the forward operation
433+
# with "valid" is missing some symbols
427434
# We assume the symbol of interest is at the center tap of the filter
428435
y_index = in_index[
429436
(self.butterfly_forward.num_taps - 1)
@@ -440,7 +447,9 @@ def forward(self, y):
440447
IQ_separate=self.IQ_separate,
441448
)
442449

443-
# logger.info("Iteration: %s/%s, VAE loss: %s", i+1, ((num_samps - index_padding - self.sps * self.block_size) // (self.sps * self.block_size)).item(), loss.item())
450+
# logger.info("Iteration: %s/%s, VAE loss: %s", i+1,
451+
# ((num_samps - index_padding - self.sps * self.block_size)
452+
# // (self.sps * self.block_size)).item(), loss.item())
444453

445454
if self.num_block_train is None or (self.num_block_train > i):
446455
# print("noise_sigma: ", self.demapper.noise_sigma)
@@ -450,12 +459,11 @@ def forward(self, y):
450459
self.optimizer.zero_grad()
451460
# self.optimizer_var.zero_grad()
452461

453-
if self.var_from_estimate == True:
462+
if self.var_from_estimate:
454463
self.demapper.noise_sigma = torch.clamp(
455464
torch.sqrt(torch.mean(var.detach().clone()) / 2),
456465
min=torch.tensor(0.05, requires_grad=False, device=q_hat.device),
457-
max=2
458-
* self.demapper.noise_sigma.detach().clone(), # torch.sqrt(var).detach()), min=0.1
466+
max=2 * self.demapper.noise_sigma.detach().clone(),
459467
)
460468

461469
output_symbols = y_symb[
@@ -470,9 +478,10 @@ def forward(self, y):
470478
out_q.append(output_q)
471479

472480
# print("loss: ", loss, "\t\t\t var: ", var)
473-
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps // 2 :])
481+
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps
482+
# // 2 :])
474483

475-
if self.requires_q == True:
484+
if self.requires_q:
476485
eq_out = namedtuple("eq_out", ["y", "q", "var", "loss"])
477486
return eq_out(torch.cat(out, axis=1), torch.cat(out_q, axis=1), var, loss)
478487
return torch.cat(out, axis=1)
@@ -519,10 +528,10 @@ class PilotAEQ_DP(torch.nn.Module):
519528
"""
520529
Perform pilot-based adaptive equalization.
521530
522-
This class performs equalization on a dual polarization signal with a known dual polarization
523-
pilot sequence. The equalization is performed either with the LMS method, ZF method or a
524-
novel LMSZF method which combines the regression vectors of LMS and ZF to improve stability
525-
and channel estimation properties.
531+
This class performs equalization on a dual polarization signal with a known dual
532+
polarization pilot sequence. The equalization is performed either with the LMS
533+
method, ZF method or a novel LMSZF method which combines the regression vectors
534+
of LMS and ZF to improve stability and channel estimation properties.
526535
"""
527536

528537
def __init__(
@@ -550,15 +559,18 @@ def __init__(
550559
:param pilot_sequence: Known dual polarization pilot sequence
551560
:param pilot_sequence_up: Upsampled dual polarization pilot sequence
552561
:param butterfly_filter: :py:class:`mokka.equalizers.torch.Butterfly2x2` object
553-
:param filter_length: If a butterfly_filter argument is not provided the filter length to initialize
554-
the butterfly filter.
562+
:param filter_length: If a butterfly_filter argument is not provided the filter
563+
length to initialize the butterfly filter.
555564
:param method: adaptive update method for the equalizer filter taps
556565
:param block_size: number of symbols to process before each update step
557566
:param adaptive_lr: Adapt learning rate during simulation
558-
:param preeq_method: Use a different method to perform a first-stage equalization
567+
:param preeq_method: Use a different method to perform a first-stage
568+
equalization
559569
:param preeq_offset: Length of first-stage equalization
560-
:param preeq_lradjust: Change learning rate by this factor for first-stage equalization
561-
:param lmszf_weight: if LMSZF is used as equalization method the weight between ZF and LMS update algorithms.
570+
:param preeq_lradjust: Change learning rate by this factor for first-stage
571+
equalization
572+
:param lmszf_weight: if LMSZF is used as equalization method the weight between
573+
ZF and LMS update algorithms.
562574
"""
563575
super(PilotAEQ_DP, self).__init__()
564576
self.register_buffer("sps", torch.as_tensor(sps))
@@ -597,7 +609,8 @@ def forward(self, y):
597609
598610
:param y: Complex receive signal y
599611
"""
600-
# y_cut is perfectly aligned with pilot_sequence_up (after cross correlation & using peak)
612+
# y_cut is perfectly aligned with pilot_sequence_up (after cross
613+
# correlation & using peak)
601614
# The adaptive filter should be able to correct polarization flip on its own
602615
y_cut = correct_start_polarization(
603616
y, self.pilot_sequence_up[:, : y.shape[1]], correct_polarization=False
@@ -678,38 +691,12 @@ def forward(self, y):
678691
+ torch.sqrt(1.0 - torch.as_tensor(self.lmszf_weight))
679692
* self.pilot_sequence_up.clone().conj().resolve_conj()
680693
)
681-
# print(
682-
# "mean y_cut energy: ",
683-
# torch.mean(
684-
# torch.pow(
685-
# torch.abs(
686-
# y_cut.clone()[:, : self.pilot_sequence_up.shape[1]]
687-
# ),
688-
# 2,
689-
# )
690-
# ),
691-
# )
692-
# print(
693-
# "mean pilot_seq_up energy: ",
694-
# torch.mean(
695-
# torch.pow(
696-
# torch.abs(
697-
# self.pilot_sequence_up.clone().conj().resolve_conj()
698-
# ),
699-
# 2,
700-
# )
701-
# ),
702-
# )
703-
704-
# print(
705-
# "mean regression seq energy: ",
706-
# torch.mean(torch.pow(torch.abs(regression_seq), 2)),
707-
# )
708694
if i == self.preeq_offset:
709695
lr = lr * self.preeq_lradjust
710696

711697
if eq_method == "ZFadv":
712-
# Update regression seq by calculating h from f and estimating \hat{y}
698+
# Update regression seq by calculating h from f and
699+
# estimating \hat{y}
713700
# We can use the same function as in the forward pass
714701
f = torch.stack(
715702
(
@@ -785,7 +772,8 @@ def forward(self, y):
785772
self.sps,
786773
)
787774
if self.adaptive_lr:
788-
# For LMS according to Rupp 2011 this stepsize ensures the stability/robustness
775+
# For LMS according to Rupp 2011 this stepsize ensures the
776+
# stability/robustness
789777
lr = (
790778
self.adaptive_scale
791779
* 2
@@ -974,7 +962,8 @@ def reset(self):
974962
def forward(self, y):
975963
# Implement CMA "by hand"
976964
# Basically step through the signal advancing always +sps symbols
977-
# and filtering 2*filter_len samples which will give one output sample with mode "valid"
965+
# and filtering 2*filter_len samples which will give one output sample with
966+
# mode "valid"
978967

979968
equalizer_length = self.taps.shape[0]
980969
num_samp = y.shape[0]
@@ -987,9 +976,8 @@ def forward(self, y):
987976
out = torch.zeros(
988977
(num_samp - equalizer_length) // self.sps, dtype=torch.complex64
989978
)
990-
eq_offset = (
991-
equalizer_length - 1
992-
) // 2 # We try to put the symbol of interest in the center tap of the equalizer
979+
# We try to put the symbol of interest in the center tap of the equalizer
980+
eq_offset = (equalizer_length - 1) // 2
993981
for i, k in enumerate(
994982
range(eq_offset, num_samp - 1 - eq_offset * 2, self.sps * self.block_size)
995983
):

0 commit comments

Comments
 (0)