V1 / pipelines.py
michaelapplydesign's picture
up
04c7187
raw
history blame
1.5 kB
import logging
import torch
import time
from diffusers import StableDiffusionInpaintPipeline
from helpers import flush
LOGGING = logging.getLogger(__name__)
class SDPipeline:
def __init__(self):
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
safety_checker=None,
)
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe = self.pipe.to("cuda")
self.waiting_queue = []
self.count = 0
@property
def queue_size(self):
return len(self.waiting_queue)
def __call__(self, **kwargs):
self.count += 1
number = self.count
self.waiting_queue.append(number)
# wait until the next number in the queue is the current number
while self.waiting_queue[0] != number:
print(f"Wait for your turn {number} in queue {self.waiting_queue}")
time.sleep(0.5)
pass
# it's your turn, so remove the number from the queue
# and call the function
print("It's the turn of", self.count)
results = self.pipe(**kwargs)
self.waiting_queue.pop(0)
flush()
return results
def get_inpainting_pipeline():
"""Method to load the inpainting pipeline
Returns:
StableDiffusionInpaintPipeline: inpainting pipeline
"""
pipe = SDPipeline()
return pipe