Skip to content

Commit

Permalink
support batch size > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxue0827 committed Dec 21, 2019
1 parent d5d1103 commit 0de3f0d
Show file tree
Hide file tree
Showing 30 changed files with 2,408 additions and 1,424 deletions.
10 changes: 6 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ __pycache__/
*.json
*.zip

*/tools/demos/*
*/output/*
*/data/pretrained_weights/*
*/data/tfrecord/*
tools/demos/*
tools/test_dota/*
tools/test_icdar2015/*
output/summary/*
data/pretrained_weights/*
data/tfrecord/*
47 changes: 29 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ This is a tensorflow re-implementation of [Focal Loss for Dense Object Detection
![1](voc_2007.gif)

### Performance
| Model | Backbone | Training data | Val data | mAP | Train Schedule | GPU | Image/GPU | Configuration File |
|:------------:|:------------:|:------------:|:---------:|:-----------:|:----------:|:----------:|:-----------:|:-----------:|
| [Faster-RCNN](https://github.com/DetectionTeamUCAS/Faster-RCNN_Tensorflow) | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 73.09 | - | 1X GTX 1080Ti | 1 | - |
| [FPN](https://github.com/DetectionTeamUCAS/FPN_Tensorflow) | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 74.26 | - | 1X GTX 1080Ti | 1 | - |
| RetinaNet | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 73.16 | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc07_v3.py |
| RetinaNet | ResNet50_v1d 600 | VOC07 trainval | VOC07 test | 73.26 | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc07_v4.py |
| RetinaNet | ResNet50_v1d 600 | VOC07+12 trainval | VOC07 test | 79.66 | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v1.py |
| RetinaNet | ResNet101_v1d 600 | VOC07+12 trainval | VOC07 test | 81.69 | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v4.py |
| RetinaNet | ResNet101_v1d 800 | VOC07+12 trainval | VOC07 test | 80.69 | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v3.py |
| RetinaNet | ResNet50_v1 600 | COCO train2017 | COCO val2017 (coco minival) | 33.4 | 1x | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_coco_1x_v4.py |
| Model | Backbone | Training data | Val data | mAP | Inf time (fps) | Model Link | Train Schedule | GPU | Image/GPU | Configuration File |
|:------------:|:------------:|:------------:|:---------:|:-----------:|:----------:|:----------:|:----------:|:----------:|:-----------:|:-----------:|
| [Faster-RCNN](https://github.com/DetectionTeamUCAS/Faster-RCNN_Tensorflow) | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 73.09 | - | - | - | 1X GTX 1080Ti | 1 | - |
| [FPN](https://github.com/DetectionTeamUCAS/FPN_Tensorflow) | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 74.26 | - | - | - | 1X GTX 1080Ti | 1 | - |
| RetinaNet | ResNet50_v1 600 | VOC07 trainval | VOC07 test | 73.16 | 14.6 | - | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc07_v3.py |
| RetinaNet | ResNet50_v1d 600 | VOC07 trainval | VOC07 test | 73.26 | 14.6 | - | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc07_v4.py |
| RetinaNet | ResNet50_v1d 600 | VOC07 trainval | VOC07 test | 74.00 | 14.6 | [model](https://drive.google.com/file/d/1qjYsAi5uHB-6KgnrgWTN42a7Njkah-rA/view?usp=sharing) | - | 4X GeForce RTX 2080 Ti | 2 | cfgs_res50_voc07_v5.py |
| RetinaNet | ResNet50_v1d 600 | VOC07+12 trainval | VOC07 test | 79.66 | 14.6 | - | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v1.py |
| RetinaNet | ResNet101_v1d 600 | VOC07+12 trainval | VOC07 test | 81.69 | 14.6 | - | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v4.py |
| RetinaNet | ResNet101_v1d 800 | VOC07+12 trainval | VOC07 test | 80.69 | 14.6 | - | - | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_voc0712_v3.py |
| RetinaNet | ResNet50_v1 600 | COCO train2017 | COCO val2017 (coco minival) | 33.4 | - | - | 1x | 8X GeForce RTX 2080 Ti | 1 | cfgs_res50_coco_1x_v4.py |
| RetinaNet | ResNet50_v1 600 | COCO train2017 | COCO val2017 (coco minival) | | - | - | 1x | 4X GeForce RTX 2080 Ti | 2 | cfgs_res50_coco_1x_v5.py |

## My Development Environment
1、python3.5 (anaconda recommend)
Expand All @@ -27,10 +29,9 @@ This is a tensorflow re-implementation of [Focal Loss for Dense Object Detection
## Download Model
### Pretrain weights
1、Please download [resnet50_v1](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz), [resnet101_v1](http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz) pre-trained models on Imagenet, put it to data/pretrained_weights.
2、Or you can choose to use a better backbone, refer to [gluon2TF](https://github.com/yangJirui/gluon2TF). [Pretrain Model Link](https://pan.baidu.com/s/1GpqKg0dOaaWmwshvv1qWGg), password: 5ht9.

### Trained weights
**Select a configuration file in the folder ($PATH_ROOT/libs/configs/) and copy its contents into cfgs.py, then download the corresponding [weights](https://github.com/DetectionTeamUCAS/Models/tree/master/RetinaNet_Tensorflow).**
2、**(Recommend in this repo)** Or you can choose to use a better backbone, refer to [gluon2TF](https://github.com/yangJirui/gluon2TF).
* [Baidu Drive](https://pan.baidu.com/s/1GpqKg0dOaaWmwshvv1qWGg), password: 5ht9.
* [Google Drive](https://drive.google.com/drive/folders/1BM8ffn1WnsRRb5RcuAcyJAHX8NS2M1Gz?usp=sharing)

## Compile
```
Expand All @@ -44,7 +45,7 @@ python setup.py build_ext --inplace
```
(1) Modify parameters (such as CLASS_NUM, DATASET_NAME, VERSION, etc.) in $PATH_ROOT/libs/configs/cfgs.py
(2) Add category information in $PATH_ROOT/libs/label_name_dict/lable_dict.py
(3) Add data_name to line 76 of $PATH_ROOT/data/io/read_tfrecord.py
(3) Add data_name to $PATH_ROOT/data/io/read_tfrecord.py
```

2、make tfrecord
Expand All @@ -58,25 +59,35 @@ python convert_data_to_tfrecord_coco.py --VOC_dir='/PATH/TO/JSON/FILE/'
3、multi-gpu train
```
cd $PATH_ROOT/tools
python multi_gpu_train.py
python multi_gpu_train.py (multi_gpu_train_batch.py)
```

## Eval
### COCO
```
cd $PATH_ROOT/tools
python eval_coco.py --eval_data='/PATH/TO/IMAGES/'
--eval_gt='/PATH/TO/TEST/ANNOTATION/'
--GPU='0'
--gpu='0'
```

```
cd $PATH_ROOT/tools
python eval_coco_multiprocessing.py --eval_data='/PATH/TO/IMAGES/'
--eval_gt='/PATH/TO/TEST/ANNOTATION/'
--gpu_ids='0,1,2,3,4,5,6,7'
--gpus='0,1,2,3,4,5,6,7'
```

### PASCAL VOC
```
cd $PATH_ROOT/tools
python eval.py --eval_dir='/PATH/TO/IMAGES/'
--annotation_dir='/PATH/TO/TEST/ANNOTATION/'
--gpu='0'
```

## Tensorboard
```
cd $PATH_ROOT/output/summary
Expand Down
8 changes: 4 additions & 4 deletions data/io/convert_data_to_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from libs.label_name_dict.label_dict import *
from help_utils.tools import *

tf.app.flags.DEFINE_string('VOC_dir', '/data/code/VOC2007/VOCdevkit/VOC2007/', 'Voc dir')
tf.app.flags.DEFINE_string('VOC_dir', '/data/yangxue/dataset/VOC2007/VOCdevkit/VOC2007', 'Voc dir')
tf.app.flags.DEFINE_string('xml_dir', 'Annotations', 'xml dir')
tf.app.flags.DEFINE_string('image_dir', 'JPEGImages', 'image dir')
tf.app.flags.DEFINE_string('save_name', 'train', 'save name')
Expand Down Expand Up @@ -71,9 +71,9 @@ def read_xml_gtbox_and_label(xml_path):


def convert_pascal_to_tfrecord():
xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
image_path = FLAGS.VOC_dir + FLAGS.image_dir
save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
xml_path = os.path.join(FLAGS.VOC_dir, FLAGS.xml_dir)
image_path = os.path.join(FLAGS.VOC_dir, FLAGS.image_dir)
save_path = os.path.join(FLAGS.save_dir, FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord')
mkdir(FLAGS.save_dir)

# writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
Expand Down
6 changes: 3 additions & 3 deletions data/io/convert_data_to_tfrecord_voc2012.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def read_xml_gtbox_and_label(xml_path):


def convert_pascal_to_tfrecord():
xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
image_path = FLAGS.VOC_dir + FLAGS.image_dir
save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
xml_path = os.path.join(FLAGS.VOC_dir, FLAGS.xml_dir)
image_path = os.path.join(FLAGS.VOC_dir, FLAGS.image_dir)
save_path = os.path.join(FLAGS.save_dir, FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord')
mkdir(FLAGS.save_dir)

# writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
Expand Down
Empty file added libs/configs/COCO/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,104 +1,103 @@
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import os
import tensorflow as tf
import math

"""
epoch-00: 00.0 epoch-01: 7.40
epoch-02: 15.4 epoch-03: 18.8
epoch-04: 20.7 epoch-05: 23.0
epoch-06: 23.6 epoch-07: 25.3
epoch-08: 24.7 epoch-09: 26.7
epoch-11: 26.2 epoch-12: 30.7
epoch-13: 30.8 epoch-14: 31.1
epoch-15: 31.2 epoch-16: 31.4
epoch-19: 31.5
"""

# ------------------------------------------------
VERSION = 'RetinaNet_COCO_1x_20190522'
NET_NAME = 'resnet_v1_50' # 'MobilenetV2'
ADD_BOX_IN_TENSORBOARD = True

# ---------------------------------------- System_config
ROOT_PATH = os.path.abspath('../')
print(20*"++--")
print(ROOT_PATH)
GPU_GROUP = "0,1,2,3,4,5,6,7"
NUM_GPU = len(GPU_GROUP.strip().split(','))
SHOW_TRAIN_INFO_INTE = 20
SMRY_ITER = 200
SAVE_WEIGHTS_INTE = 20000 * 5

SUMMARY_PATH = ROOT_PATH + '/output/summary'
TEST_SAVE_PATH = ROOT_PATH + '/tools/test_result'

if NET_NAME.startswith("resnet"):
weights_name = NET_NAME
elif NET_NAME.startswith("MobilenetV2"):
weights_name = "mobilenet/mobilenet_v2_1.0_224"
else:
raise Exception('net name must in [resnet_v1_101, resnet_v1_50, MobilenetV2]')

PRETRAINED_CKPT = ROOT_PATH + '/data/pretrained_weights/' + weights_name + '.ckpt'
TRAINED_CKPT = os.path.join(ROOT_PATH, 'output/trained_weights')
EVALUATE_DIR = ROOT_PATH + '/output/evaluate_result_pickle/'

# ------------------------------------------ Train config
RESTORE_FROM_RPN = False
FIXED_BLOCKS = 1 # allow 0~3
FREEZE_BLOCKS = [True, False, False, False, False] # for gluoncv backbone
USE_07_METRIC = True

MUTILPY_BIAS_GRADIENT = None # 2.0 # if None, will not multipy
GRADIENT_CLIPPING_BY_NORM = None # 10.0 if None, will not clip

BATCH_SIZE = 1
EPSILON = 1e-5
MOMENTUM = 0.9
LR = 5e-4 * NUM_GPU * BATCH_SIZE
DECAY_STEP = [SAVE_WEIGHTS_INTE*12, SAVE_WEIGHTS_INTE*16, SAVE_WEIGHTS_INTE*20]
MAX_ITERATION = SAVE_WEIGHTS_INTE*20
WARM_SETP = int(1.0 / 8.0 * SAVE_WEIGHTS_INTE)

# -------------------------------------------- Data_preprocess_config
DATASET_NAME = 'coco' # 'pascal', 'coco'
PIXEL_MEAN = [123.68, 116.779, 103.939] # R, G, B. In tf, channel is RGB. In openCV, channel is BGR
PIXEL_MEAN_ = [0.485, 0.456, 0.406]
PIXEL_STD = [0.229, 0.224, 0.225] # R, G, B. In tf, channel is RGB. In openCV, channel is BGR
IMG_SHORT_SIDE_LEN = 600
IMG_MAX_LENGTH = 1000
CLASS_NUM = 80

# --------------------------------------------- Network_config
BATCH_SIZE = 1
SUBNETS_WEIGHTS_INITIALIZER = tf.random_normal_initializer(mean=0.0, stddev=0.01, seed=None)
SUBNETS_BIAS_INITIALIZER = tf.constant_initializer(value=0.0)
PROBABILITY = 0.01
FINAL_CONV_BIAS_INITIALIZER = tf.constant_initializer(value=-math.log((1.0 - PROBABILITY) / PROBABILITY))
WEIGHT_DECAY = 1e-4

# ---------------------------------------------Anchor config
LEVEL = ['P3', 'P4', 'P5', 'P6', 'P7']
BASE_ANCHOR_SIZE_LIST = [32, 64, 128, 256, 512]
ANCHOR_STRIDE = [8, 16, 32, 64, 128]
ANCHOR_SCALES = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
ANCHOR_RATIOS = [0.5, 1.0, 2.0]
ANCHOR_SCALE_FACTORS = None
USE_CENTER_OFFSET = True

# --------------------------------------------RPN config
SHARE_NET = True
USE_P5 = True
IOU_POSITIVE_THRESHOLD = 0.5
IOU_NEGATIVE_THRESHOLD = 0.4

NMS = True
NMS_IOU_THRESHOLD = 0.5
MAXIMUM_DETECTIONS = 100
FILTERED_SCORE = 0.05
VIS_SCORE = 0.5


# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import os
import tensorflow as tf
import math

"""
epoch-00: 00.0 epoch-01: 7.40
epoch-02: 15.4 epoch-03: 18.8
epoch-04: 20.7 epoch-05: 23.0
epoch-06: 23.6 epoch-07: 25.3
epoch-08: 24.7 epoch-09: 26.7
epoch-11: 26.2 epoch-12: 30.7
epoch-13: 30.8 epoch-14: 31.1
epoch-15: 31.2 epoch-16: 31.4
epoch-19: 31.5
"""

# ------------------------------------------------
VERSION = 'RetinaNet_COCO_1x_20190522'
NET_NAME = 'resnet_v1_50' # 'MobilenetV2'
ADD_BOX_IN_TENSORBOARD = True

# ---------------------------------------- System_config
ROOT_PATH = os.path.abspath('../')
print(20*"++--")
print(ROOT_PATH)
GPU_GROUP = "0,1,2,3,4,5,6,7"
NUM_GPU = len(GPU_GROUP.strip().split(','))
SHOW_TRAIN_INFO_INTE = 20
SMRY_ITER = 200
SAVE_WEIGHTS_INTE = 20000 * 5

SUMMARY_PATH = ROOT_PATH + '/output/summary'
TEST_SAVE_PATH = ROOT_PATH + '/tools/test_result'

if NET_NAME.startswith("resnet"):
weights_name = NET_NAME
elif NET_NAME.startswith("MobilenetV2"):
weights_name = "mobilenet/mobilenet_v2_1.0_224"
else:
raise Exception('net name must in [resnet_v1_101, resnet_v1_50, MobilenetV2]')

PRETRAINED_CKPT = ROOT_PATH + '/data/pretrained_weights/' + weights_name + '.ckpt'
TRAINED_CKPT = os.path.join(ROOT_PATH, 'output/trained_weights')
EVALUATE_DIR = ROOT_PATH + '/output/evaluate_result_pickle/'

# ------------------------------------------ Train config
RESTORE_FROM_RPN = False
FIXED_BLOCKS = 1 # allow 0~3
FREEZE_BLOCKS = [True, False, False, False, False] # for gluoncv backbone
USE_07_METRIC = True

MUTILPY_BIAS_GRADIENT = None # 2.0 # if None, will not multipy
GRADIENT_CLIPPING_BY_NORM = None # 10.0 if None, will not clip

BATCH_SIZE = 1
EPSILON = 1e-5
MOMENTUM = 0.9
LR = 5e-4 * NUM_GPU * BATCH_SIZE
DECAY_STEP = [SAVE_WEIGHTS_INTE*12, SAVE_WEIGHTS_INTE*16, SAVE_WEIGHTS_INTE*20]
MAX_ITERATION = SAVE_WEIGHTS_INTE*20
WARM_SETP = int(1.0 / 8.0 * SAVE_WEIGHTS_INTE)

# -------------------------------------------- Data_preprocess_config
DATASET_NAME = 'coco' # 'pascal', 'coco'
PIXEL_MEAN = [123.68, 116.779, 103.939] # R, G, B. In tf, channel is RGB. In openCV, channel is BGR
PIXEL_MEAN_ = [0.485, 0.456, 0.406]
PIXEL_STD = [0.229, 0.224, 0.225] # R, G, B. In tf, channel is RGB. In openCV, channel is BGR
IMG_SHORT_SIDE_LEN = 600
IMG_MAX_LENGTH = 1000
CLASS_NUM = 80

# --------------------------------------------- Network_config
SUBNETS_WEIGHTS_INITIALIZER = tf.random_normal_initializer(mean=0.0, stddev=0.01, seed=None)
SUBNETS_BIAS_INITIALIZER = tf.constant_initializer(value=0.0)
PROBABILITY = 0.01
FINAL_CONV_BIAS_INITIALIZER = tf.constant_initializer(value=-math.log((1.0 - PROBABILITY) / PROBABILITY))
WEIGHT_DECAY = 1e-4

# ---------------------------------------------Anchor config
LEVEL = ['P3', 'P4', 'P5', 'P6', 'P7']
BASE_ANCHOR_SIZE_LIST = [32, 64, 128, 256, 512]
ANCHOR_STRIDE = [8, 16, 32, 64, 128]
ANCHOR_SCALES = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
ANCHOR_RATIOS = [0.5, 1.0, 2.0]
ANCHOR_SCALE_FACTORS = None
USE_CENTER_OFFSET = True

# --------------------------------------------RPN config
SHARE_NET = True
USE_P5 = True
IOU_POSITIVE_THRESHOLD = 0.5
IOU_NEGATIVE_THRESHOLD = 0.4

NMS = True
NMS_IOU_THRESHOLD = 0.5
MAXIMUM_DETECTIONS = 100
FILTERED_SCORE = 0.05
VIS_SCORE = 0.5


Loading

0 comments on commit 0de3f0d

Please sign in to comment.