Skip to content

Commit 4815426

Browse files
authored
Merge pull request #2 from brenjohn/velocity
Velocity
2 parents 07a0e2c + a385583 commit 4815426

36 files changed

+1811
-768
lines changed

dmsr/data_tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .dataset import *
2+
from .utils import *

dmsr/data_tools/dataset.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Fri Sep 13 13:45:00 2024
5+
6+
@author: brennan
7+
8+
This file defines a DMSR-Dataset class for handling training data for a
9+
DMSR-WGAN model.
10+
"""
11+
12+
import torch
13+
14+
from torch.utils.data import Dataset
15+
from ..field_operations.augmentation import permute_tensor
16+
17+
18+
class DMSRDataset(Dataset):
19+
"""A Dataset class for holding training data for the DMSR-WGAN.
20+
"""
21+
22+
def __init__(
23+
self,
24+
lr_position,
25+
hr_position,
26+
lr_velocity = None,
27+
hr_velocity = None,
28+
augment = True
29+
):
30+
self.lr_position = lr_position
31+
self.hr_position = hr_position
32+
self.lr_velocity = lr_velocity
33+
self.hr_velocity = hr_velocity
34+
self.velocities_included = not lr_velocity is None
35+
self.augment = augment
36+
37+
38+
def __len__(self):
39+
return self.lr_position.size(0)
40+
41+
42+
def __getitem__(self, idx):
43+
lr_data = self.lr_position[idx]
44+
hr_data = self.hr_position[idx]
45+
46+
# Apply augmentation (random flip/permutation) if specified
47+
if self.augment:
48+
random_perm = torch.randperm(3)
49+
lr_data = permute_tensor(lr_data, random_perm)
50+
hr_data = permute_tensor(hr_data, random_perm)
51+
52+
if self.velocities_included:
53+
lr_velocity = self.lr_velocity[idx]
54+
hr_velocity = self.hr_velocity[idx]
55+
56+
if self.augment:
57+
lr_velocity = permute_tensor(lr_velocity, random_perm)
58+
hr_velocity = permute_tensor(hr_velocity, random_perm)
59+
60+
lr_data = torch.concat((lr_data, lr_velocity))
61+
hr_data = torch.concat((hr_data, hr_velocity))
62+
63+
return lr_data, hr_data
64+
65+
66+
def normalise_dataset(self):
67+
"""Scales position and velocity data by dividing by their respective
68+
standard deviations. The standard deviations are also returned as a
69+
dictionary.
70+
"""
71+
params = {}
72+
field_names = ["lr_position", "hr_position"]
73+
if self.velocities_included:
74+
field_names += ["lr_velocity", "hr_velocity"]
75+
76+
for field in field_names:
77+
standard_deviation = vars(self)[field].std()
78+
params[field + "_std"] = standard_deviation.item()
79+
vars(self)[field] /= standard_deviation
80+
81+
return params

dmsr/data_tools/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Feb 27 11:56:32 2025
5+
6+
@author: brennan
7+
"""
8+
9+
import torch
10+
import numpy as np
11+
12+
from os.path import exists
13+
14+
15+
def load_numpy_dataset(data_directory):
16+
"""Returns LR and HR data contained in numpy files saved in the given
17+
directory.
18+
"""
19+
LR_data = np.load(data_directory + 'LR_fields.npy')
20+
LR_data = torch.from_numpy(LR_data)
21+
22+
HR_data = np.load(data_directory + 'HR_fields.npy')
23+
HR_data = torch.from_numpy(HR_data)
24+
25+
meta_file = data_directory + 'metadata.npy'
26+
meta_data = np.load(meta_file)
27+
box_size, HR_patch_size, LR_size, HR_size, LR_mass, HR_mass = meta_data
28+
29+
return LR_data, HR_data, HR_patch_size, LR_size, HR_size
30+
31+
32+
def load_normalisation_parameters(param_file):
33+
"""Reads the standard deviations from the given .npy file used to noramlise
34+
dmsr training data.
35+
"""
36+
lr_pos_std = hr_pos_std = lr_vel_std = hr_vel_std = 1
37+
38+
if exists(param_file):
39+
scale_params = np.load(param_file, allow_pickle=True).item()
40+
scale_params = {k : v.item() for k, v in scale_params.items()}
41+
lr_pos_std = scale_params.get('lr_position_std', 1)
42+
hr_pos_std = scale_params.get('hr_position_std', 1)
43+
lr_vel_std = scale_params.get('lr_velocity_std', 1)
44+
hr_vel_std = scale_params.get('hr_velocity_std', 1)
45+
46+
return lr_pos_std, hr_pos_std, lr_vel_std, hr_vel_std
47+
48+
49+
def generate_mock_data(lr_grid_size, hr_grid_size, channels, samples):
50+
"""Create a mock training data set for testing.
51+
"""
52+
box_size = 1
53+
shape = (samples, channels, lr_grid_size, lr_grid_size, lr_grid_size)
54+
LR_data = torch.rand(*shape)
55+
shape = (samples, channels, hr_grid_size, hr_grid_size, hr_grid_size)
56+
HR_data = torch.rand(*shape)
57+
return LR_data, HR_data, box_size, lr_grid_size, hr_grid_size

dmsr/dmsr_gan/__init__.py

Whitespace-only changes.

dmsr/dmsr_gan/dmsr_dataset.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

dmsr/field_operations/resize.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
66
@author: brennan
77
8-
This file defines functions for resizing tensors using various methods.
8+
This file defines functions for resizing tensors.
99
"""
1010

11-
import numpy as np
12-
1311

1412
def crop(field, crop_size):
1513
"""Crops the spatial dimensions of the given tensor by size crop_size.
@@ -20,65 +18,33 @@ def crop(field, crop_size):
2018
return field[ind]
2119

2220

23-
def cut_field(fields, cut_size, stride=0, pad=0):
24-
"""Cuts the given field tensor into blocks of size `cut_size`.
25-
26-
Arguments:
27-
- fields : A numpy tensor of shape (batch_size, channels, N, N, N)
28-
where N is the grid size of the fields.
29-
- cut_size : The base size of the blocks to cut the given fields into.
30-
31-
- stride : The number of cells to move in each direction before
32-
extracting the next block.
33-
- pad : The number of cells to pad the base blocks on each side.
34-
35-
Returns:
36-
A numpy tensor containing the blocks/subfields cut from the given
37-
fields tensor. The shape of the returned tensor is:
38-
(number_of_cuts * batch_size, channels, n, n, n),
39-
where number_of_cuts is the number of subfields extracted from each
40-
field and n is the grid size of each subfield (ie cut_size + 2 * pad).
21+
def pixel_unshuffle(tensor, scale):
4122
"""
42-
grid_size = fields.shape[-1]
43-
if not stride:
44-
stride = cut_size
23+
Reshapes the given a tensor of shape (B, C, D, H, W) to shape
24+
(B, C * scale**3, D // scale, H // scale, W // scale).
4525
46-
cuts = []
47-
for i in range(0, grid_size, stride):
48-
for j in range(0, grid_size, stride):
49-
for k in range(0, grid_size, stride):
50-
51-
slice_x = [n % grid_size for n in range(i-pad, i+cut_size+pad)]
52-
slice_y = [n % grid_size for n in range(j-pad, j+cut_size+pad)]
53-
slice_z = [n % grid_size for n in range(k-pad, k+cut_size+pad)]
54-
55-
patch = np.take(fields, slice_x, axis=2)
56-
patch = np.take(patch, slice_y, axis=3)
57-
patch = np.take(patch, slice_z, axis=4)
58-
59-
cuts.append(patch)
60-
61-
return np.concatenate(cuts)
62-
63-
64-
def stitch_fields(patches, patches_per_dim):
65-
"""Combines or stitches the given collection of patches into a single
66-
tensor.
67-
68-
This function can be thought of as performing the reverse operation
69-
performed by `cut_field`.
26+
The reshaping procedure uses the pixel shuffle method of Shi et al 2016 -
27+
"Real-Time Single Image and Video Super-Resolution Using an Efficient
28+
Sub-Pixel Convolutional Neural Network"
7029
"""
30+
# Ensure tensor has the right shape
31+
batch_size, channels, depth, height, width = tensor.shape
32+
33+
new_channels = channels * scale**3
34+
new_depth = depth // scale
35+
new_height = height // scale
36+
new_width = width // scale
37+
38+
# Reshape and permute to rearrange data
39+
tensor = tensor.contiguous().view(
40+
batch_size, channels,
41+
new_depth, scale,
42+
new_height, scale,
43+
new_width, scale
44+
)
45+
tensor = tensor.permute(0, 1, 3, 5, 7, 2, 4, 6)
46+
tensor = tensor.contiguous().view(
47+
batch_size, new_channels, new_depth, new_height, new_width
48+
)
7149

72-
patch_size = patches[0].shape[-1]
73-
field_size = patch_size * patches_per_dim
74-
field = np.zeros((3, field_size, field_size, field_size))
75-
76-
for n, patch in enumerate(patches):
77-
i = n // patches_per_dim**2
78-
j = (n % patches_per_dim**2) // patches_per_dim
79-
k = n % patches_per_dim
80-
81-
N = patch_size
82-
field[:, i*N:(i+1)*N, j*N:(j+1)*N, k*N:(k+1)*N] = patch
83-
84-
return field
50+
return tensor

dmsr/monitors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .monitor import *
2+
from .manager import *

dmsr/monitors/manager.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Feb 27 12:25:54 2025
5+
6+
@author: brennan
7+
"""
8+
9+
import time
10+
11+
12+
class MonitorManager():
13+
"""A class to manage monitor objects.
14+
15+
The Monitor Manager class stores and calls Monitor objects during DMSR-WGAN
16+
training at appropriate times.
17+
18+
Monitor objects are stored in a monitors dictionary. During DMSR training,
19+
at the end of a batch update the `post_batch_processing` method of each
20+
monitor object is called. Similarly, at the end of each epoch, the
21+
`post_epoch_processing` method of each monitor is called by the monitor
22+
manager.
23+
24+
Any messages returned by the `post_batch_processing` calls are passed to a
25+
batch report method which prints them along with some information regarding
26+
batch/epoch number and timings. At the end of each epoch, the monitor
27+
manager also prints some timing information regarding the epoch and epoch
28+
post processing.
29+
"""
30+
31+
def __init__(self, report_rate, device):
32+
self.device = device
33+
self.report_rate = report_rate
34+
35+
36+
def set_monitors(self, monitors):
37+
self.monitors = monitors
38+
39+
40+
def init_monitoring(self, num_epochs, num_batches):
41+
"""Initializes values for variables used for timing batches and epochs.
42+
"""
43+
self.num_epochs = num_epochs
44+
self.num_batches = num_batches
45+
self.batch_start_time = time.time()
46+
self.epoch_start_time = time.time()
47+
48+
49+
def end_of_epoch(self, epoch):
50+
"""Calls the `post_epoch_processing` method of each monitor.
51+
"""
52+
epoch_time = time.time() - self.epoch_start_time
53+
print(f"[Epoch {epoch} took: {epoch_time:.4f} sec]")
54+
post_processing_start_time = time.time()
55+
56+
for monitor in self.monitors.values():
57+
monitor.post_epoch_processing(epoch)
58+
59+
self.epoch_start_time = time.time()
60+
self.batch_start_time = time.time()
61+
post_processing_time = time.time() - post_processing_start_time
62+
print(
63+
f"[Epoch post-processing took: {post_processing_time:.4f} sec]",
64+
flush=True
65+
)
66+
67+
68+
def end_of_batch(self, epoch, batch, batch_counter, losses):
69+
"""Calls the `post_batch_processing` method of each monitor.
70+
"""
71+
monitor_report = ''
72+
73+
for monitor in self.monitors.values():
74+
monitor_report += monitor.post_batch_processing(
75+
epoch, batch, batch_counter, losses
76+
)
77+
78+
self.batch_report(epoch, batch, monitor_report)
79+
80+
81+
def batch_report(self, epoch, batch, monitor_report):
82+
"""Report some satistics for the last few batch updates.
83+
"""
84+
if (batch > 0 and batch % self.report_rate == 0):
85+
time_curr = time.time()
86+
time_prev = self.batch_start_time
87+
average_batch_time = (time_curr - time_prev) / self.report_rate
88+
89+
report = f"[Epoch {epoch:04}/{self.num_epochs}]"
90+
report += f"[Batch {batch:03}/{self.num_batches}]"
91+
report += f"[time per batch: {average_batch_time*1000:.4f} ms]"
92+
report += monitor_report
93+
94+
print(report)
95+
self.batch_start_time = time.time()

0 commit comments

Comments
 (0)