@@ -95,7 +95,7 @@ def __init__(self,
9595 logging .debug ('Start initialization of Pina DataModule' )
9696 logging .info ('Start initialization of Pina DataModule' )
9797 super ().__init__ ()
98- self .default_batching = automatic_batching
98+ self .automatic_batching = automatic_batching
9999 self .batch_size = batch_size
100100 self .shuffle = shuffle
101101 self .repeat = repeat
@@ -133,24 +133,24 @@ def setup(self, stage=None):
133133 self .train_dataset = PinaDatasetFactory (
134134 self .collector_splits ['train' ],
135135 max_conditions_lengths = self .find_max_conditions_lengths (
136- 'train' ))
136+ 'train' ), automatic_batching = self . automatic_batching )
137137 if 'val' in self .collector_splits .keys ():
138138 self .val_dataset = PinaDatasetFactory (
139139 self .collector_splits ['val' ],
140140 max_conditions_lengths = self .find_max_conditions_lengths (
141- 'val' )
141+ 'val' ), automatic_batching = self . automatic_batching
142142 )
143143 elif stage == 'test' :
144144 self .test_dataset = PinaDatasetFactory (
145145 self .collector_splits ['test' ],
146146 max_conditions_lengths = self .find_max_conditions_lengths (
147- 'test' )
147+ 'test' ), automatic_batching = self . automatic_batching
148148 )
149149 elif stage == 'predict' :
150150 self .predict_dataset = PinaDatasetFactory (
151151 self .collector_splits ['predict' ],
152152 max_conditions_lengths = self .find_max_conditions_lengths (
153- 'predict' )
153+ 'predict' ), automatic_batching = self . automatic_batching
154154 )
155155 else :
156156 raise ValueError (
@@ -237,9 +237,9 @@ def val_dataloader(self):
237237 self .val_dataset )
238238
239239 # Use default batching in torch DataLoader (good is batch size is small)
240- if self .default_batching :
240+ if self .automatic_batching :
241241 collate = Collator (self .find_max_conditions_lengths ('val' ))
242- return DataLoader (self .val_dataset , self . batch_size ,
242+ return DataLoader (self .val_dataset , batch_size ,
243243 collate_fn = collate )
244244 collate = Collator (None )
245245 # Use custom batching (good if batch size is large)
@@ -252,14 +252,16 @@ def train_dataloader(self):
252252 Create the training dataloader
253253 """
254254 # Use default batching in torch DataLoader (good is batch size is small)
255- if self .default_batching :
255+ batch_size = self .batch_size if self .batch_size is not None else len (
256+ self .train_dataset )
257+
258+ if self .automatic_batching :
256259 collate = Collator (self .find_max_conditions_lengths ('train' ))
257- return DataLoader (self .train_dataset , self . batch_size ,
260+ return DataLoader (self .train_dataset , batch_size ,
258261 collate_fn = collate )
259262 collate = Collator (None )
260263 # Use custom batching (good if batch size is large)
261- batch_size = self .batch_size if self .batch_size is not None else len (
262- self .train_dataset )
264+
263265 sampler = PinaBatchSampler (self .train_dataset , batch_size ,
264266 shuffle = False )
265267 return DataLoader (self .train_dataset , sampler = sampler ,
0 commit comments