flux-controlnet-inpainting / preprocessor.py
jiuface's picture
bugfix
0611e52
import gc
import numpy as np
import PIL.Image
import torch
import torchvision
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LineartAnimeDetector,
LineartDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
)
from controlnet_aux.util import HWC3
from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor
from kornia.core import Tensor
# load preprocessor
# HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
Canny = CannyDetector()
OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
if name == "Midas":
self.model = Midas
elif name == "MLSD":
self.model =MLSD
elif name == "Openpose":
self.model = OPENPOSE
elif name == "Canny":
self.model = Canny
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
if self.name == "Canny" or self.name == "MLSD":
detect_resolution = kwargs.pop("detect_resolution")
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image).convert('RGB')
else:
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)