Skip to content

Commit e235882

Browse files
committed
allow multiple process_fn arguments
1 parent 5f8ef20 commit e235882

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

generators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,11 @@ def _data_generation(self, ids_batch):
198198
# if needed, process each image, and add to X_list (inputs list)
199199
if params.process_fn not in [None, False]:
200200
data_list = []
201-
for i, row in enumerate(ids_batch.itertuples()):
202-
arg = [] if args_name is None else [getattr(row, args_name)]
203-
data_i = params.process_fn(data[i], *arg)
201+
for i, row in enumerate(ids_batch.itertuples()):
202+
args = []
203+
if args_name is not None:
204+
args = [getattr(row, name) for name in force_list(args_name)]
205+
data_i = params.process_fn(data[i], *args)
204206
data_list.append(force_list(data_i))
205207

206208
# transpose list, sublists become batches

tests/test_generators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,22 @@ def preproc(im, *arg):
146146
assert np.array_equal(np.squeeze(g[0][0][1]), np.arange(1,5))
147147
assert np.array_equal(np.squeeze(g[0][1]), np.arange(1,5))
148148

149+
def test_multi_process_args_DataGeneratorHDF5():
150+
def preproc(im, arg1, arg2):
151+
return np.zeros(1) + arg1 + arg2
152+
153+
gen_params_local = gen_params.copy()
154+
gen_params_local.process_fn = preproc
155+
gen_params_local.process_args = {'filename': ['filename_args','filename_args']}
156+
gen_params_local.batch_size = 4
157+
158+
ids_local = ids.copy()
159+
ids_local['filename_args'] = range(len(ids_local))
160+
161+
g = gr.DataGeneratorDisk(ids_local, **gen_params_local)
162+
x = g[0]
163+
assert np.array_equal(np.squeeze(x[0][0].T), np.arange(4)*2)
164+
149165
def test_callable_outputs_DataGeneratorHDF5():
150166
d = {'features': [1, 2, 3, 4, 5],
151167
'mask': [1, 0, 1, 1, 0]}

tests/tests.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@
5050
"metadata": {},
5151
"outputs": [],
5252
"source": [
53-
"def preproc(im, arg):\n",
54-
" return np.zeros(1) + arg\n",
53+
"def preproc(im, arg1, arg2):\n",
54+
" return np.zeros(1) + arg1 + arg2\n",
5555
"\n",
5656
"gen_params_local = gen_params.copy()\n",
5757
"gen_params_local.process_fn = preproc\n",
58-
"gen_params_local.process_args = {'filename': 'filename_args'}\n",
58+
"gen_params_local.process_args = {'filename': ['filename_args','filename_args']}\n",
5959
"gen_params_local.batch_size = 4\n",
6060
"\n",
6161
"ids_local = ids.copy()\n",
@@ -64,7 +64,7 @@
6464
"g = gr.DataGeneratorDisk(ids_local, **gen_params_local)\n",
6565
"x = g[0]\n",
6666
"# gen.pretty(g)\n",
67-
"assert np.array_equal(np.squeeze(x[0].T), np.arange(4))"
67+
"assert np.array_equal(np.squeeze(x[0][0].T), np.arange(4)*2)"
6868
]
6969
},
7070
{

0 commit comments

Comments
 (0)