Skip to content

Commit 6b934b7

Browse files
committed
udpate
1 parent 6b660de commit 6b934b7

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

SimCLR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def init_model(self):
137137

138138
elif self.arch == 'resnet50':
139139
# backbone = resnet.resnet50(mode=self.mode)
140-
backbone = resnet.ResNetPreTrained(type='resnet50')
140+
backbone = resnet.resnet50(type='resnet50')
141141
return backbone
142142

143143
def forward(self, x):

module/resnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False, mo
7878
super(ResNet, self).__init__()
7979
self.in_planes = 64
8080
self.mode = mode
81-
if self.mode == 'cifar':
81+
if self.mode == 'cifar10':
8282
self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
8383
bias=False)
8484
self.bn1 = nn.BatchNorm2d(64)
@@ -125,8 +125,8 @@ def _make_layer(self, block, planes, num_blocks, stride):
125125

126126
def forward(self, x):
127127
out = F.relu(self.bn1(self.conv1(x)))
128-
if self.mode == 'cifar':
129-
out = self.maxpool(out)
128+
# if self.mode == 'cifar10':
129+
# out = self.maxpool(out)
130130
out = self.layer1(out)
131131
out = self.layer2(out)
132132
out = self.layer3(out)

train_cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
train_data = CIFAR10(download=True,root="./cifar10",transform=MultiViewDataInjector([train_transform,train_transform,val_test_transform]))
1818
train_len = len(train_data)
1919
num_class = len(np.unique(train_data.targets))
20-
train_loader = DataLoader(dataset = train_data, batch_size = 1024, num_workers=48)
20+
train_loader = DataLoader(dataset = train_data, batch_size = 1024, num_workers=12)
2121
# test_loader = DataLoader(dataset = test_data, batch_size = 16)
2222
# valid_loader = DataLoader(dataset = val_data, batch_size= 16)
2323

0 commit comments

Comments
 (0)