Skip to content

Commit 85829e1

Browse files
committed
add cifar-N to docta
1 parent 61e54b9 commit 85829e1

File tree

5 files changed

+58
-3
lines changed

5 files changed

+58
-3
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@ dist
33
docta.ai.egg-info
44
*.pyc
55
results/Tabular_train/label_error_Tabular_train_diagnose_report.pt
6+
data/cifar/cifar-10-python.tar.gz
7+
data/cifar/cifar-10-batches-py/*
8+
data/cifar/cifar-100-python.tar.gz
9+
data/cifar/cifar-100-python/*
10+
test_cifar_N.py

config/cifar100.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# dataset settings
2+
seed = 0
3+
dataset_type = 'CIFAR'
4+
modality = 'image' # image, text, tabular
5+
num_classes = 100
6+
data_root = './data/cifar/'
7+
label_sel = 1 # which label/attribute we want to diagnose
8+
train_label_sel = label_sel # 1 for noisy
9+
test_label_sel = train_label_sel
10+
11+
file_name = 'c100'
12+
dataset_type += '_' + file_name
13+
save_path = f'./results/{dataset_type}/'

data/cifar/CIFAR-100_human.pt

977 KB
Binary file not shown.

docta/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# different dataset builder
2-
from .cifar import Cifar10_noisy, Cifar100_noisy, Cifar10_clean, Cifar100_clean
2+
from .cifar import Cifar10_noisy, Cifar100_noisy, Cifar10_clean, Cifar100_clean, Cifar10N, Cifar100N
33
from .hh_rlhf import HH_RLHF
44
from .customize import CustomizedDataset
55
from .customize_img_folder import Customize_Image_Folder
66
from .csv_loder import TabularDataset
77

88
__all__ = [
9-
'Cifar10_noisy', 'Cifar100_noisy', 'HH_RLHF', 'CustomizedDataset', 'Customize_Image_Folder', 'Cifar10_clean', 'Cifar100_clean', 'TabularDataset'
9+
'Cifar10_noisy', 'Cifar100_noisy', 'HH_RLHF', 'CustomizedDataset', 'Customize_Image_Folder', 'Cifar10_clean', 'Cifar100_clean', 'TabularDataset', 'Cifar10N', 'Cifar100N'
1010
]

docta/datasets/cifar.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ def __getitem__(self, index: int):
6868
return img, label, index
6969

7070

71+
class Cifar10N(Cifar10_noisy):
72+
73+
train_transform = transforms.Compose([
74+
transforms.RandomCrop(32, padding=4),
75+
transforms.RandomHorizontalFlip(),
76+
transforms.ToTensor(),
77+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
78+
])
79+
80+
test_transform = transforms.Compose([
81+
transforms.ToTensor(),
82+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
83+
])
84+
85+
def __init__(self, cfg, train=True, preprocess=None, noisy_label_key=None, clean_label_key=None) -> None:
86+
print(f"CIFAR-10N includes five sets of noisy labels: random_label1, random_label2, random_label3, aggre_label, worse_label.\nPlease set cfg.noisy_label_key to one of them.")
87+
if preprocess is None:
88+
preprocess = self.train_transform if train else self.test_transform
89+
self.cfg = cfg
90+
self.cfg.label_path = "./data/cifar/CIFAR-10_human.pt"
91+
if noisy_label_key is not None:
92+
self.cfg.noisy_label_key = noisy_label_key
93+
if clean_label_key is not None:
94+
self.cfg.clean_label_key = clean_label_key
95+
super(Cifar10N, self).__init__(cfg, train, preprocess)
7196

7297
class Cifar10_clean(CIFAR10):
7398

@@ -159,4 +184,16 @@ class Cifar100_clean(Cifar10_clean):
159184
test_transform = transforms.Compose([
160185
transforms.ToTensor(),
161186
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
162-
])
187+
])
188+
189+
190+
class Cifar100N(Cifar100_noisy):
191+
192+
def __init__(self, cfg, train=True, preprocess=None) -> None:
193+
if preprocess is None:
194+
preprocess = self.train_transform if train else self.test_transform
195+
self.cfg = cfg
196+
self.cfg.label_path = "./data/cifar/CIFAR-100_human.pt"
197+
self.cfg.noisy_label_key = "noisy_label"
198+
self.cfg.clean_label_key = "clean_label"
199+
super(Cifar100N, self).__init__(cfg, train, preprocess)

0 commit comments

Comments
 (0)