Skip to content

Commit 0b5c08c

Browse files
Merge pull request #216 from fastlabel/feature/update-coco-format-rotation
COCO形式のフィールドに回転の要素を追加
2 parents 270074f + c110ebd commit 0b5c08c

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

fastlabel/converters.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def get_annotation_points(anno, _):
7373
"annotation_value": annotation["value"],
7474
"annotation_type": annotation["type"],
7575
"annotation_points": get_annotation_points(annotation, index),
76+
"annotation_rotation": annotation.get("rotation", 0),
7677
"annotation_keypoints": annotation.get("keypoints"),
7778
"annotation_attributes": _get_coco_annotation_attributes(
7879
annotation
@@ -204,6 +205,7 @@ def __to_coco_annotation(data: dict) -> dict:
204205
image_id = data["image_id"]
205206
points = data["annotation_points"]
206207
keypoints = data["annotation_keypoints"]
208+
rotation = data["annotation_rotation"]
207209
annotation_type = data["annotation_type"]
208210
annotation_value = data["annotation_value"]
209211
annotation_id = 0
@@ -237,6 +239,7 @@ def __to_coco_annotation(data: dict) -> dict:
237239
image_id,
238240
annotation_type,
239241
annotation_attributes,
242+
rotation
240243
)
241244

242245

@@ -268,6 +271,7 @@ def __get_coco_annotation(
268271
image_id: str,
269272
annotation_type: str,
270273
annotation_attributes: Dict[str, AttributeValue],
274+
rotation: int
271275
) -> dict:
272276
annotation = {}
273277
annotation["num_keypoints"] = len(keypoints) if keypoints else 0
@@ -278,13 +282,75 @@ def __get_coco_annotation(
278282
annotation["iscrowd"] = 0
279283
annotation["area"] = __to_area(annotation_type, points)
280284
annotation["image_id"] = image_id
281-
annotation["bbox"] = __to_bbox(annotation_type, points)
285+
annotation["bbox"] = (
286+
__get_coco_bbox(points, rotation)
287+
if annotation_type == AnnotationType.bbox
288+
else __to_bbox(annotation_type, points)
289+
)
290+
annotation["rotation"] = rotation
282291
annotation["category_id"] = category["id"]
283292
annotation["id"] = id_
284293
annotation["attributes"] = annotation_attributes
285294
return annotation
286295

287296

297+
def __rotate_point(
298+
cx: float, cy: float, angle: float, px: float, py: float
299+
) -> np.ndarray:
300+
px -= cx
301+
py -= cy
302+
303+
x_new = px * math.cos(angle) - py * math.sin(angle)
304+
y_new = px * math.sin(angle) + py * math.cos(angle)
305+
306+
px = x_new + cx
307+
py = y_new + cy
308+
return np.array([px, py])
309+
310+
311+
def __get_rotated_rectangle_coordinates(
312+
coords: np.ndarray, rotation: int
313+
) -> np.ndarray:
314+
top_left = coords[0]
315+
bottom_right = coords[1]
316+
317+
cx = (top_left[0] + bottom_right[0]) / 2
318+
cy = (top_left[1] + bottom_right[1]) / 2
319+
320+
top_right = np.array([bottom_right[0], top_left[1]])
321+
bottom_left = np.array([top_left[0], bottom_right[1]])
322+
323+
corners = np.array([top_left, top_right, bottom_right, bottom_left])
324+
325+
angle_rad = math.radians(rotation)
326+
rotated_corners = np.array(
327+
[__rotate_point(cx, cy, angle_rad, x, y) for x, y in corners]
328+
)
329+
330+
return rotated_corners
331+
332+
def __get_coco_bbox(
333+
points: list,
334+
rotation: int,
335+
) -> list[float]:
336+
if not points:
337+
return []
338+
points_splitted = [points[idx : idx + 2] for idx in range(0, len(points), 2)]
339+
polygon_geo = geojson.Polygon(points_splitted)
340+
coords = np.array(list(geojson.utils.coords(polygon_geo)))
341+
rotated_coords = __get_rotated_rectangle_coordinates(coords, rotation)
342+
x_min = rotated_coords[:, 0].min()
343+
y_min = rotated_coords[:, 1].min()
344+
x_max = rotated_coords[:, 0].max()
345+
y_max = rotated_coords[:, 1].max()
346+
return [
347+
x_min, # x
348+
y_min, # y
349+
x_max - x_min, # width
350+
y_max - y_min, # height
351+
]
352+
353+
288354
def __get_without_hollowed_points(points: list) -> list:
289355
return [region[0] for region in points]
290356

@@ -295,6 +361,10 @@ def __to_coco_segmentation(annotation_type: str, points: list) -> list:
295361
if annotation_type == AnnotationType.segmentation.value:
296362
# Remove hollowed points
297363
return __get_without_hollowed_points(points)
364+
if annotation_type == AnnotationType.bbox.value:
365+
x1, y1, x2, y2 = points
366+
rectangle_points = [x1, y1, x2, y1, x2, y2, x1, y2, x1, y1]
367+
return [rectangle_points]
298368
return [points]
299369

300370

0 commit comments

Comments
 (0)