-
Notifications
You must be signed in to change notification settings - Fork 1
/
generator.py
36 lines (26 loc) · 1.05 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tensorflow as tf
import numpy as np
from parameters import Parameters
from utils import load_row
import math
class Generator(tf.keras.utils.Sequence):
def __init__(self, p : Parameters, rows):
self.params = p
self.rows = rows
self.on_epoch_end()
def __len__(self):
return math.floor(len(self.rows) / self.params.batch_size)
def __getitem__(self, index):
indexes = self.indexes[index*self.params.batch_size:(index+1)*self.params.batch_size]
batch_of_rows = [self.rows[k] for k in indexes]
X, y = self.__data_generation(batch_of_rows)
return X, y
def __data_generation(self, rows):
X = np.empty((self.params.batch_size, *self.params.input_dim, 3), dtype=np.uint8)
y = np.empty((self.params.batch_size, *self.params.input_dim, 1), dtype=np.bool)
for i, row in enumerate(rows):
X[i,], y[i,], _ = load_row(row)
return X, y
def on_epoch_end(self):
self.indexes = np.arange(len(self.rows))
np.random.shuffle(self.indexes)