1
1
"""PyTorch implementations of adaptive equalizers."""
2
+
2
3
from ..torch import Butterfly2x2
3
- from ..torch import correct_start_polarization , correct_start , find_start_offset
4
+ from ..torch import correct_start_polarization , correct_start
4
5
from ...functional .torch import convolve_overlap_save
5
6
from ..torch import h2f
6
7
import torch
7
8
import logging
9
+ from collections import namedtuple
8
10
9
11
logger = logging .getLogger (__name__ )
10
12
11
- from collections import namedtuple
12
-
13
13
14
14
class CMA (torch .nn .Module ):
15
15
"""Class to perform CMA equalization."""
@@ -33,12 +33,17 @@ def __init__(
33
33
:param R: constant modulus radius
34
34
:param sps: samples per symbol
35
35
:param lr: learning rate
36
- :param butterfly_filter: Optional :py:class:`mokka.equalizers.torch.Butterfly2x2` object
37
- :param filter_length: butterfly filter length (if object is not given)
38
- :param block_size: Number of symbols to process before updating the equalizer taps
39
- :param no_singularity: Initialize the x- and y-polarization to avoid singularity by decoding
40
- the signal from the same polarization twice.
41
- :param singularity_length: Delay for initialization with the no_singularity approach
36
+ :param butterfly_filter: Optional
37
+ :py:class:`mokka.equalizers.torch.Butterfly2x2` object
38
+ :param filter_length: butterfly filter length
39
+ (if object is not given)
40
+ :param block_size: Number of symbols to process before
41
+ updating the equalizer taps
42
+ :param no_singularity: Initialize the x- and y-polarization to avoid
43
+ singularity by decoding
44
+ the signal from the same polarization twice.
45
+ :param singularity_length: Delay for initialization with the
46
+ no_singularity approach
42
47
"""
43
48
super (CMA , self ).__init__ ()
44
49
self .register_buffer ("R" , torch .as_tensor (R ))
@@ -79,7 +84,8 @@ def forward(self, y):
79
84
"""
80
85
# Implement CMA "by hand"
81
86
# Basically step through the signal advancing always +sps symbols
82
- # and filtering 2*filter_len samples which will give one output sample with mode "valid"
87
+ # and filtering 2*filter_len samples which will give one output
88
+ # sample with mode "valid"
83
89
84
90
equalizer_length = self .butterfly_filter .taps .size ()[1 ]
85
91
num_samp = y .shape [1 ]
@@ -92,9 +98,8 @@ def forward(self, y):
92
98
out = torch .zeros (
93
99
2 , (num_samp - equalizer_length ) // self .sps , dtype = torch .complex64
94
100
)
95
- eq_offset = (
96
- equalizer_length - 1
97
- ) // 2 # We try to put the symbol of interest in the center tap of the equalizer
101
+ # We try to put the symbol of interest in the center tap of the equalizer
102
+ eq_offset = (equalizer_length - 1 ) // 2
98
103
for i , k in enumerate (
99
104
range (eq_offset , num_samp - 1 - eq_offset * 2 , self .sps * self .block_size )
100
105
):
@@ -143,11 +148,6 @@ def get_error_signal(self):
143
148
return self .out_e
144
149
145
150
146
- ##############################################################################################
147
- ########################### Variational Autoencoer based Equalizer ###########################
148
- ##############################################################################################
149
-
150
-
151
151
def ELBO_DP (
152
152
y ,
153
153
q ,
@@ -168,8 +168,8 @@ def ELBO_DP(
168
168
N = y .shape [1 ]
169
169
# Now we have two polarizations in the first dimension
170
170
# We assume the same transmit constellation for both, calculating
171
- # q needs to be shaped 2 x N x M -> for each observation on each polarization we have M q-values
172
- # we have M constellation symbols
171
+ # q needs to be shaped 2 x N x M -> for each observation on each polarization we
172
+ # have M q-values and we have M constellation symbols
173
173
L = butterfly_filter .taps .shape [1 ]
174
174
L_offset = (L - 1 ) // 2
175
175
if p_constellation is None :
@@ -180,7 +180,7 @@ def ELBO_DP(
180
180
# # Precompute E_Q{c} = sum( q * c) where c is x and |x|**2
181
181
E_Q_x = torch .zeros (2 , N , device = q .device , dtype = torch .complex64 )
182
182
E_Q_x_abssq = torch .zeros (2 , N , device = q .device , dtype = torch .float32 )
183
- if IQ_separate == True :
183
+ if IQ_separate :
184
184
num_lev = constellation_symbols .shape [0 ]
185
185
E_Q_x [:, ::sps ] = torch .complex (
186
186
torch .sum (
@@ -217,8 +217,8 @@ def ELBO_DP(
217
217
axis = - 1 ,
218
218
)
219
219
220
- # Term A - sum all the things, but spare the first dimension, since the two polarizations
221
- # are sorta independent
220
+ # Term A - sum all the things, but spare the first dimension,
221
+ # since the two polarizations are sorta independent
222
222
bias = 1e-14
223
223
A = torch .sum (
224
224
q [:, L_offset :- L_offset , :]
@@ -230,11 +230,13 @@ def ELBO_DP(
230
230
)
231
231
232
232
# Precompute h \ast E_Q{x}
233
- h_conv_E_Q_x = butterfly_filter (
234
- E_Q_x , mode = "valid"
235
- ) # Due to definition that we assume the symbol is at the center tap we remove (filter_length - 1)//2 at start and end
236
- # Limit the length of y to the "computable space" because y depends on more past values than given
237
- # We try to generate the received symbol sequence with the estimated symbol sequence
233
+ # Due to definition that we assume the symbol is at the center tap
234
+ # we remove (filter_length - 1)//2 at start and end
235
+ h_conv_E_Q_x = butterfly_filter (E_Q_x , mode = "valid" )
236
+ # Limit the length of y to the "computable space" because y depends
237
+ # on more past values than given
238
+ # We try to generate the received symbol sequence
239
+ # with the estimated symbol sequence
238
240
C = torch .sum (
239
241
y [:, L_offset :- L_offset ].real ** 2 + y [:, L_offset :- L_offset ].imag ** 2 , axis = 1
240
242
)
@@ -256,9 +258,6 @@ def ELBO_DP(
256
258
return loss , var
257
259
258
260
259
- ##############################################################################################
260
-
261
-
262
261
class VAE_LE_DP (torch .nn .Module ):
263
262
"""
264
263
Adaptive Equalizer based on the variational autoencoder principle with a linear equalizer.
@@ -267,7 +266,7 @@ class VAE_LE_DP(torch.nn.Module):
267
266
268
267
[1] V. Lauinger, F. Buchali, and L. Schmalen, ‘Blind equalization and channel estimation in coherent optical communications using variational autoencoders’,
269
268
IEEE Journal on Selected Areas in Communications, vol. 40, no. 9, pp. 2529–2539, Sep. 2022, doi: 10.1109/JSAC.2022.3191346.
270
- """
269
+ """ # noqa
271
270
272
271
def __init__ (
273
272
self ,
@@ -285,20 +284,25 @@ def __init__(
285
284
"""
286
285
Initialize :py:class:`VAE_LE_DP`.
287
286
288
- This VAE equalizer is implemented with a butterfly linear equalizer in the forward path and a butterfly linear equalizer in
289
- the backward pass. Therefore, it is limited to correct impairments of linear channels.
287
+ This VAE equalizer is implemented with a butterfly linear equalizer in the
288
+ forward path and a butterfly linear equalizer in the backward pass.
289
+ Therefore, it is limited to correct impairments of linear channels.
290
290
291
291
:param num_taps_forward: number of equalizer taps
292
292
:param num_taps_backward: number of channel taps
293
293
:param demapper: mokka demapper object to perform complex symbol demapping
294
294
:param sps: samples per symbol
295
- :param block_size: number of symbols per block - defines the update rate of the equalizer
295
+ :param block_size: number of symbols per block - defines the update rate
296
+ of the equalizer
296
297
:param lr: learning rate for the adam algorithm
297
298
:param requires_q: return q-values in forward call
298
- :param IQ_separate: process I and Q separately - requires a demapper which performs demapping on real values
299
- and a bit-mapping which is equal on I and Q.
300
- :param var_from_estimate: Update the variance in the demapper from the SNR estimate of the output
301
- :param num_block_train: Number of blocks to train the equalizer before switching to non-training equalization mode (for static channels only)
299
+ :param IQ_separate: process I and Q separately - requires a demapper
300
+ which performs demapping on real values
301
+ and a bit-mapping which is equal on I and Q.
302
+ :param var_from_estimate: Update the variance in the demapper from
303
+ the SNR estimate of the output
304
+ :param num_block_train: Number of blocks to train the equalizer before
305
+ switching to non-training equalization mode (for static channels only)
302
306
"""
303
307
super (VAE_LE_DP , self ).__init__ ()
304
308
@@ -365,17 +369,17 @@ def forward(self, y):
365
369
out = []
366
370
out_q = []
367
371
# We start our loop already at num_taps (because we cannot equalize the start)
368
- # We will end the loop at num_samps - num_taps - sps*block_size (safety, so we don't overrun)
369
- # We will process sps * block_size - 2 * num_taps because we will cut out the first and last block
372
+ # We will end the loop at num_samps - num_taps - sps*block_size
373
+ # (safety, so we don't overrun)
374
+ # We will process sps * block_size - 2 * num_taps because we will cut out
375
+ # the first and last block
370
376
371
377
index_padding = (self .butterfly_forward .num_taps - 1 ) // 2
378
+ # Back-off one block-size + filter_overlap from end to avoid overrunning
372
379
for i , k in enumerate (
373
380
range (
374
381
index_padding ,
375
- num_samps
376
- - index_padding
377
- - self .sps
378
- * self .block_size , # Back-off one block-size + filter_overlap from end to avoid overrunning
382
+ num_samps - index_padding - self .sps * self .block_size ,
379
383
self .sps * self .block_size ,
380
384
)
381
385
):
@@ -387,15 +391,18 @@ def forward(self, y):
387
391
k - index_padding ,
388
392
k + self .sps * self .block_size + index_padding ,
389
393
)
390
- # Equalization will give sps * block_size samples (because we add (num_taps - 1) in the beginning)
394
+ # Equalization will give sps * block_size samples (because we add
395
+ # (num_taps - 1) in the beginning)
391
396
y_hat = self .butterfly_forward (y [:, in_index ], "valid" )
392
397
393
- # We downsample so we will have floor(((sps * block_size - num_taps + 1) / sps) = floor(block_size - (num_taps - 1)/sps)
398
+ # We downsample so we will have
399
+ # floor(((sps * block_size - num_taps + 1) / sps)
400
+ # = floor(block_size - (num_taps - 1)/sps)
394
401
y_symb = y_hat [
395
402
:, 0 :: self .sps
396
403
] # ---> y[0,(self.butterfly_forward.num_taps + 1)//2 +1 ::self.sps]
397
404
398
- if self .IQ_separate == True :
405
+ if self .IQ_separate :
399
406
q_hat = torch .cat (
400
407
(
401
408
torch .cat (
@@ -422,8 +429,8 @@ def forward(self, y):
422
429
self .demapper (y_symb [1 , :]).unsqueeze (0 ),
423
430
)
424
431
)
425
- # We calculate the loss with less symbols, since the forward operation with "valid"
426
- # is missing some symbols
432
+ # We calculate the loss with less symbols, since the forward operation
433
+ # with "valid" is missing some symbols
427
434
# We assume the symbol of interest is at the center tap of the filter
428
435
y_index = in_index [
429
436
(self .butterfly_forward .num_taps - 1 )
@@ -440,7 +447,9 @@ def forward(self, y):
440
447
IQ_separate = self .IQ_separate ,
441
448
)
442
449
443
- # logger.info("Iteration: %s/%s, VAE loss: %s", i+1, ((num_samps - index_padding - self.sps * self.block_size) // (self.sps * self.block_size)).item(), loss.item())
450
+ # logger.info("Iteration: %s/%s, VAE loss: %s", i+1,
451
+ # ((num_samps - index_padding - self.sps * self.block_size)
452
+ # // (self.sps * self.block_size)).item(), loss.item())
444
453
445
454
if self .num_block_train is None or (self .num_block_train > i ):
446
455
# print("noise_sigma: ", self.demapper.noise_sigma)
@@ -450,12 +459,11 @@ def forward(self, y):
450
459
self .optimizer .zero_grad ()
451
460
# self.optimizer_var.zero_grad()
452
461
453
- if self .var_from_estimate == True :
462
+ if self .var_from_estimate :
454
463
self .demapper .noise_sigma = torch .clamp (
455
464
torch .sqrt (torch .mean (var .detach ().clone ()) / 2 ),
456
465
min = torch .tensor (0.05 , requires_grad = False , device = q_hat .device ),
457
- max = 2
458
- * self .demapper .noise_sigma .detach ().clone (), # torch.sqrt(var).detach()), min=0.1
466
+ max = 2 * self .demapper .noise_sigma .detach ().clone (),
459
467
)
460
468
461
469
output_symbols = y_symb [
@@ -470,9 +478,10 @@ def forward(self, y):
470
478
out_q .append (output_q )
471
479
472
480
# print("loss: ", loss, "\t\t\t var: ", var)
473
- # out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps // 2 :])
481
+ # out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps
482
+ # // 2 :])
474
483
475
- if self .requires_q == True :
484
+ if self .requires_q :
476
485
eq_out = namedtuple ("eq_out" , ["y" , "q" , "var" , "loss" ])
477
486
return eq_out (torch .cat (out , axis = 1 ), torch .cat (out_q , axis = 1 ), var , loss )
478
487
return torch .cat (out , axis = 1 )
@@ -519,10 +528,10 @@ class PilotAEQ_DP(torch.nn.Module):
519
528
"""
520
529
Perform pilot-based adaptive equalization.
521
530
522
- This class performs equalization on a dual polarization signal with a known dual polarization
523
- pilot sequence. The equalization is performed either with the LMS method, ZF method or a
524
- novel LMSZF method which combines the regression vectors of LMS and ZF to improve stability
525
- and channel estimation properties.
531
+ This class performs equalization on a dual polarization signal with a known dual
532
+ polarization pilot sequence. The equalization is performed either with the LMS
533
+ method, ZF method or a novel LMSZF method which combines the regression vectors
534
+ of LMS and ZF to improve stability and channel estimation properties.
526
535
"""
527
536
528
537
def __init__ (
@@ -550,15 +559,18 @@ def __init__(
550
559
:param pilot_sequence: Known dual polarization pilot sequence
551
560
:param pilot_sequence_up: Upsampled dual polarization pilot sequence
552
561
:param butterfly_filter: :py:class:`mokka.equalizers.torch.Butterfly2x2` object
553
- :param filter_length: If a butterfly_filter argument is not provided the filter length to initialize
554
- the butterfly filter.
562
+ :param filter_length: If a butterfly_filter argument is not provided the filter
563
+ length to initialize the butterfly filter.
555
564
:param method: adaptive update method for the equalizer filter taps
556
565
:param block_size: number of symbols to process before each update step
557
566
:param adaptive_lr: Adapt learning rate during simulation
558
- :param preeq_method: Use a different method to perform a first-stage equalization
567
+ :param preeq_method: Use a different method to perform a first-stage
568
+ equalization
559
569
:param preeq_offset: Length of first-stage equalization
560
- :param preeq_lradjust: Change learning rate by this factor for first-stage equalization
561
- :param lmszf_weight: if LMSZF is used as equalization method the weight between ZF and LMS update algorithms.
570
+ :param preeq_lradjust: Change learning rate by this factor for first-stage
571
+ equalization
572
+ :param lmszf_weight: if LMSZF is used as equalization method the weight between
573
+ ZF and LMS update algorithms.
562
574
"""
563
575
super (PilotAEQ_DP , self ).__init__ ()
564
576
self .register_buffer ("sps" , torch .as_tensor (sps ))
@@ -597,7 +609,8 @@ def forward(self, y):
597
609
598
610
:param y: Complex receive signal y
599
611
"""
600
- # y_cut is perfectly aligned with pilot_sequence_up (after cross correlation & using peak)
612
+ # y_cut is perfectly aligned with pilot_sequence_up (after cross
613
+ # correlation & using peak)
601
614
# The adaptive filter should be able to correct polarization flip on its own
602
615
y_cut = correct_start_polarization (
603
616
y , self .pilot_sequence_up [:, : y .shape [1 ]], correct_polarization = False
@@ -678,38 +691,12 @@ def forward(self, y):
678
691
+ torch .sqrt (1.0 - torch .as_tensor (self .lmszf_weight ))
679
692
* self .pilot_sequence_up .clone ().conj ().resolve_conj ()
680
693
)
681
- # print(
682
- # "mean y_cut energy: ",
683
- # torch.mean(
684
- # torch.pow(
685
- # torch.abs(
686
- # y_cut.clone()[:, : self.pilot_sequence_up.shape[1]]
687
- # ),
688
- # 2,
689
- # )
690
- # ),
691
- # )
692
- # print(
693
- # "mean pilot_seq_up energy: ",
694
- # torch.mean(
695
- # torch.pow(
696
- # torch.abs(
697
- # self.pilot_sequence_up.clone().conj().resolve_conj()
698
- # ),
699
- # 2,
700
- # )
701
- # ),
702
- # )
703
-
704
- # print(
705
- # "mean regression seq energy: ",
706
- # torch.mean(torch.pow(torch.abs(regression_seq), 2)),
707
- # )
708
694
if i == self .preeq_offset :
709
695
lr = lr * self .preeq_lradjust
710
696
711
697
if eq_method == "ZFadv" :
712
- # Update regression seq by calculating h from f and estimating \hat{y}
698
+ # Update regression seq by calculating h from f and
699
+ # estimating \hat{y}
713
700
# We can use the same function as in the forward pass
714
701
f = torch .stack (
715
702
(
@@ -785,7 +772,8 @@ def forward(self, y):
785
772
self .sps ,
786
773
)
787
774
if self .adaptive_lr :
788
- # For LMS according to Rupp 2011 this stepsize ensures the stability/robustness
775
+ # For LMS according to Rupp 2011 this stepsize ensures the
776
+ # stability/robustness
789
777
lr = (
790
778
self .adaptive_scale
791
779
* 2
@@ -974,7 +962,8 @@ def reset(self):
974
962
def forward (self , y ):
975
963
# Implement CMA "by hand"
976
964
# Basically step through the signal advancing always +sps symbols
977
- # and filtering 2*filter_len samples which will give one output sample with mode "valid"
965
+ # and filtering 2*filter_len samples which will give one output sample with
966
+ # mode "valid"
978
967
979
968
equalizer_length = self .taps .shape [0 ]
980
969
num_samp = y .shape [0 ]
@@ -987,9 +976,8 @@ def forward(self, y):
987
976
out = torch .zeros (
988
977
(num_samp - equalizer_length ) // self .sps , dtype = torch .complex64
989
978
)
990
- eq_offset = (
991
- equalizer_length - 1
992
- ) // 2 # We try to put the symbol of interest in the center tap of the equalizer
979
+ # We try to put the symbol of interest in the center tap of the equalizer
980
+ eq_offset = (equalizer_length - 1 ) // 2
993
981
for i , k in enumerate (
994
982
range (eq_offset , num_samp - 1 - eq_offset * 2 , self .sps * self .block_size )
995
983
):
0 commit comments