-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
309 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
data/ | ||
.vscode | ||
.idea | ||
.DS_Store | ||
|
||
# custom | ||
*.pkl | ||
*.pkl.json | ||
*.log.json | ||
work_dirs/ | ||
|
||
# Pytorch | ||
*.pth | ||
*.py~ | ||
*.sh~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = tensorflow | ||
known_third_party = tensorflow,torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .gaussian_target import gaussian_radius, gen_gaussian_target | ||
|
||
__all__ = ['gaussian_radius', 'gen_gaussian_target'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from math import sqrt | ||
|
||
import torch | ||
|
||
|
||
def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'): | ||
"""Generate 2D gaussian kernel. | ||
Args: | ||
radius (int): Radius of gaussian kernel. | ||
sigma (int): Sigma of gaussian function. Default: 1. | ||
dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32. | ||
device (str): Device of gaussian tensor. Default: 'cpu'. | ||
Returns: | ||
h (Tensor): Gaussian kernel with a | ||
``(2 * radius + 1) * (2 * radius + 1)`` shape. | ||
""" | ||
x = torch.arange( | ||
-radius, radius + 1, dtype=dtype, device=device).view(1, -1) | ||
y = torch.arange( | ||
-radius, radius + 1, dtype=dtype, device=device).view(-1, 1) | ||
|
||
h = (-(x * x + y * y) / (2 * sigma * sigma)).exp() | ||
|
||
h[h < torch.finfo(h.dtype).eps * h.max()] = 0 | ||
return h | ||
|
||
|
||
def gen_gaussian_target(heatmap, center, radius, k=1): | ||
"""Generate 2D gaussian heatmap. | ||
Args: | ||
heatmap (Tensor): Input heatmap, the gaussian kernel will cover on | ||
it and maintain the max value. | ||
center (list[int]): Coord of gaussian kernel's center. | ||
radius (int): Radius of gaussian kernel. | ||
k (int): Coefficient of gaussian kernel. Default: 1. | ||
Returns: | ||
out_heatmap (Tensor): Updated heatmap covered by gaussian kernel. | ||
""" | ||
diameter = 2 * radius + 1 | ||
gaussian_kernel = gaussian2D( | ||
radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device) | ||
|
||
x, y = center | ||
|
||
height, width = heatmap.shape[:2] | ||
|
||
left, right = min(x, radius), min(width - x, radius + 1) | ||
top, bottom = min(y, radius), min(height - y, radius + 1) | ||
|
||
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] | ||
masked_gaussian = gaussian_kernel[radius - top:radius + bottom, | ||
radius - left:radius + right] | ||
out_heatmap = heatmap | ||
torch.max( | ||
masked_heatmap, | ||
masked_gaussian * k, | ||
out=out_heatmap[y - top:y + bottom, x - left:x + right]) | ||
|
||
return out_heatmap | ||
|
||
|
||
def gaussian_radius(det_size, min_overlap): | ||
r"""Generate 2D gaussian radius. | ||
This function is modified from the `official github repo | ||
<https://github.com/princeton-vl/CornerNet-Lite/blob/master/core/sample/ | ||
utils.py#L65>`_. | ||
Given ``min_overlap``, radius could computed by a quadratic equation | ||
according to Vieta's formulas. | ||
There are 3 cases for computing gaussian radius, details are following: | ||
- Explanation of figure: ``lt`` and ``br`` indicates the left-top and | ||
bottom-right corner of ground truth box. ``x`` indicates the | ||
generated corner at the limited position when ``radius=r``. | ||
- Case1: one corner is inside the gt box and the other is outside. | ||
.. code:: text | ||
|< width >| | ||
lt-+----------+ - | ||
| | | ^ | ||
+--x----------+--+ | ||
| | | | | ||
| | | | height | ||
| | overlap | | | ||
| | | | | ||
| | | | v | ||
+--+---------br--+ - | ||
| | | | ||
+----------+--x | ||
To ensure IoU of generated box and gt box is larger than ``min_overlap``: | ||
.. math:: | ||
\cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad | ||
{r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\ | ||
{a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h} | ||
{r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a} | ||
- Case2: both two corners are inside the gt box. | ||
.. code:: text | ||
|< width >| | ||
lt-+----------+ - | ||
| | | ^ | ||
+--x-------+ | | ||
| | | | | ||
| |overlap| | height | ||
| | | | | ||
| +-------x--+ | ||
| | | v | ||
+----------+-br - | ||
To ensure IoU of generated box and gt box is larger than ``min_overlap``: | ||
.. math:: | ||
\cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad | ||
{4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\ | ||
{a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h} | ||
{r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a} | ||
- Case3: both two corners are outside the gt box. | ||
.. code:: text | ||
|< width >| | ||
x--+----------------+ | ||
| | | | ||
+-lt-------------+ | - | ||
| | | | ^ | ||
| | | | | ||
| | overlap | | height | ||
| | | | | ||
| | | | v | ||
| +------------br--+ - | ||
| | | | ||
+----------------+--x | ||
To ensure IoU of generated box and gt box is larger than ``min_overlap``: | ||
.. math:: | ||
\cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad | ||
{4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\ | ||
{a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\ | ||
{r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a} | ||
Args: | ||
det_size (list[int]): Shape of object. | ||
min_overlap (float): Min IoU with ground truth for boxes generated by | ||
keypoints inside the gaussian kernel. | ||
Returns: | ||
radius (int): Radius of gaussian kernel. | ||
""" | ||
height, width = det_size | ||
|
||
a1 = 1 | ||
b1 = (height + width) | ||
c1 = width * height * (1 - min_overlap) / (1 + min_overlap) | ||
sq1 = sqrt(b1**2 - 4 * a1 * c1) | ||
r1 = (b1 - sq1) / (2 * a1) | ||
|
||
a2 = 4 | ||
b2 = 2 * (height + width) | ||
c2 = (1 - min_overlap) * width * height | ||
sq2 = sqrt(b2**2 - 4 * a2 * c2) | ||
r2 = (b2 - sq2) / (2 * a2) | ||
|
||
a3 = 4 * min_overlap | ||
b3 = -2 * min_overlap * (height + width) | ||
c3 = (min_overlap - 1) * width * height | ||
sq3 = sqrt(b3**2 - 4 * a3 * c3) | ||
r3 = (b3 + sq3) / (2 * a3) | ||
return min(r1, r2, r3) |