Spaces:
Build error
Build error
import torch | |
import comfy.utils | |
from .Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace | |
from comfy.model_management import get_torch_device | |
class AutoCropFaces: | |
def __init__(self): | |
pass | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"number_of_faces": ("INT", { | |
"default": 5, | |
"min": 1, | |
"max": 100, | |
"step": 1, | |
}), | |
"scale_factor": ("FLOAT", { | |
"default": 1.5, | |
"min": 0.5, | |
"max": 10, | |
"step": 0.5, | |
"display": "slider" | |
}), | |
"shift_factor": ("FLOAT", { | |
"default": 0.45, | |
"min": 0, | |
"max": 1, | |
"step": 0.01, | |
"display": "slider" | |
}), | |
"start_index": ("INT", { | |
"default": 0, | |
"step": 1, | |
"display": "number" | |
}), | |
"max_faces_per_image": ("INT", { | |
"default": 50, | |
"min": 1, | |
"max": 1000, | |
"step": 1, | |
}), | |
# "aspect_ratio": ("FLOAT", { | |
# "default": 1, | |
# "min": 0.2, | |
# "max": 5, | |
# "step": 0.1, | |
# }), | |
"aspect_ratio": (["9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"], { | |
"default": "1:1", | |
}), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE", "CROP_DATA") | |
RETURN_NAMES = ("face",) | |
FUNCTION = "auto_crop_faces" | |
CATEGORY = "Faces" | |
def aspect_ratio_string_to_float(self, str_aspect_ratio="1:1"): | |
a, b = map(float, str_aspect_ratio.split(':')) | |
return a / b | |
def auto_crop_faces_in_image (self, image, max_number_of_faces, scale_factor, shift_factor, aspect_ratio, method='lanczos'): | |
image_255 = image * 255 | |
rf = Pytorch_RetinaFace(top_k=50, keep_top_k=max_number_of_faces, device=get_torch_device()) | |
dets = rf.detect_faces(image_255) | |
cropped_faces, bbox_info = rf.center_and_crop_rescale(image, dets, scale_factor=scale_factor, shift_factor=shift_factor, aspect_ratio=aspect_ratio) | |
# Add a batch dimension to each cropped face | |
cropped_faces_with_batch = [face.unsqueeze(0) for face in cropped_faces] | |
return cropped_faces_with_batch, bbox_info | |
def auto_crop_faces(self, image, number_of_faces, start_index, max_faces_per_image, scale_factor, shift_factor, aspect_ratio, method='lanczos'): | |
""" | |
"image" - Input can be one image or a batch of images with shape (batch, width, height, channel count) | |
"number_of_faces" - This is passed into PyTorch_RetinaFace which allows you to define a maximum number of faces to look for. | |
"start_index" - The starting index of which face you select out of the set of detected faces. | |
"scale_factor" - How much crop factor or padding do you want around each detected face. | |
"shift_factor" - Pan up or down relative to the face, 0.5 should be right in the center. | |
"aspect_ratio" - When we crop, you can have it crop down at a particular aspect ratio. | |
"method" - Scaling pixel sampling interpolation method. | |
""" | |
# Turn aspect ratio to float value | |
aspect_ratio = self.aspect_ratio_string_to_float(aspect_ratio) | |
selected_faces, detected_cropped_faces = [], [] | |
selected_crop_data, detected_crop_data = [], [] | |
original_images = [] | |
# Loop through the input batches. Even if there is only one input image, it's still considered a batch. | |
for i in range(image.shape[0]): | |
original_images.append(image[i].unsqueeze(0)) # Temporarily the image, but insure it still has the batch dimension. | |
# Detect the faces in the image, this will return multiple images and crop data for it. | |
cropped_images, infos = self.auto_crop_faces_in_image( | |
image[i], | |
max_faces_per_image, | |
scale_factor, | |
shift_factor, | |
aspect_ratio, | |
method) | |
detected_cropped_faces.extend(cropped_images) | |
detected_crop_data.extend(infos) | |
# If we haven't detected anything, just return the original images, and default crop data. | |
if not detected_cropped_faces or len(detected_cropped_faces) == 0: | |
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images] | |
return (image, selected_crop_data) | |
# Circular index calculation | |
start_index = start_index % len(detected_cropped_faces) | |
if number_of_faces >= len(detected_cropped_faces): | |
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:start_index] | |
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:start_index] | |
else: | |
end_index = (start_index + number_of_faces) % len(detected_cropped_faces) | |
if start_index < end_index: | |
selected_faces = detected_cropped_faces[start_index:end_index] | |
selected_crop_data = detected_crop_data[start_index:end_index] | |
else: | |
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:end_index] | |
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:end_index] | |
# If we haven't selected anything, then return original images. | |
if len(selected_faces) == 0: | |
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images] | |
return (image, selected_crop_data) | |
# If there is only one detected face in batch of images, just return that one. | |
elif len(selected_faces) <= 1: | |
out = selected_faces[0] | |
return (out, selected_crop_data) | |
# Determine the index of the face with the maximum width | |
max_width_index = max(range(len(selected_faces)), key=lambda i: selected_faces[i].shape[1]) | |
# Determine the maximum width | |
max_width = selected_faces[max_width_index].shape[1] | |
max_height = selected_faces[max_width_index].shape[2] | |
shape = (max_height, max_width) | |
out = None | |
# All images need to have the same width/height to fit into the tensor such that we can output as image batches. | |
for face_image in selected_faces: | |
if shape != face_image.shape[1:3]: # Determine whether cropped face image size matches largest cropped face image. | |
face_image = comfy.utils.common_upscale( # This method expects (batch, channel, height, width) | |
face_image.movedim(-1, 1), # Move channel dimension to width dimension | |
max_height, # Height | |
max_width, # Width | |
method, # Pixel sampling method. | |
"" # Only "center" is implemented right now, and we don't want to use that. | |
).movedim(1, -1) | |
# Append the fitted image into the tensor. | |
if out is None: | |
out = face_image | |
else: | |
out = torch.cat((out, face_image), dim=0) | |
return (out, selected_crop_data) | |
NODE_CLASS_MAPPINGS = { | |
"AutoCropFaces": AutoCropFaces | |
} | |
# A dictionary that contains the friendly/humanly readable titles for the nodes | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"AutoCropFaces": "Auto Crop Faces" | |
} | |