44
44
- Non-IID-xxx
45
45
- Non-IID-xxx
46
46
"""
47
+
47
48
import argparse
48
49
49
50
# from progress.bar import Bar as Bar
58
59
FedNLLMNIST ,
59
60
FedNLLSVHN ,
60
61
FedNLLClothing1M ,
62
+ FedNLLWebVision ,
63
+ FedNLLSynthetic ,
61
64
)
65
+ from fednoisy .data .NLLData import functional as nllF
62
66
63
67
64
68
def read_args ():
@@ -74,9 +78,25 @@ def read_args():
74
78
"--partition" ,
75
79
default = "iid" ,
76
80
type = str ,
77
- choices = ["iid" , "noniid-#label" , "noniid-labeldir" , "noniid-quantity" ],
81
+ choices = [
82
+ "iid" ,
83
+ "noniid" ,
84
+ "noniid-#label" ,
85
+ "noniid-labeldir" ,
86
+ "noniid-quantity" ,
87
+ ],
78
88
help = "Data partition scheme for federated setting." ,
79
89
)
90
+ parser .add_argument (
91
+ "--personalize" ,
92
+ action = "store_true" ,
93
+ help = "Whether use personalized local test set for each client. If True, then each client's class ratio of local test set is same as the training set" ,
94
+ )
95
+ parser .add_argument (
96
+ "--balance" ,
97
+ action = "store_true" ,
98
+ help = "whether use balance partition for Synthetic dataset." ,
99
+ )
80
100
parser .add_argument (
81
101
"--num_clients" ,
82
102
default = 10 ,
@@ -140,24 +160,56 @@ def read_args():
140
160
"--num_samples" ,
141
161
default = 32 * 2 * 1000 ,
142
162
type = int ,
143
- help = "Number of samples used for Clothing1M training. Defaults as 64000." ,
163
+ help = "Number of samples used for Clothing1M/Synthetic data training. Defaults as 64000." ,
164
+ )
165
+
166
+ parser .add_argument (
167
+ "--num_test_samples" ,
168
+ default = 1000 ,
169
+ type = int ,
170
+ help = "Number of test samples for synthetic dataset." ,
171
+ )
172
+ parser .add_argument (
173
+ "--feature_dim" ,
174
+ type = int ,
175
+ default = 100 ,
176
+ help = "Feature dimension for synthetic dataset." ,
177
+ )
178
+ parser .add_argument (
179
+ "--use_bias" ,
180
+ action = "store_true" ,
181
+ help = "Whether to use bias in synthetic data generation. If True, Y = Xw + b + ε; otherwise Y = Xw + ε." ,
144
182
)
145
183
146
184
# ----Dataset path options----
147
185
parser .add_argument (
148
186
"--dataset" ,
149
187
default = "cifar10" ,
150
188
type = str ,
151
- choices = ["mnist" , "cifar10" , "cifar100" , "svhn" , "clothing1m" , "webvision" ],
189
+ choices = [
190
+ "mnist" ,
191
+ "cifar10" ,
192
+ "cifar100" ,
193
+ "svhn" ,
194
+ "clothing1m" ,
195
+ "webvision" ,
196
+ "synthetic" ,
197
+ ],
152
198
help = "Dataset for experiment. Current support: ['mnist', 'cifar10', "
153
- "'cifar100', 'svhn', 'clothing1m', 'webvision']" ,
199
+ "'cifar100', 'svhn', 'clothing1m', 'webvision', 'synthetic ]" ,
154
200
)
155
201
parser .add_argument (
156
202
"--raw_data_dir" ,
157
203
default = "../data" ,
158
204
type = str ,
159
205
help = "Directory for raw dataset download" ,
160
206
)
207
+ parser .add_argument (
208
+ "--raw_imagenet_dir" ,
209
+ default = "../rawdata/imagenet" ,
210
+ type = str ,
211
+ help = "Directory for raw dataset download" ,
212
+ )
161
213
parser .add_argument (
162
214
"--data_dir" ,
163
215
default = "../noisy_label_data" ,
@@ -242,9 +294,11 @@ def read_args():
242
294
max_noise_ratio = args .max_noise_ratio ,
243
295
root_dir = args .raw_data_dir ,
244
296
out_dir = args .data_dir ,
297
+ personalize = args .personalize ,
245
298
)
246
299
nll_cifar10 .create_nll_scene (seed = args .seed )
247
300
nll_cifar10 .save_nll_scene ()
301
+
248
302
elif args .dataset == "cifar100" :
249
303
nll_cifar100 = FedNLLCIFAR100 (
250
304
globalize = args .globalize ,
@@ -258,9 +312,11 @@ def read_args():
258
312
max_noise_ratio = args .max_noise_ratio ,
259
313
root_dir = args .raw_data_dir ,
260
314
out_dir = args .data_dir ,
315
+ personalize = args .personalize ,
261
316
)
262
317
nll_cifar100 .create_nll_scene (seed = args .seed )
263
318
nll_cifar100 .save_nll_scene ()
319
+
264
320
elif args .dataset == "mnist" :
265
321
nll_mnist = FedNLLMNIST (
266
322
globalize = args .globalize ,
@@ -274,6 +330,7 @@ def read_args():
274
330
max_noise_ratio = args .max_noise_ratio ,
275
331
root_dir = args .raw_data_dir ,
276
332
out_dir = args .data_dir ,
333
+ personalize = args .personalize ,
277
334
)
278
335
nll_mnist .create_nll_scene (seed = args .seed )
279
336
nll_mnist .save_nll_scene ()
@@ -291,11 +348,15 @@ def read_args():
291
348
max_noise_ratio = args .max_noise_ratio ,
292
349
root_dir = args .raw_data_dir ,
293
350
out_dir = args .data_dir ,
351
+ personalize = args .personalize ,
294
352
)
295
353
nll_svhn .create_nll_scene (seed = args .seed )
296
354
nll_svhn .save_nll_scene ()
297
355
298
356
elif args .dataset == "clothing1m" :
357
+ args .noise_mode = "real"
358
+ args .globalize = True
359
+ args .noise_ratio = 0.39
299
360
nll_clothing1m = FedNLLClothing1M (
300
361
root_dir = args .raw_data_dir ,
301
362
out_dir = args .data_dir ,
@@ -308,5 +369,42 @@ def read_args():
308
369
nll_clothing1m .create_nll_scene (seed = args .seed )
309
370
nll_clothing1m .save_nll_scene ()
310
371
372
+ elif args .dataset == "webvision" :
373
+ args .noise_mode = "real"
374
+ args .globalize = True
375
+ args .noise_ratio = 0.20
376
+ nll_webvision = FedNLLWebVision (
377
+ root_dir = args .raw_data_dir ,
378
+ imagenet_root_dir = args .raw_imagenet_dir ,
379
+ out_dir = args .data_dir ,
380
+ partition = args .partition ,
381
+ num_clients = args .num_clients ,
382
+ dir_alpha = args .dir_alpha ,
383
+ major_classes_num = args .major_classes_num ,
384
+ )
385
+ nll_webvision .create_nll_scene (seed = args .seed )
386
+ nll_webvision .save_nll_scene ()
387
+
388
+ elif args .dataset == "synthetic" :
389
+ nll_synthetic = FedNLLSynthetic (
390
+ out_dir = args .data_dir ,
391
+ num_clients = args .num_clients ,
392
+ init_mu = 0 ,
393
+ init_sigma = 1 ,
394
+ partition = args .partition ,
395
+ balance = args .balance ,
396
+ train_sample_num = args .num_samples ,
397
+ test_sample_num = args .num_test_samples ,
398
+ feature_dim = args .feature_dim ,
399
+ use_bias = args .use_bias ,
400
+ dir_alpha = args .dir_alpha ,
401
+ )
402
+ args .init_mu = 0
403
+ args .init_sigma = 1
404
+ nll_synthetic .create_nll_scene (seed = args .seed )
405
+ nll_synthetic .save_nll_scene ()
406
+ nll_name = nllF .FedNLL_name (** vars (args ))
407
+ print (f"{ nll_name } " )
408
+
311
409
else :
312
410
raise ValueError (f"dataset='{ args .dataset } ' is not supported!" )
0 commit comments