Skip to content

Commit c9906ab

Browse files
authored
Merge branch 'master' into master
2 parents aa35c70 + 7e327d0 commit c9906ab

File tree

5 files changed

+97
-5
lines changed

5 files changed

+97
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
augment
12
CNN
23
train_config.yaml
34
notebooks/*.json

data.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from pathlib import Path
22

3+
import Augmentor
34
import numpy as np
5+
import torch
6+
import yaml
47
from PIL import Image
58
from torch.utils.data.dataset import Dataset
69
from torchvision import transforms
10+
from tqdm import tqdm
711

812
KAGGLE_PATH = Path.home() / '.kaggle/competitions/digit-recognizer/'
913
KAGGLE_TEST_PATH = KAGGLE_PATH / 'test.csv'
@@ -15,6 +19,7 @@ def __init__(self, X, Y, pretransform=False):
1519
self.pretransform = pretransform
1620
self.transform = transforms.Compose([
1721
transforms.Resize((32, 32)),
22+
transforms.RandomRotation(15),
1823
transforms.ToTensor()])
1924
if self.pretransform:
2025
X = list(map(self.transform, map(Image.fromarray, X)))
@@ -66,3 +71,50 @@ def get_test_dataset(csv_file, pretransform=False):
6671
X = data.reshape(-1, 28, 28)
6772
test_dataset = DigitRecognizerDataset(X, None, pretransform=pretransform)
6873
return test_dataset
74+
75+
76+
def save_as_png(csv_file):
77+
data = np.genfromtxt(csv_file, delimiter=',', skip_header=1)
78+
Y, X = np.split(data, [1], axis=1)
79+
Y = np.squeeze(Y).astype(int)
80+
root_path = Path('augment/original')
81+
print(Y.shape)
82+
root_path.mkdir(parents=True, exist_ok=True)
83+
class_paths = [root_path.joinpath(str(i)) for i in range(10)]
84+
for path in class_paths:
85+
path.mkdir(exist_ok=True)
86+
idx = [0 for _ in range(10)]
87+
for i in tqdm(range(len(X))):
88+
path = class_paths[Y[i]]
89+
path = path.joinpath(str(idx[Y[i]]).zfill(5) + '.png')
90+
idx[Y[i]] += 1
91+
x = X[i].reshape(28, 28)
92+
image = Image.fromarray(x).convert('RGB')
93+
with path.open('wb') as f:
94+
image.save(f, format='PNG')
95+
96+
97+
def augment_data(root_path):
98+
src_path = root_path.joinpath('original')
99+
class_paths = [src_path.joinpath(str(i)).resolve() for i in range(10)]
100+
dst_path = root_path.joinpath('out')
101+
dst_path.mkdir(exist_ok=True)
102+
out_paths = [dst_path.joinpath(str(i)) for i in range(10)]
103+
for path in out_paths:
104+
path.mkdir(exist_ok=True)
105+
out_paths = [path.resolve() for path in out_paths]
106+
p = Augmentor.Pipeline(source_directory=str(
107+
class_paths[0]), output_directory=str(out_paths[0]), save_format='PNG')
108+
for i in range(1, 10):
109+
p.add_further_directory(str(class_paths[i]), str(out_paths[i]))
110+
p.random_distortion(1.0, 5, 5, 1)
111+
p.sample(80000)
112+
113+
114+
if __name__ == '__main__':
115+
config_path = Path('train_config.yaml')
116+
config = None
117+
with config_path.open('r') as f:
118+
config = yaml.load(f)
119+
# save_as_png(config['train_path'])
120+
augment_data(Path('augment'))

doc.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ Just an experiment, maybe a baseline for other models, it gets ~95%.
1717
# PracticalCNN
1818

1919
Just a different name, different layer parameters, nothing interested.
20+
21+
# RichCNN
22+
23+
Potentially little bit better, more filters in conv layers

net.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
class SimpleCNN(nn.Module):
88
def __init__(self):
99
super(SimpleCNN, self).__init__()
10-
self.bn = nn.BatchNorm2d(1)
1110
self.conv1 = nn.Conv2d(1, 6, 5)
1211
self.conv2 = nn.Conv2d(6, 16, 5)
1312
self.fc1 = nn.Linear(16 * 5 * 5, 120)
1413
self.fc2 = nn.Linear(120, 84)
1514
self.fc3 = nn.Linear(84, 10)
1615

1716
def forward(self, x):
18-
x = self.bn(x)
1917
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
2018
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
2119
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
@@ -28,24 +26,61 @@ def forward(self, x):
2826
class CNN(nn.Module):
2927
def __init__(self):
3028
super(CNN, self).__init__()
29+
self.bn1 = nn.BatchNorm2d(1)
3130
self.conv1 = nn.Conv2d(1, 6, 5)
31+
self.bn2 = nn.BatchNorm2d(6)
3232
self.conv2 = nn.Conv2d(6, 16, 5)
33+
self.bn3 = nn.BatchNorm2d(16)
3334
self.fc1 = nn.Linear(16 * 5 * 5, 120)
35+
self.bn4 = nn.BatchNorm1d(120)
3436
self.fc2 = nn.Linear(120, 84)
37+
self.bn5 = nn.BatchNorm1d(84)
3538
self.drop = nn.Dropout()
3639
self.fc3 = nn.Linear(84, 10)
3740

3841
def forward(self, x):
39-
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
42+
x = self.bn1(x)
43+
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
44+
x = self.bn2(x)
4045
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
46+
x = self.bn3(x)
4147
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
4248
x = F.relu(self.fc1(x))
49+
x = self.bn4(x)
4350
x = F.relu(self.fc2(x))
51+
x = self.bn5(x)
4452
x = self.drop(x)
4553
x = self.fc3(x)
4654
return x
4755

4856

57+
class RichCNN(nn.Module):
58+
def __init__(self):
59+
super(RichCNN, self).__init__()
60+
self.bn1 = nn.BatchNorm2d(1)
61+
self.conv1 = nn.Conv2d(1, 32, 5)
62+
self.bn2 = nn.BatchNorm2d(32)
63+
self.conv2 = nn.Conv2d(32, 64, 5)
64+
self.bn3 = nn.BatchNorm2d(64)
65+
self.fc1 = nn.Linear(64 * 5 * 5, 1024)
66+
self.bn4 = nn.BatchNorm1d(1024)
67+
self.drop = nn.Dropout()
68+
self.fc2 = nn.Linear(1024, 10)
69+
70+
def forward(self, x):
71+
x = self.bn1(x)
72+
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
73+
x = self.bn2(x)
74+
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
75+
x = self.bn3(x)
76+
x = x.view(-1, reduce(lambda x, y: x * y, x.size()[1:]))
77+
x = F.relu(self.fc1(x))
78+
x = self.bn4(x)
79+
x = self.drop(x)
80+
x = self.fc2(x)
81+
return x
82+
83+
4984
class PracticalCNN(nn.Module):
5085
def __init__(self):
5186
super(PracticalCNN, self).__init__()

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
validation_classes[idx] += counts
3737
H['validation_classes'] = validation_classes.tolist()
3838

39-
net = CNN()
39+
net = RichCNN()
4040
H['net'] = type(net).__name__
4141
net_dir = Path('./' + H['net'])
4242
net_dir.mkdir(parents=True, exist_ok=True)
@@ -45,7 +45,7 @@
4545
optimizer = torch.optim.Adam(net.parameters(), lr=config['learning_rate'])
4646
H['optimizer'] = str(optimizer)
4747
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
48-
optimizer, mode='max')
48+
optimizer, mode='max', verbose=True)
4949
H['lr_scheduler'] = str(lr_scheduler)
5050
criterion = nn.CrossEntropyLoss()
5151
H['criterion'] = str(criterion)

0 commit comments

Comments
 (0)