Spaces:
Build error
Build error
File size: 7,995 Bytes
d5779bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import comfy.utils
from .Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace
from comfy.model_management import get_torch_device
class AutoCropFaces:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"number_of_faces": ("INT", {
"default": 5,
"min": 1,
"max": 100,
"step": 1,
}),
"scale_factor": ("FLOAT", {
"default": 1.5,
"min": 0.5,
"max": 10,
"step": 0.5,
"display": "slider"
}),
"shift_factor": ("FLOAT", {
"default": 0.45,
"min": 0,
"max": 1,
"step": 0.01,
"display": "slider"
}),
"start_index": ("INT", {
"default": 0,
"step": 1,
"display": "number"
}),
"max_faces_per_image": ("INT", {
"default": 50,
"min": 1,
"max": 1000,
"step": 1,
}),
# "aspect_ratio": ("FLOAT", {
# "default": 1,
# "min": 0.2,
# "max": 5,
# "step": 0.1,
# }),
"aspect_ratio": (["9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"], {
"default": "1:1",
}),
},
}
RETURN_TYPES = ("IMAGE", "CROP_DATA")
RETURN_NAMES = ("face",)
FUNCTION = "auto_crop_faces"
CATEGORY = "Faces"
def aspect_ratio_string_to_float(self, str_aspect_ratio="1:1"):
a, b = map(float, str_aspect_ratio.split(':'))
return a / b
def auto_crop_faces_in_image (self, image, max_number_of_faces, scale_factor, shift_factor, aspect_ratio, method='lanczos'):
image_255 = image * 255
rf = Pytorch_RetinaFace(top_k=50, keep_top_k=max_number_of_faces, device=get_torch_device())
dets = rf.detect_faces(image_255)
cropped_faces, bbox_info = rf.center_and_crop_rescale(image, dets, scale_factor=scale_factor, shift_factor=shift_factor, aspect_ratio=aspect_ratio)
# Add a batch dimension to each cropped face
cropped_faces_with_batch = [face.unsqueeze(0) for face in cropped_faces]
return cropped_faces_with_batch, bbox_info
def auto_crop_faces(self, image, number_of_faces, start_index, max_faces_per_image, scale_factor, shift_factor, aspect_ratio, method='lanczos'):
"""
"image" - Input can be one image or a batch of images with shape (batch, width, height, channel count)
"number_of_faces" - This is passed into PyTorch_RetinaFace which allows you to define a maximum number of faces to look for.
"start_index" - The starting index of which face you select out of the set of detected faces.
"scale_factor" - How much crop factor or padding do you want around each detected face.
"shift_factor" - Pan up or down relative to the face, 0.5 should be right in the center.
"aspect_ratio" - When we crop, you can have it crop down at a particular aspect ratio.
"method" - Scaling pixel sampling interpolation method.
"""
# Turn aspect ratio to float value
aspect_ratio = self.aspect_ratio_string_to_float(aspect_ratio)
selected_faces, detected_cropped_faces = [], []
selected_crop_data, detected_crop_data = [], []
original_images = []
# Loop through the input batches. Even if there is only one input image, it's still considered a batch.
for i in range(image.shape[0]):
original_images.append(image[i].unsqueeze(0)) # Temporarily the image, but insure it still has the batch dimension.
# Detect the faces in the image, this will return multiple images and crop data for it.
cropped_images, infos = self.auto_crop_faces_in_image(
image[i],
max_faces_per_image,
scale_factor,
shift_factor,
aspect_ratio,
method)
detected_cropped_faces.extend(cropped_images)
detected_crop_data.extend(infos)
# If we haven't detected anything, just return the original images, and default crop data.
if not detected_cropped_faces or len(detected_cropped_faces) == 0:
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
return (image, selected_crop_data)
# Circular index calculation
start_index = start_index % len(detected_cropped_faces)
if number_of_faces >= len(detected_cropped_faces):
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:start_index]
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:start_index]
else:
end_index = (start_index + number_of_faces) % len(detected_cropped_faces)
if start_index < end_index:
selected_faces = detected_cropped_faces[start_index:end_index]
selected_crop_data = detected_crop_data[start_index:end_index]
else:
selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:end_index]
selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:end_index]
# If we haven't selected anything, then return original images.
if len(selected_faces) == 0:
selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
return (image, selected_crop_data)
# If there is only one detected face in batch of images, just return that one.
elif len(selected_faces) <= 1:
out = selected_faces[0]
return (out, selected_crop_data)
# Determine the index of the face with the maximum width
max_width_index = max(range(len(selected_faces)), key=lambda i: selected_faces[i].shape[1])
# Determine the maximum width
max_width = selected_faces[max_width_index].shape[1]
max_height = selected_faces[max_width_index].shape[2]
shape = (max_height, max_width)
out = None
# All images need to have the same width/height to fit into the tensor such that we can output as image batches.
for face_image in selected_faces:
if shape != face_image.shape[1:3]: # Determine whether cropped face image size matches largest cropped face image.
face_image = comfy.utils.common_upscale( # This method expects (batch, channel, height, width)
face_image.movedim(-1, 1), # Move channel dimension to width dimension
max_height, # Height
max_width, # Width
method, # Pixel sampling method.
"" # Only "center" is implemented right now, and we don't want to use that.
).movedim(1, -1)
# Append the fitted image into the tensor.
if out is None:
out = face_image
else:
out = torch.cat((out, face_image), dim=0)
return (out, selected_crop_data)
NODE_CLASS_MAPPINGS = {
"AutoCropFaces": AutoCropFaces
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"AutoCropFaces": "Auto Crop Faces"
}
|