Skip to content

Commit

Permalink
Format files to conform to flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
noc0lour committed Jan 7, 2025
1 parent 8fc57a9 commit 29a0f04
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 109 deletions.
186 changes: 87 additions & 99 deletions src/mokka/equalizers/adaptive/torch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""PyTorch implementations of adaptive equalizers."""

from ..torch import Butterfly2x2
from ..torch import correct_start_polarization, correct_start, find_start_offset
from ..torch import correct_start_polarization, correct_start
from ...functional.torch import convolve_overlap_save
from ..torch import h2f
import torch
import logging
from collections import namedtuple

logger = logging.getLogger(__name__)

from collections import namedtuple


class CMA(torch.nn.Module):
"""Class to perform CMA equalization."""
Expand All @@ -33,12 +33,17 @@ def __init__(
:param R: constant modulus radius
:param sps: samples per symbol
:param lr: learning rate
:param butterfly_filter: Optional :py:class:`mokka.equalizers.torch.Butterfly2x2` object
:param filter_length: butterfly filter length (if object is not given)
:param block_size: Number of symbols to process before updating the equalizer taps
:param no_singularity: Initialize the x- and y-polarization to avoid singularity by decoding
the signal from the same polarization twice.
:param singularity_length: Delay for initialization with the no_singularity approach
:param butterfly_filter: Optional
:py:class:`mokka.equalizers.torch.Butterfly2x2` object
:param filter_length: butterfly filter length
(if object is not given)
:param block_size: Number of symbols to process before
updating the equalizer taps
:param no_singularity: Initialize the x- and y-polarization to avoid
singularity by decoding
the signal from the same polarization twice.
:param singularity_length: Delay for initialization with the
no_singularity approach
"""
super(CMA, self).__init__()
self.register_buffer("R", torch.as_tensor(R))
Expand Down Expand Up @@ -79,7 +84,8 @@ def forward(self, y):
"""
# Implement CMA "by hand"
# Basically step through the signal advancing always +sps symbols
# and filtering 2*filter_len samples which will give one output sample with mode "valid"
# and filtering 2*filter_len samples which will give one output
# sample with mode "valid"

equalizer_length = self.butterfly_filter.taps.size()[1]
num_samp = y.shape[1]
Expand All @@ -92,9 +98,8 @@ def forward(self, y):
out = torch.zeros(
2, (num_samp - equalizer_length) // self.sps, dtype=torch.complex64
)
eq_offset = (
equalizer_length - 1
) // 2 # We try to put the symbol of interest in the center tap of the equalizer
# We try to put the symbol of interest in the center tap of the equalizer
eq_offset = (equalizer_length - 1) // 2
for i, k in enumerate(
range(eq_offset, num_samp - 1 - eq_offset * 2, self.sps * self.block_size)
):
Expand Down Expand Up @@ -143,11 +148,6 @@ def get_error_signal(self):
return self.out_e


##############################################################################################
########################### Variational Autoencoer based Equalizer ###########################
##############################################################################################


def ELBO_DP(
y,
q,
Expand All @@ -168,8 +168,8 @@ def ELBO_DP(
N = y.shape[1]
# Now we have two polarizations in the first dimension
# We assume the same transmit constellation for both, calculating
# q needs to be shaped 2 x N x M -> for each observation on each polarization we have M q-values
# we have M constellation symbols
# q needs to be shaped 2 x N x M -> for each observation on each polarization we
# have M q-values and we have M constellation symbols
L = butterfly_filter.taps.shape[1]
L_offset = (L - 1) // 2
if p_constellation is None:
Expand All @@ -180,7 +180,7 @@ def ELBO_DP(
# # Precompute E_Q{c} = sum( q * c) where c is x and |x|**2
E_Q_x = torch.zeros(2, N, device=q.device, dtype=torch.complex64)
E_Q_x_abssq = torch.zeros(2, N, device=q.device, dtype=torch.float32)
if IQ_separate == True:
if IQ_separate:
num_lev = constellation_symbols.shape[0]
E_Q_x[:, ::sps] = torch.complex(
torch.sum(
Expand Down Expand Up @@ -217,8 +217,8 @@ def ELBO_DP(
axis=-1,
)

# Term A - sum all the things, but spare the first dimension, since the two polarizations
# are sorta independent
# Term A - sum all the things, but spare the first dimension,
# since the two polarizations are sorta independent
bias = 1e-14
A = torch.sum(
q[:, L_offset:-L_offset, :]
Expand All @@ -230,11 +230,13 @@ def ELBO_DP(
)

# Precompute h \ast E_Q{x}
h_conv_E_Q_x = butterfly_filter(
E_Q_x, mode="valid"
) # Due to definition that we assume the symbol is at the center tap we remove (filter_length - 1)//2 at start and end
# Limit the length of y to the "computable space" because y depends on more past values than given
# We try to generate the received symbol sequence with the estimated symbol sequence
# Due to definition that we assume the symbol is at the center tap
# we remove (filter_length - 1)//2 at start and end
h_conv_E_Q_x = butterfly_filter(E_Q_x, mode="valid")
# Limit the length of y to the "computable space" because y depends
# on more past values than given
# We try to generate the received symbol sequence
# with the estimated symbol sequence
C = torch.sum(
y[:, L_offset:-L_offset].real ** 2 + y[:, L_offset:-L_offset].imag ** 2, axis=1
)
Expand All @@ -256,9 +258,6 @@ def ELBO_DP(
return loss, var


##############################################################################################


class VAE_LE_DP(torch.nn.Module):
"""
Adaptive Equalizer based on the variational autoencoder principle with a linear equalizer.
Expand All @@ -267,7 +266,7 @@ class VAE_LE_DP(torch.nn.Module):
[1] V. Lauinger, F. Buchali, and L. Schmalen, ‘Blind equalization and channel estimation in coherent optical communications using variational autoencoders’,
IEEE Journal on Selected Areas in Communications, vol. 40, no. 9, pp. 2529–2539, Sep. 2022, doi: 10.1109/JSAC.2022.3191346.
"""
""" # noqa

def __init__(
self,
Expand All @@ -285,20 +284,25 @@ def __init__(
"""
Initialize :py:class:`VAE_LE_DP`.
This VAE equalizer is implemented with a butterfly linear equalizer in the forward path and a butterfly linear equalizer in
the backward pass. Therefore, it is limited to correct impairments of linear channels.
This VAE equalizer is implemented with a butterfly linear equalizer in the
forward path and a butterfly linear equalizer in the backward pass.
Therefore, it is limited to correct impairments of linear channels.
:param num_taps_forward: number of equalizer taps
:param num_taps_backward: number of channel taps
:param demapper: mokka demapper object to perform complex symbol demapping
:param sps: samples per symbol
:param block_size: number of symbols per block - defines the update rate of the equalizer
:param block_size: number of symbols per block - defines the update rate
of the equalizer
:param lr: learning rate for the adam algorithm
:param requires_q: return q-values in forward call
:param IQ_separate: process I and Q separately - requires a demapper which performs demapping on real values
and a bit-mapping which is equal on I and Q.
:param var_from_estimate: Update the variance in the demapper from the SNR estimate of the output
:param num_block_train: Number of blocks to train the equalizer before switching to non-training equalization mode (for static channels only)
:param IQ_separate: process I and Q separately - requires a demapper
which performs demapping on real values
and a bit-mapping which is equal on I and Q.
:param var_from_estimate: Update the variance in the demapper from
the SNR estimate of the output
:param num_block_train: Number of blocks to train the equalizer before
switching to non-training equalization mode (for static channels only)
"""
super(VAE_LE_DP, self).__init__()

Expand Down Expand Up @@ -365,17 +369,17 @@ def forward(self, y):
out = []
out_q = []
# We start our loop already at num_taps (because we cannot equalize the start)
# We will end the loop at num_samps - num_taps - sps*block_size (safety, so we don't overrun)
# We will process sps * block_size - 2 * num_taps because we will cut out the first and last block
# We will end the loop at num_samps - num_taps - sps*block_size
# (safety, so we don't overrun)
# We will process sps * block_size - 2 * num_taps because we will cut out
# the first and last block

index_padding = (self.butterfly_forward.num_taps - 1) // 2
# Back-off one block-size + filter_overlap from end to avoid overrunning
for i, k in enumerate(
range(
index_padding,
num_samps
- index_padding
- self.sps
* self.block_size, # Back-off one block-size + filter_overlap from end to avoid overrunning
num_samps - index_padding - self.sps * self.block_size,
self.sps * self.block_size,
)
):
Expand All @@ -387,15 +391,18 @@ def forward(self, y):
k - index_padding,
k + self.sps * self.block_size + index_padding,
)
# Equalization will give sps * block_size samples (because we add (num_taps - 1) in the beginning)
# Equalization will give sps * block_size samples (because we add
# (num_taps - 1) in the beginning)
y_hat = self.butterfly_forward(y[:, in_index], "valid")

# We downsample so we will have floor(((sps * block_size - num_taps + 1) / sps) = floor(block_size - (num_taps - 1)/sps)
# We downsample so we will have
# floor(((sps * block_size - num_taps + 1) / sps)
# = floor(block_size - (num_taps - 1)/sps)
y_symb = y_hat[
:, 0 :: self.sps
] # ---> y[0,(self.butterfly_forward.num_taps + 1)//2 +1 ::self.sps]

if self.IQ_separate == True:
if self.IQ_separate:
q_hat = torch.cat(
(
torch.cat(
Expand All @@ -422,8 +429,8 @@ def forward(self, y):
self.demapper(y_symb[1, :]).unsqueeze(0),
)
)
# We calculate the loss with less symbols, since the forward operation with "valid"
# is missing some symbols
# We calculate the loss with less symbols, since the forward operation
# with "valid" is missing some symbols
# We assume the symbol of interest is at the center tap of the filter
y_index = in_index[
(self.butterfly_forward.num_taps - 1)
Expand All @@ -440,7 +447,9 @@ def forward(self, y):
IQ_separate=self.IQ_separate,
)

# logger.info("Iteration: %s/%s, VAE loss: %s", i+1, ((num_samps - index_padding - self.sps * self.block_size) // (self.sps * self.block_size)).item(), loss.item())
# logger.info("Iteration: %s/%s, VAE loss: %s", i+1,
# ((num_samps - index_padding - self.sps * self.block_size)
# // (self.sps * self.block_size)).item(), loss.item())

if self.num_block_train is None or (self.num_block_train > i):
# print("noise_sigma: ", self.demapper.noise_sigma)
Expand All @@ -450,12 +459,11 @@ def forward(self, y):
self.optimizer.zero_grad()
# self.optimizer_var.zero_grad()

if self.var_from_estimate == True:
if self.var_from_estimate:
self.demapper.noise_sigma = torch.clamp(
torch.sqrt(torch.mean(var.detach().clone()) / 2),
min=torch.tensor(0.05, requires_grad=False, device=q_hat.device),
max=2
* self.demapper.noise_sigma.detach().clone(), # torch.sqrt(var).detach()), min=0.1
max=2 * self.demapper.noise_sigma.detach().clone(),
)

output_symbols = y_symb[
Expand All @@ -470,9 +478,10 @@ def forward(self, y):
out_q.append(output_q)

# print("loss: ", loss, "\t\t\t var: ", var)
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps // 2 :])
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps
# // 2 :])

if self.requires_q == True:
if self.requires_q:
eq_out = namedtuple("eq_out", ["y", "q", "var", "loss"])
return eq_out(torch.cat(out, axis=1), torch.cat(out_q, axis=1), var, loss)
return torch.cat(out, axis=1)
Expand Down Expand Up @@ -519,10 +528,10 @@ class PilotAEQ_DP(torch.nn.Module):
"""
Perform pilot-based adaptive equalization.
This class performs equalization on a dual polarization signal with a known dual polarization
pilot sequence. The equalization is performed either with the LMS method, ZF method or a
novel LMSZF method which combines the regression vectors of LMS and ZF to improve stability
and channel estimation properties.
This class performs equalization on a dual polarization signal with a known dual
polarization pilot sequence. The equalization is performed either with the LMS
method, ZF method or a novel LMSZF method which combines the regression vectors
of LMS and ZF to improve stability and channel estimation properties.
"""

def __init__(
Expand Down Expand Up @@ -550,15 +559,18 @@ def __init__(
:param pilot_sequence: Known dual polarization pilot sequence
:param pilot_sequence_up: Upsampled dual polarization pilot sequence
:param butterfly_filter: :py:class:`mokka.equalizers.torch.Butterfly2x2` object
:param filter_length: If a butterfly_filter argument is not provided the filter length to initialize
the butterfly filter.
:param filter_length: If a butterfly_filter argument is not provided the filter
length to initialize the butterfly filter.
:param method: adaptive update method for the equalizer filter taps
:param block_size: number of symbols to process before each update step
:param adaptive_lr: Adapt learning rate during simulation
:param preeq_method: Use a different method to perform a first-stage equalization
:param preeq_method: Use a different method to perform a first-stage
equalization
:param preeq_offset: Length of first-stage equalization
:param preeq_lradjust: Change learning rate by this factor for first-stage equalization
:param lmszf_weight: if LMSZF is used as equalization method the weight between ZF and LMS update algorithms.
:param preeq_lradjust: Change learning rate by this factor for first-stage
equalization
:param lmszf_weight: if LMSZF is used as equalization method the weight between
ZF and LMS update algorithms.
"""
super(PilotAEQ_DP, self).__init__()
self.register_buffer("sps", torch.as_tensor(sps))
Expand Down Expand Up @@ -597,7 +609,8 @@ def forward(self, y):
:param y: Complex receive signal y
"""
# y_cut is perfectly aligned with pilot_sequence_up (after cross correlation & using peak)
# y_cut is perfectly aligned with pilot_sequence_up (after cross
# correlation & using peak)
# The adaptive filter should be able to correct polarization flip on its own
y_cut = correct_start_polarization(
y, self.pilot_sequence_up[:, : y.shape[1]], correct_polarization=False
Expand Down Expand Up @@ -678,38 +691,12 @@ def forward(self, y):
+ torch.sqrt(1.0 - torch.as_tensor(self.lmszf_weight))
* self.pilot_sequence_up.clone().conj().resolve_conj()
)
# print(
# "mean y_cut energy: ",
# torch.mean(
# torch.pow(
# torch.abs(
# y_cut.clone()[:, : self.pilot_sequence_up.shape[1]]
# ),
# 2,
# )
# ),
# )
# print(
# "mean pilot_seq_up energy: ",
# torch.mean(
# torch.pow(
# torch.abs(
# self.pilot_sequence_up.clone().conj().resolve_conj()
# ),
# 2,
# )
# ),
# )

# print(
# "mean regression seq energy: ",
# torch.mean(torch.pow(torch.abs(regression_seq), 2)),
# )
if i == self.preeq_offset:
lr = lr * self.preeq_lradjust

if eq_method == "ZFadv":
# Update regression seq by calculating h from f and estimating \hat{y}
# Update regression seq by calculating h from f and
# estimating \hat{y}
# We can use the same function as in the forward pass
f = torch.stack(
(
Expand Down Expand Up @@ -785,7 +772,8 @@ def forward(self, y):
self.sps,
)
if self.adaptive_lr:
# For LMS according to Rupp 2011 this stepsize ensures the stability/robustness
# For LMS according to Rupp 2011 this stepsize ensures the
# stability/robustness
lr = (
self.adaptive_scale
* 2
Expand Down Expand Up @@ -974,7 +962,8 @@ def reset(self):
def forward(self, y):
# Implement CMA "by hand"
# Basically step through the signal advancing always +sps symbols
# and filtering 2*filter_len samples which will give one output sample with mode "valid"
# and filtering 2*filter_len samples which will give one output sample with
# mode "valid"

equalizer_length = self.taps.shape[0]
num_samp = y.shape[0]
Expand All @@ -987,9 +976,8 @@ def forward(self, y):
out = torch.zeros(
(num_samp - equalizer_length) // self.sps, dtype=torch.complex64
)
eq_offset = (
equalizer_length - 1
) // 2 # We try to put the symbol of interest in the center tap of the equalizer
# We try to put the symbol of interest in the center tap of the equalizer
eq_offset = (equalizer_length - 1) // 2
for i, k in enumerate(
range(eq_offset, num_samp - 1 - eq_offset * 2, self.sps * self.block_size)
):
Expand Down
Loading

0 comments on commit 29a0f04

Please sign in to comment.