Skip to content

Commit a385583

Browse files
committed
Update analysis scripts
1 parent d0648ba commit a385583

File tree

4 files changed

+44
-28
lines changed

4 files changed

+44
-28
lines changed

scripts/data_processing/dataset_preperation.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,30 @@
1414
import numpy as np
1515

1616
from swift_tools.data import read_snapshots
17-
from dmsr.field_operations.resize import cut_field
17+
from swift_tools.fields import cut_field
1818

1919

2020
#%%
2121
data_directory = '../../data/dmsr_runs/'
2222

23-
LR_snapshots = np.sort(glob.glob(data_directory + '*/064/snap_0002.hdf5'))[:16]
24-
HR_snapshots = np.sort(glob.glob(data_directory + '*/128/snap_0002.hdf5'))[:16]
23+
# LR_snapshots = np.sort(glob.glob(data_directory + '*/064/snap_0002.hdf5'))[:16]
24+
# HR_snapshots = np.sort(glob.glob(data_directory + '*/256/snap_0002.hdf5'))[:16]
2525

26-
# LR_snapshots = np.sort(glob.glob(data_directory + 'run17/064/snap_0002.hdf5'))
27-
# HR_snapshots = np.sort(glob.glob(data_directory + 'run17/128/snap_0002.hdf5'))
26+
LR_snapshots = np.sort(glob.glob(data_directory + 'run17/064/snap_0002.hdf5'))
27+
HR_snapshots = np.sort(glob.glob(data_directory + 'run17/256/snap_0002.hdf5'))
2828

2929

3030
#%%
3131
LR_disp, LR_vel, box_size, LR_grid_size, LR_mass = read_snapshots(LR_snapshots)
3232
HR_disp, HR_vel, box_size, HR_grid_size, HR_mass = read_snapshots(HR_snapshots)
3333

3434
LR_fields = np.concatenate((LR_disp, LR_vel), axis=1)
35+
del LR_disp
36+
del LR_vel
37+
3538
HR_fields = np.concatenate((HR_disp, HR_vel), axis=1)
39+
del HR_disp
40+
del HR_vel
3641

3742
# # Normalise values so that box size is 1
3843
# LR_fields /= box_size
@@ -43,23 +48,23 @@
4348
#%%
4449
padding = 2
4550
LR_patch_size = 16
46-
HR_patch_size = 32
51+
HR_patch_size = 64
4752

4853
LR_fields = cut_field(LR_fields, LR_patch_size, LR_patch_size, pad=padding)
4954
HR_fields = cut_field(HR_fields, HR_patch_size, HR_patch_size)
5055

5156

5257
#%%
53-
LR_file = '../../data/dmsr_training/LR_fields.npy'
54-
# LR_file = '../../data/dmsr_validation/LR_fields.npy'
58+
# LR_file = '../../data/dmsr_training/LR_fields.npy'
59+
LR_file = '../../data/dmsr_validation/LR_fields.npy'
5560
np.save(LR_file, LR_fields)
5661

57-
HR_file = '../../data/dmsr_training/HR_fields.npy'
58-
# HR_file = '../../data/dmsr_validation/HR_fields.npy'
62+
# HR_file = '../../data/dmsr_training/HR_fields.npy'
63+
HR_file = '../../data/dmsr_validation/HR_fields.npy'
5964
np.save(HR_file, HR_fields)
6065

61-
meta_file = '../../data/dmsr_training/metadata.npy'
62-
# meta_file = '../../data/dmsr_validation/metadata.npy'
66+
# meta_file = '../../data/dmsr_training/metadata.npy'
67+
meta_file = '../../data/dmsr_validation/metadata.npy'
6368
LR_size = LR_patch_size + 2 * padding
6469
HR_size = HR_patch_size
6570
np.save(meta_file, [

scripts/training/plot_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import matplotlib.pyplot as plt
1111

12-
losses = np.load('./losses.npz')
12+
losses = np.load('./velocity_run/losses.npz')
1313

1414
critic_loss = losses['critic_loss']
1515
critic_batches = losses['critic_batches']
@@ -54,5 +54,6 @@
5454
generator_batches[window_size//2-1:-window_size//2], moving_average,
5555
# color='black'
5656
)
57+
# plt.savefig('loss_curve.png', dpi=210)
5758
plt.show()
5859
plt.close()

scripts/training/plot_power_spectrum.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import matplotlib.pyplot as plt
1919

20+
from swift_tools.data import load_normalisation_parameters
2021
from dmsr.analysis import displacement_power_spectrum
2122

2223

@@ -57,31 +58,35 @@ def plot_spectra(
5758

5859

5960
#%%
60-
plots_dir = 'plots/training_spectra/'
61-
data_dir = './data/samples/'
62-
sr_samples = glob.glob(data_dir + 'sr_sample_*.npy')
61+
data_dir = './velocity_run/'
62+
plots_dir = data_dir + 'plots/training_spectra/'
63+
samples_dir = data_dir + 'samples/'
64+
sr_samples = glob.glob(samples_dir + 'sr_sample_*.npy')
6365
sr_samples = np.sort(sr_samples)
6466

6567
existing_plots = glob.glob(plots_dir + 'power_sprectrum_epoch_*.png')
6668
existing_plots = [re.split(r'[._]+', plot)[-2] for plot in existing_plots]
6769
sr_samples = [sample for sample in sr_samples
6870
if not re.split(r'[._]+', sample)[-2] in existing_plots]
6971

70-
lr_sample = np.load(data_dir + 'lr_sample.npy')
71-
lr_sample = torch.from_numpy(lr_sample)
72-
hr_sample = np.load(data_dir + 'hr_sample.npy')
73-
hr_sample = torch.from_numpy(hr_sample)
72+
scale_path = data_dir + 'normalisation.npy'
73+
lr_std, hr_std, _, _ = load_normalisation_parameters(scale_path)
74+
75+
lr_sample = np.load(samples_dir + 'lr_sample.npy')
76+
lr_sample = torch.from_numpy(lr_sample)[:, :3, ...] * lr_std
77+
hr_sample = np.load(samples_dir + 'hr_sample.npy')
78+
hr_sample = torch.from_numpy(hr_sample)[:, :3, ...] * hr_std
7479

7580
# TODO: read this from metadata
7681
box_size = 35.56187768431281
7782

7883
for sr_sample in sr_samples:
7984
epoch = int(re.split(r'[._]+', sr_sample)[-2])
8085
sr_sample = np.load(sr_sample)
81-
sr_sample = torch.from_numpy(sr_sample)
86+
sr_sample = torch.from_numpy(sr_sample)[:, :3, ...] * hr_std
8287

8388
plot_spectra(
8489
lr_sample, sr_sample, hr_sample,
85-
64, 1, 20*box_size/14, box_size, 20, 56,
90+
64, 1, 20*box_size/16, box_size, 20, 64,
8691
epoch, plots_dir
8792
)

scripts/training/plot_samples.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import matplotlib.pyplot as plt
1919

20+
from swift_tools.data import load_normalisation_parameters
2021
from dmsr.field_operations.conversion import displacements_to_positions
2122

2223
def plot_samples(
@@ -77,25 +78,29 @@ def get_xys(positions):
7778

7879

7980
#%%
80-
plots_dir = 'plots/training_samples/'
81-
data_dir = './data/samples/'
82-
sr_samples = glob.glob(data_dir + 'sr_sample_*.npy')
81+
data_dir = './velocity_run/'
82+
plots_dir = data_dir + 'plots/training_samples/'
83+
samples_dir = data_dir + 'samples/'
84+
sr_samples = glob.glob(samples_dir + 'sr_sample_*.npy')
8385
sr_samples = np.sort(sr_samples)
8486

8587
existing_plots = glob.glob(plots_dir + 'particle_plot_epoch_*.png')
8688
existing_plots = [re.split(r'[._]+', plot)[-2] for plot in existing_plots]
8789
sr_samples = [sample for sample in sr_samples
8890
if not re.split(r'[._]+', sample)[-2] in existing_plots]
8991

90-
lr_sample = np.load(data_dir + 'lr_sample.npy')
91-
hr_sample = np.load(data_dir + 'hr_sample.npy')
92+
scale_path = data_dir + 'normalisation.npy'
93+
lr_std, hr_std, _, _ = load_normalisation_parameters(scale_path)
94+
95+
lr_sample = np.load(samples_dir + 'lr_sample.npy')[:, :3, ...] * lr_std
96+
hr_sample = np.load(samples_dir + 'hr_sample.npy')[:, :3, ...] * hr_std
9297

9398
# TODO: read this from metadata
9499
box_size = 35.56187768431281
95100

96101
for sr_sample in sr_samples:
97102
epoch = int(re.split(r'[._]+', sr_sample)[-2])
98-
sr_sample = np.load(sr_sample)
103+
sr_sample = np.load(sr_sample)[:, :3, ...] * hr_std
99104

100105
plot_samples(
101106
lr_sample, sr_sample, hr_sample,

0 commit comments

Comments
 (0)