Skip to content

Commit d47c297

Browse files
committed
added code for easy training and evaluation.
1 parent ce1a7f8 commit d47c297

10 files changed

+854
-626
lines changed

Quickstart.ipynb

+4-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
8484
},
8585
"kernelspec": {
86-
"display_name": "Python 3.8.8 64-bit ('env2': conda)",
86+
"display_name": "Python 3",
87+
"language": "python",
8788
"name": "python3"
8889
},
8990
"language_info": {
@@ -96,9 +97,9 @@
9697
"name": "python",
9798
"nbconvert_exporter": "python",
9899
"pygments_lexer": "ipython3",
99-
"version": "3.8.8"
100+
"version": "3.8.10"
100101
}
101102
},
102103
"nbformat": 4,
103-
"nbformat_minor": 2
104+
"nbformat_minor": 4
104105
}

Readme.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Image Segmentation Using Text and Image Prompts
2-
This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".
2+
This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
33

44
<img src="overview.png" alt="drawing" height="200em"/>
55

@@ -44,22 +44,24 @@ git clone https://github.com/juhongm999/hsnet.git
4444
- [CLIPSeg-D64](https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth) (4.1MB, without CLIP weights)
4545
- [CLIPSeg-D16](https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth) (1.1MB, without CLIP weights)
4646

47-
### Training
47+
### Training and Evaluation
48+
49+
To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
50+
51+
For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
4852

49-
See the experiment folder for yaml definitions of the training configurations. The training code is in `experiment_setup.py`.
5053

5154
### Usage of PFENet Wrappers
5255

5356
In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
5457
`git clone https://github.com/Jia-Research-Lab/PFENet.git `
5558

5659
### Citation
57-
5860
```
5961
@article{lueddecke21
6062
title={Image Segmentation Using Text and Image Prompts},
6163
author={Timo Lüddecke and Alexander Ecker},
62-
journal={...},
64+
journal={arXiv preprint arXiv:2112.10003},
6365
year={2021}
6466
}
6567
```

datasets/pascal_zeroshot.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from general_utils import log
88
from torchvision import transforms
99

10-
# PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
11-
# ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
12-
# ['chair.n.01', 'pot_plant.n.01']]
10+
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
11+
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
12+
['chair.n.01', 'pot_plant.n.01']]
1313

1414

1515
class PascalZeroShot(object):

datasets/phrasecut.py

+6
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=F
6666
self.image_size = image_size
6767
self.with_visual = with_visual
6868
self.only_visual = only_visual
69+
self.phrase_form = '{}'
6970
self.mask = mask
7071
self.aug_crop = aug_crop
7172

@@ -125,7 +126,9 @@ def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=F
125126

126127
elif remove_classes[0] == 'zs':
127128
stop = remove_classes[1]
129+
128130
from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
131+
129132
avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
130133
print(avoid)
131134

@@ -209,6 +212,7 @@ def load_sample(self, sample_i, j):
209212

210213
polys_phrase0 = img_ref_data['gt_Polygons'][j]
211214
phrase = img_ref_data['phrases'][j]
215+
phrase = self.phrase_form.format(phrase)
212216

213217
masks = []
214218
for polys in polys_phrase0:
@@ -248,6 +252,8 @@ def load_sample(self, sample_i, j):
248252

249253
img = self.normalize(img)
250254

255+
256+
251257
return img, seg, phrase
252258

253259
def __getitem__(self, i):

evaluation_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def norm(img):
2929
std = torch.Tensor([0.229, 0.224, 0.225])
3030
return (img - mean[:,None,None]) / std[:,None,None]
3131

32+
3233
def compute_shift(name, w, datasets, size=1, seed=1):
3334

3435
if type(name) == str:

0 commit comments

Comments
 (0)