Freak-ppa's picture
Upload 36 files
d5779bb verified
import torch
import comfy.utils
from .Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace
from comfy.model_management import get_torch_device
class AutoCropFaces:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"number_of_faces": ("INT", {
"default": 5,
"min": 1,
"max": 100,
"step": 1,
}),
"scale_factor": ("FLOAT", {
"default": 1.5,
"min": 0.5,
"max": 10,
"step": 0.5,
"display": "slider"
}),
"shift_factor": ("FLOAT", {
"default": 0.45,
"min": 0,
"max": 1,
"step": 0.01,
"display": "slider"
}),
"start_index": ("INT", {
"default": 0,
"step": 1,
"display": "number"
}),
"max_faces_per_image": ("INT", {
"default": 50,
"min": 1,
"max": 1000,
"step": 1,
}),
# "aspect_ratio": ("FLOAT", {
# "default": 1,
# "min": 0.2,
# "max": 5,
# "step": 0.1,
# }),
"aspect_ratio": (["9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"], {
"default": "1:1",
}),
},
}
RETURN_TYPES = ("IMAGE", "CROP_DATA")
RETURN_NAMES = ("face",)
FUNCTION = "auto_crop_faces"
CATEGORY = "Faces"
def aspect_ratio_string_to_float(self, str_aspect_ratio="1:1"):
a, b = map(float, str_aspect_ratio.split(':'))
return a / b
def auto_crop_faces_in_image (self, image, max_number_of_faces, scale_factor, shift_factor, aspect_ratio, method='lanczos'):
image_255 = image * 255
rf = Pytorch_RetinaFace(top_k=50, keep_top_k=max_number_of_faces, device=get_torch_device())
dets = rf.detect_faces(image_255)
cropped_faces, bbox_info = rf.center_and_crop_rescale(image, dets, scale_factor=scale_factor, shift_factor=shift_factor, aspect_ratio=aspect_ratio)
# Add a batch dimension to each cropped face
cropped_faces_with_batch = [face.unsqueeze(0) for face in cropped_faces]
return cropped_faces_with_batch, bbox_info
def auto_crop_faces(self, image, number_of_faces, start_index, max_faces_per_image, scale_factor, shift_factor, aspect_ratio, method='lanczos'):
"""
"image" - Input can be one image or a batch of images with shape (batch, width, height, channel count)
"number_of_faces" - This is passed into PyTorch_RetinaFace which allows you to define a maximum number of faces to look for.
"start_index" - The starting index of which face you select out of the set of detected faces.
"scale_factor" - How much crop factor or padding do you want around each detected face.
"shift_factor" - Pan up or down relative to the face, 0.5 should be right in the center.
"aspect_ratio" - When we crop, you can have it crop down at a particular aspect ratio.
"method" - Scaling pixel sampling interpolation method.
"""
# Turn aspect ratio to float value
aspect_ratio = self.aspect_ratio_string_to_float(aspect_ratio)
selected_faces, detected_cropped_faces = [], []
selected_crop_data, detected_crop_data = [], []
original_images = []
# Loop through the input batches. Even if there is only one input image, it's still considered a batch.
for i in range(image.shape[0]):
original_images.append(image[i].unsqueeze(0)) # Temporarily the image, but insure it still has the batch dimension.
# Detect the faces in the image, this will return multiple images and crop data for it.
cropped_images, infos = self.auto_crop_faces_in_image(
image[i],
max_faces_per_image,
scale_factor,
shift_factor,
aspect_ratio,
method)
detected_cropped_faces.extend(cropped_images)
detected_crop_data.extend(infos)
# If we haven't detected anything, just return the original images, and default crop data.
if not detected_cropped_faces or len(detected_cropped_faces) == 0:
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
return (image, selected_crop_data)
# Circular index calculation
start_index = start_index % len(detected_cropped_faces)
if number_of_faces >= len(detected_cropped_faces):
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:start_index]
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:start_index]
else:
end_index = (start_index + number_of_faces) % len(detected_cropped_faces)
if start_index < end_index:
selected_faces = detected_cropped_faces[start_index:end_index]
selected_crop_data = detected_crop_data[start_index:end_index]
else:
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:end_index]
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:end_index]
# If we haven't selected anything, then return original images.
if len(selected_faces) == 0:
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
return (image, selected_crop_data)
# If there is only one detected face in batch of images, just return that one.
elif len(selected_faces) <= 1:
out = selected_faces[0]
return (out, selected_crop_data)
# Determine the index of the face with the maximum width
max_width_index = max(range(len(selected_faces)), key=lambda i: selected_faces[i].shape[1])
# Determine the maximum width
max_width = selected_faces[max_width_index].shape[1]
max_height = selected_faces[max_width_index].shape[2]
shape = (max_height, max_width)
out = None
# All images need to have the same width/height to fit into the tensor such that we can output as image batches.
for face_image in selected_faces:
if shape != face_image.shape[1:3]: # Determine whether cropped face image size matches largest cropped face image.
face_image = comfy.utils.common_upscale( # This method expects (batch, channel, height, width)
face_image.movedim(-1, 1), # Move channel dimension to width dimension
max_height, # Height
max_width, # Width
method, # Pixel sampling method.
"" # Only "center" is implemented right now, and we don't want to use that.
).movedim(1, -1)
# Append the fitted image into the tensor.
if out is None:
out = face_image
else:
out = torch.cat((out, face_image), dim=0)
return (out, selected_crop_data)
NODE_CLASS_MAPPINGS = {
"AutoCropFaces": AutoCropFaces
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"AutoCropFaces": "Auto Crop Faces"
}