Skip to content

Dev radio #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions configs/radio_meerkat_macro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#Change checkpoint and sense_map path
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/

# Define the experience
experience: radio

# Number of code vectors for each phase
num_z_test: 32
num_z_valid: 8
num_z_train: 2

# Data
in_chans: 2 # Real+Imag parts from obs
out_chans: 1
im_size: 360 #384x384 pixel images

# Options
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
norm: macro # none, micro, macro

# Optimizer:
lr: 0.001
beta_1: 0
beta_2: 0.99

# Loss weights
gp_weight: 10
adv_weight: 1e-5

# Training
batch_size: 2 # per GPU
accumulate_grad_batches: 2

#Remember to increase this for full training
num_epochs: 100
psnr_gain_tol: 0.25

num_workers: 4
40 changes: 40 additions & 0 deletions configs/radio_meerkat_macro_gradient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#Change checkpoint and sense_map path
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/

# Define the experience
experience: radio

# Number of code vectors for each phase
num_z_test: 32
num_z_valid: 8
num_z_train: 2

# Data
in_chans: 2 # Real+Imag parts from obs
out_chans: 1
im_size: 360 #384x384 pixel images

# Options
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
norm: macro # none, micro, macro
gradient: True

# Optimizer:
lr: 0.001
beta_1: 0
beta_2: 0.99

# Loss weights
gp_weight: 10
adv_weight: 1e-5

# Training
batch_size: 2 # per GPU
accumulate_grad_batches: 2

#Remember to increase this for full training
num_epochs: 100
psnr_gain_tol: 0.25

num_workers: 4
88 changes: 74 additions & 14 deletions data/datasets/Radio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,41 @@

class RadioDataset_Test(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Test/x.npy, Test/y.npy, Test/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)


if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
# load means and stds from train set
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))

# self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
# self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
# self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()

def __len__(self):
"""Returns the number of samples in the dataset."""
Expand All @@ -37,21 +56,41 @@ def __getitem__(self,i):

class RadioDataset_Val(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Val/x.npy, Val/y.npy, Val/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)

if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
# load means and stds from train set
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))

# self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
# self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
# self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()

def __len__(self):
"""Returns the number of samples in the dataset."""
Expand All @@ -66,22 +105,43 @@ def __getitem__(self,i):

class RadioDataset_Train(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Train/x.npy, Train/y.npy, Train/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)


if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
self.transform.mean_x, self.transform.std_x = self.x.mean(), np.mean(self.x.std(axis=(1,2)))
self.transform.mean_y, self.transform.std_y = self.y.mean(), np.mean(self.y.std(axis=(1,2)))
self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), np.mean(self.uv.std(axis=(1,2)))

np.save(data_dir.joinpath("mean_x.npy"), self.transform.mean_x)
np.save(data_dir.joinpath("std_x.npy"), self.transform.std_x)
np.save(data_dir.joinpath("mean_y.npy"), self.transform.mean_y)
np.save(data_dir.joinpath("std_y.npy"), self.transform.std_y)
np.save(data_dir.joinpath("mean_uv.npy"), self.transform.mean_uv)
np.save(data_dir.joinpath("std_uv.npy"), self.transform.std_uv)




def __len__(self):
"""Returns the number of samples in the dataset."""
Expand Down
31 changes: 22 additions & 9 deletions data/lightning/RadioDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, args, test=False, ISNR=30):
self.args = args
self.test = test
self.ISNR = ISNR

self.norm = args.__dict__.get('norm', 'micro')

def __call__(self, data) -> Tuple[float, float, float, float]:
""" Transforms the data.
Expand All @@ -36,21 +38,28 @@ def __call__(self, data) -> Tuple[float, float, float, float]:
x, y, uv = data



# Format input gt data.
pt_x = transforms.to_tensor(x) # Shape (H, W, 2)
pt_x = transforms.to_tensor(x)[:, :, None] # Shape (H, W, 2)
pt_x = pt_x.permute(2, 0, 1) # Shape (2, H, W)
# Format observation data.
pt_y = transforms.to_tensor(y) # Shape (H, W, 2)
pt_y = transforms.to_tensor(y)[:, :, None] # Shape (H, W, 2)
pt_y = pt_y.permute(2, 0, 1) # Shape (2, H, W)
# Format uv data
pt_uv = transforms.to_tensor(uv)[:, :, None] # Shape (H, W, 1)
pt_uv = pt_uv.permute(2, 0, 1) # Shape (1, H, W)
# Normalize everything based on measurements y
normalized_y, mean, std = transforms.normalize_instance(pt_y)
normalized_x = transforms.normalize(pt_x, mean, std)
normalized_uv = transforms.normalize(pt_uv, mean, std)


if self.norm != 'micro':
normalized_y = transforms.normalize(pt_y, self.mean_y, self.std_y) # scale globally
normalized_x = transforms.normalize(pt_x, self.mean_x, self.std_x) # scale globally
normalized_uv = transforms.normalize(pt_uv, self.mean_uv, self.std_uv) # scale globally
mean, std = self.mean_x, self.std_x
elif self.norm == 'micro':
normalized_y, mean, std = transforms.normalize_instance(pt_y)
normalized_x = transforms.normalize(pt_x, mean, std) # scale based on input
normalized_uv, _, _ = transforms.normalize_instance(pt_uv) # scale on intself

# Use normalized stack of y + uv
normalized_y = torch.cat([normalized_y, normalized_uv], dim=0)
Expand All @@ -72,6 +81,7 @@ def __init__(self, args):
super().__init__()
self.prepare_data_per_node = True
self.args = args
self.norm = args.__dict__.get('norm', 'micro')

def prepare_data(self):
pass
Expand All @@ -81,17 +91,20 @@ def setup(self, stage: Optional[str] = None):

train_data = RadioDataset_Train(
data_dir=pathlib.Path(self.args.data_path) / 'train',
transform=RadioDataTransform(self.args, test=False)
transform=RadioDataTransform(self.args, test=False),
norm=self.norm
)

dev_data = RadioDataset_Val(
data_dir=pathlib.Path(self.args.data_path) / 'val',
transform=RadioDataTransform(self.args, test=True)
transform=RadioDataTransform(self.args, test=True),
norm=self.norm
)

test_data = RadioDataset_Test(
data_dir=pathlib.Path(self.args.data_path) / 'test',
transform=RadioDataTransform(self.args, test=True)
transform=RadioDataTransform(self.args, test=True),
norm=self.norm
)

self.train, self.validate, self.test = train_data, dev_data, test_data
Expand Down
Empty file.
Loading