Skip to content

Commit 8acbe04

Browse files
committed
synchronizers: Add pilot synchronizers
1 parent d86156b commit 8acbe04

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

src/mokka/synchronizers/phase/torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .bps import BPS # noqa
44
from . import vandv # noqa
55
from . import cycleslip_comp # noqa
6+
from . import pilot # noqa
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Module implementing pilot phase synchronizers."""
2+
import torch
3+
from ....functional.torch import convolve
4+
from ....functional.torch import unwrap_torch
5+
6+
7+
class PilotInserter(torch.nn.Module):
8+
"""
9+
Insert a pilot symbol at a regular interval.
10+
11+
Either a QPSK symbol or another pre-defined symbol.
12+
"""
13+
14+
def __init__(self, block_length=64, pilot_symbol=None, **kwargs):
15+
"""
16+
Initialize :py:class:`PilotInserter`.
17+
18+
:param block_length: Blocklength of signal plus one pilot symbol
19+
:param pilot_symbol: Pilot symbol or pilot sequence for consecutive blocks
20+
"""
21+
super(PilotInserter, self).__init__(**kwargs)
22+
self.block_length = block_length
23+
if pilot_symbol is None:
24+
self.pilot_symbol = 1 / (torch.sqrt(torch.tensor(2))) * (1 + 1j)
25+
else:
26+
self.pilot_symbol = pilot_symbol
27+
28+
def forward(self, y):
29+
"""
30+
Insert pilot symbols into transmit signal y.
31+
32+
:param y: Transmit symbols
33+
"""
34+
y = torch.squeeze(y)
35+
if len(y.size()) < 2:
36+
y = torch.reshape(torch.unsqueeze(y, -1), (-1, self.block_length - 1))
37+
else:
38+
assert y.size()[1] == self.block_length - 1
39+
40+
if not self.pilot_symbol.size() or self.pilot_symbol.size()[0] < y.size()[0]:
41+
if not self.pilot_symbol.size():
42+
pilot_size = 1
43+
else:
44+
pilot_size = self.pilot_symbol.size()[0]
45+
pilots = torch.repeat_interleave(
46+
self.pilot_symbol,
47+
torch.tensor((-1 * (-y.size()[0] // pilot_size),)),
48+
)[: y.size()[0]].to(y.device)
49+
else:
50+
pilots = self.pilot_symbol.to(y.device)
51+
pilots = torch.unsqueeze(pilots, -1)
52+
signal = torch.concat((pilots, y), dim=-1).flatten()
53+
return signal
54+
55+
56+
class PilotPhaseCompensation(torch.nn.Module):
57+
"""
58+
Phase compensation based on inserted pilot symbols.
59+
60+
This block performs the phase compensation based on pilot symbols inserted by
61+
:py:class:`PilotInserter`.
62+
"""
63+
64+
def __init__(
65+
self, block_length=64, pilot_symbol=None, window_size=10 * 64, **kwargs
66+
):
67+
"""
68+
Initialize :py:class:`PilotPhaseCompensation`.
69+
70+
:param block_length: Block length of the transmit signal plus pilots
71+
:param pilot_symbol: Either a single pilot symbol or a sequence which is
72+
inserted to consecutive transmit symbol blocks
73+
:param window_size: Window size to use for moving average of the phase
74+
estimation.
75+
"""
76+
super(PilotPhaseCompensation, self).__init__(**kwargs)
77+
self.block_length = block_length
78+
self.window_size = window_size
79+
if pilot_symbol is None:
80+
self.pilot_symbol = 1 / (torch.sqrt(torch.tensor(2))) * (1 + 1j)
81+
else:
82+
self.pilot_symbol = pilot_symbol
83+
84+
def forward(self, y):
85+
"""
86+
Apply the phase compensation to received signal `y`.
87+
88+
:param y: Received signal
89+
"""
90+
if len(y.size()) < 2:
91+
y = torch.reshape(torch.unsqueeze(y, 1), (-1, self.block_length))
92+
else:
93+
assert y.size()[1] == self.block_length
94+
if not self.pilot_symbol.size() or self.pilot_symbol.size()[0] < y.size()[0]:
95+
if not self.pilot_symbol.size():
96+
pilot_size = 1
97+
else:
98+
pilot_size = self.pilot_symbol.size()[0]
99+
pilots = torch.repeat_interleave(
100+
self.pilot_symbol,
101+
torch.tensor(
102+
-1 * (-y.size()[0] // pilot_size),
103+
),
104+
)[: y.size()[0]]
105+
else:
106+
pilots = self.pilot_symbol
107+
pilots = torch.unsqueeze(pilots, -1).to(y.device)
108+
109+
received_pilots = torch.unsqueeze(y[:, 0], -1)
110+
phase_est = (
111+
torch.angle(received_pilots * torch.conj(pilots))
112+
.type(torch.float32)
113+
.flatten()
114+
)
115+
phase_est = unwrap_torch(phase_est)
116+
117+
phase_comp = torch.zeros_like(y, dtype=torch.float32)
118+
phase_comp[:, 0] = phase_est
119+
phase_comp = phase_comp.flatten()
120+
121+
filter_kernel = torch.ones((self.window_size,), dtype=torch.complex64).to(
122+
y.device
123+
)
124+
125+
phase_comp_val = torch.zeros_like(phase_comp, dtype=torch.float32)
126+
phase_comp_val[:: self.block_length] = 1.0
127+
phase_comp_norm = convolve(
128+
phase_comp_val.type(torch.float32),
129+
filter_kernel.type(torch.float32),
130+
mode="same",
131+
)
132+
133+
phase_est = (
134+
convolve(phase_comp.type(torch.complex64), filter_kernel, mode="same")
135+
/ phase_comp_norm
136+
)
137+
y = y.flatten() * torch.exp(-1j * phase_est)
138+
y = torch.reshape(y, (-1, self.block_length))[:, 1:].flatten()
139+
140+
return y

0 commit comments

Comments
 (0)