99
1010from ..data .density_estimation .golden_ratio import fibonacci_ratio , get_centers
1111
12- # from source.data.dataloader import DataLoader
1312from ..strategy import Strategy
14-
1513from matplotlib import pyplot as plt
16-
1714import jax .numpy as jnp
1815from tqdm import tqdm
19-
2016import seaborn as sns
2117
18+
19+ class DataLoader :
20+ pass
21+
22+
2223class TwoDimensionalPlot (Callback ):
2324
2425 name : str = '2d_plot'
2526 save_every : int = 10
2627 num_samples : int = 3000
2728 num_samples_gt : int = 3000
2829
29- def __init__ (self , save_every : int = 10 , num_samples : int = 3000 ):
30+ def __init__ (self , save_every : int = 10 , num_samples : int = 3000 , savedir : str = None ):
3031 super (self .__class__ , self ).__init__ ()
3132
3233 self .save_every = save_every
3334 self .num_samples = num_samples
35+ self .savedir = savedir
3436
3537
3638 def __call__ (self , logs : dict , rng : jr .PRNGKey , * args , ** kwargs ):
@@ -59,8 +61,12 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
5961
6062 ax .scatter (x [:self .num_samples_gt ], y [:self .num_samples_gt ])
6163
64+ if self .savedir is not None :
65+ plt .savefig (f"{ self .savedir } /pictures/samples_gt.png" )
66+
6267 logs [f'samples_gt' ] = wandb .Image (fig )
6368
69+
6470 plt .close (fig )
6571
6672 sample_fn = strategy .sample
@@ -75,6 +81,15 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
7581 fig , ax = plt .subplots ()
7682 ax .scatter (x , y )
7783
84+ epoch = logs .get ("epoch" , 0 )
85+ if self .savedir is not None :
86+
87+ import os
88+ # create directory if it does not exist
89+ os .makedirs (f"{ self .savedir } /pictures" , exist_ok = True )
90+
91+ plt .savefig (f"{ self .savedir } /pictures/samples_{ epoch } .png" )
92+
7893 logs [f'samples' ] = wandb .Image (fig )
7994
8095 plt .close (fig )
@@ -84,6 +99,9 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
8499 def on_train_begin (self , * args , ** kwargs ):
85100 return self .__call__ (* args , init = True , ** kwargs )
86101
102+ def on_test (self , * args , ** kwargs ):
103+ return self .__call__ (* args , init = False , ** kwargs )
104+
87105 def on_epoch_end (self , logs : dict , rng : jr .PRNGKey , * args , ** kwargs ):
88106
89107 if "epoch" in logs and logs ["epoch" ] % self .save_every == 0 :
0 commit comments