Skip to content

Commit

Permalink
Add scale_stick_for_xinsr_cn on DWPose and OpenPose. Close #447
Browse files Browse the repository at this point in the history
  • Loading branch information
Fannovel16 committed Sep 4, 2024
1 parent df91818 commit b1094d7
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 19 deletions.
8 changes: 5 additions & 3 deletions node_wrappers/dwpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ def INPUT_TYPES(s):
pose_estimator=INPUT.COMBO(
["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"],
default="dw-ll_ucoco_384_bs5.torchscript.pt"
)
),
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
)

RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
FUNCTION = "estimate_pose"

CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"

def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", **kwargs):
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", scale_stick_for_xinsr_cn="disable", **kwargs):
if bbox_detector == "yolox_l.onnx":
yolo_repo = DWPOSE_MODEL_NAME
elif "yolox" in bbox_detector:
Expand Down Expand Up @@ -78,13 +79,14 @@ def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detec
detect_hand = detect_hand == "enable"
detect_body = detect_body == "enable"
detect_face = detect_face == "enable"
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
self.openpose_dicts = []
def func(image, **kwargs):
pose_img, openpose_dict = model(image, **kwargs)
self.openpose_dicts.append(openpose_dict)
return pose_img

out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution)
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution, xinsr_stick_scaling=scale_stick_for_xinsr_cn)
del model
return {
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
Expand Down
8 changes: 5 additions & 3 deletions node_wrappers/openpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,22 @@ def INPUT_TYPES(s):
detect_hand=INPUT.COMBO(["enable", "disable"]),
detect_body=INPUT.COMBO(["enable", "disable"]),
detect_face=INPUT.COMBO(["enable", "disable"]),
resolution=INPUT.RESOLUTION()
resolution=INPUT.RESOLUTION(),
scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
)

RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
FUNCTION = "estimate_pose"

CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"

def estimate_pose(self, image, detect_hand, detect_body, detect_face, resolution=512, **kwargs):
def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", scale_stick_for_xinsr_cn="disable", resolution=512, **kwargs):
from custom_controlnet_aux.open_pose import OpenposeDetector

detect_hand = detect_hand == "enable"
detect_body = detect_body == "enable"
detect_face = detect_face == "enable"
scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"

model = OpenposeDetector.from_pretrained().to(model_management.get_torch_device())
self.openpose_dicts = []
Expand All @@ -31,7 +33,7 @@ def func(image, **kwargs):
self.openpose_dicts.append(openpose_dict)
return pose_img

out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution)
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, xinsr_stick_scaling=scale_stick_for_xinsr_cn, resolution=resolution)
del model
return {
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "comfyui_controlnet_aux"
description = "Plug-and-play ComfyUI node sets for making ControlNet hint images"

version = "1.0.4-alpha.8"
version = "1.0.4-alpha.9"
dependencies = ["torch", "importlib_metadata", "huggingface_hub", "scipy", "opencv-python>=4.7.0.72", "filelock", "numpy", "Pillow", "einops", "torchvision", "pyyaml", "scikit-image", "python-dateutil", "mediapipe", "svglib", "fvcore", "yapf", "omegaconf", "ftfy", "addict", "yacs", "trimesh[easy]", "albumentations", "scikit-learn", "matplotlib"]

[project.urls]
Expand Down
8 changes: 4 additions & 4 deletions src/custom_controlnet_aux/dwpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def draw_animalpose(canvas: np.ndarray, keypoints: list[Keypoint]) -> np.ndarray
return canvas


def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True, xinsr_stick_scaling=False):
"""
Draw the detected poses on an empty canvas.
Expand All @@ -110,7 +110,7 @@ def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, dr

for pose in poses:
if draw_body:
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
canvas = util.draw_bodypose(canvas, pose.body.keypoints, xinsr_stick_scaling)

if draw_hand:
canvas = util.draw_handpose(canvas, pose.left_hand)
Expand Down Expand Up @@ -252,7 +252,7 @@ def detect_poses(self, oriImg) -> List[PoseResult]:
keypoints_info = self.dw_pose_estimation(oriImg.copy())
return Wholebody.format_result(keypoints_info)

def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", xinsr_stick_scaling=False, **kwargs):
if hand_and_face is not None:
warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
include_hand = hand_and_face
Expand All @@ -262,7 +262,7 @@ def __call__(self, input_image, detect_resolution=512, include_body=True, includ
input_image, _ = resize_image_with_pad(input_image, 0, upscale_method)
poses = self.detect_poses(input_image)

canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face, xinsr_stick_scaling=xinsr_stick_scaling)
canvas, remove_pad = resize_image_with_pad(canvas, detect_resolution, upscale_method)
detected_map = HWC3(remove_pad(canvas))

Expand Down
13 changes: 11 additions & 2 deletions src/custom_controlnet_aux/dwpose/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def is_normalized(keypoints: List[Optional[Keypoint]]) -> bool:
return all(point_normalized)


def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint], xinsr_stick_scaling: bool = False) -> np.ndarray:
"""
Draw keypoints and limbs representing body pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
xinsr_stick_scaling (bool): Whether or not scaling stick width for xinsr ControlNet
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
Expand All @@ -98,8 +99,16 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
else:
H, W, _ = canvas.shape

CH, CW, _ = canvas.shape
stickwidth = 4

# Ref: https://huggingface.co/xinsir/controlnet-openpose-sdxl-1.0
max_side = max(CW, CH)
if xinsr_stick_scaling:
stick_scale = 1 if max_side < 500 else min(2 + (max_side // 1000), 7)
else:
stick_scale = 1

limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5],
[6, 7], [7, 8], [2, 9], [9, 10],
Expand All @@ -125,7 +134,7 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth*stick_scale), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])

for keypoint, color in zip(keypoints, colors):
Expand Down
8 changes: 4 additions & 4 deletions src/custom_controlnet_aux/open_pose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PoseResult(NamedTuple):
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]

def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True, xinsr_stick_scaling=False):
"""
Draw the detected poses on an empty canvas.
Expand All @@ -55,7 +55,7 @@ def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, dr

for pose in poses:
if draw_body:
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
canvas = util.draw_bodypose(canvas, pose.body.keypoints, xinsr_stick_scaling)

if draw_hand:
canvas = util.draw_handpose(canvas, pose.left_hand)
Expand Down Expand Up @@ -216,7 +216,7 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P

return results

def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", xinsr_stick_scaling=False, **kwargs):
if hand_and_face is not None:
warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
include_hand = hand_and_face
Expand All @@ -226,7 +226,7 @@ def __call__(self, input_image, detect_resolution=512, include_body=True, includ
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)

poses = self.detect_poses(input_image, include_hand=include_hand, include_face=include_face)
canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face, xinsr_stick_scaling=xinsr_stick_scaling)
detected_map = HWC3(remove_pad(canvas))

if output_type == "pil":
Expand Down
11 changes: 9 additions & 2 deletions src/custom_controlnet_aux/open_pose/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def transfer(model, model_weights):
return transfered_model_weights


def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint], xinsr_stick_scaling: bool = False) -> np.ndarray:
"""
Draw keypoints and limbs representing body pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
xinsr_stick_scaling (bool): Whether or not scaling stick width for xinsr ControlNet
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
Expand All @@ -83,6 +84,12 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
"""
H, W, C = canvas.shape
stickwidth = 4
# Ref: https://huggingface.co/xinsir/controlnet-openpose-sdxl-1.0
max_side = max(H, W)
if xinsr_stick_scaling:
stick_scale = 1 if max_side < 500 else min(2 + (max_side // 1000), 7)
else:
stick_scale = 1

limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5],
Expand All @@ -109,7 +116,7 @@ def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth*stick_scale), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])

for keypoint, color in zip(keypoints, colors):
Expand Down

2 comments on commit b1094d7

@Satoshi-Yoda
Copy link

@Satoshi-Yoda Satoshi-Yoda commented on b1094d7 Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wow, that was fast!! ^_^

  • there is also that important node affected: "RenderPeopleKps" aka "Render Pose JSON (Human)", now it renders the default thin lines. Looks like it assumes dwpose format as input?
  • in the reference impl there was also scale applied to the dots, probably not such a big impact as lines, but who can tell...
    around here cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) // 4 * stick_scale
  • I believe it was spelled "xinsir" on huggingface, if I am not missing something =)

@Fannovel16
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Satoshi-Yoda

  • I forgot it xD
  • Will add that. Tks
  • Changing the name will be fine as long as the order of parameters is kept

Please sign in to comment.