Skip to content

Commit 3b357ee

Browse files
authored
Fix checkpoint loading for KNN eval (#376)
1 parent 9f60a1b commit 3b357ee

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

main_knn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import torch
2626
import torch.nn as nn
27+
from omegaconf import OmegaConf
2728
from torch.utils.data import DataLoader
2829
from tqdm import tqdm
2930

@@ -126,12 +127,10 @@ def main():
126127
# load arguments
127128
with open(args_path) as f:
128129
method_args = json.load(f)
130+
cfg = OmegaConf.create(method_args)
129131

130132
# build the model
131-
model = METHODS[method_args["method"]].load_from_checkpoint(
132-
ckpt_path, strict=False, **method_args
133-
)
134-
model.cuda()
133+
model = METHODS[method_args["method"]].load_from_checkpoint(ckpt_path, strict=False, cfg=cfg)
135134

136135
# prepare data
137136
_, T = prepare_transforms(args.dataset)

solo/methods/mocov3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ def forward(self, X: torch.Tensor) -> Dict[str, Any]:
171171
"""
172172

173173
out = super().forward(X)
174-
q = self.predictor(self.projector(out["feats"]))
175-
out.update({"q": q})
174+
z = self.projector(out["feats"])
175+
q = self.predictor(z)
176+
out.update({"q": q, "z": z})
176177
return out
177178

178179
@torch.no_grad()

solo/methods/swav.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, cfg: omegaconf.DictConfig):
4343
temperature (float): temperature for the softmax normalization.
4444
queue_size (int): number of samples to hold in the queue.
4545
epoch_queue_starts (int): epochs the queue starts.
46-
freeze_prototypes_epochs (int): number of epochs during which the prototypes are frozen.
46+
freeze_prototypes_epochs (int): number of epochs during which
47+
the prototypes are frozen.
4748
"""
4849

4950
super().__init__(cfg)

0 commit comments

Comments
 (0)