File size: 8,347 Bytes
232c234
 
1a688bc
cb5daed
4d6f2bc
7736f5f
4d6f2bc
 
232c234
4d6f2bc
232c234
eb8fc69
4d6f2bc
 
232c234
48c31e7
dffd0bb
23f4f95
eb8fc69
232c234
ca5a1e4
 
4d6f2bc
dffd0bb
 
 
232c234
 
 
 
 
 
53eff53
232c234
 
 
 
 
 
 
 
 
 
 
 
1a688bc
4d6f2bc
 
 
48c31e7
4d6f2bc
 
48c31e7
4d6f2bc
 
 
 
 
 
48c31e7
4d6f2bc
 
 
 
 
 
053b3a4
232c234
053b3a4
1a688bc
232c234
053b3a4
1a688bc
 
 
 
 
 
 
232c234
eb8fc69
 
 
 
 
 
 
 
 
232c234
 
 
 
 
eb8fc69
 
 
 
cb5daed
4d6f2bc
 
 
60849d7
61ad3d2
 
23f4f95
1a688bc
4d6f2bc
1a688bc
 
48c31e7
 
1128e78
1a688bc
60849d7
4d6f2bc
1a688bc
48c31e7
c348e53
48c31e7
 
4d6f2bc
60849d7
05246f1
 
5c4e8c1
4d6f2bc
1128e78
5c4e8c1
1128e78
1a688bc
 
 
48c31e7
22a0476
cb5daed
22a0476
48c31e7
22a0476
48c31e7
 
 
 
 
 
 
 
4d6f2bc
60849d7
 
61ad3d2
 
 
 
 
4d6f2bc
7736f5f
4d6f2bc
05246f1
60849d7
61ad3d2
cb5daed
 
 
 
c348e53
60849d7
05246f1
22a0476
60849d7
cb5daed
4d6f2bc
23f4f95
 
 
 
 
 
 
 
 
 
61ad3d2
 
 
23f4f95
 
61ad3d2
23f4f95
4d6f2bc
 
60849d7
 
 
 
22a0476
48c31e7
60849d7
4d6f2bc
 
 
48c31e7
dffd0bb
 
1a688bc
 
dffd0bb
 
4d6f2bc
 
48c31e7
22a0476
1128e78
dffd0bb
 
 
 
1a688bc
 
dffd0bb
 
 
 
 
4d6f2bc
60849d7
 
 
 
 
 
 
 
 
 
 
 
 
eb8fc69
60849d7
61ad3d2
232c234
 
 
61ad3d2
c5cf566
 
 
 
 
 
 
 
4d6f2bc
 
 
 
cb5daed
05246f1
 
4d6f2bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import functools
import inspect
import json
import os
import re
import time
from datetime import datetime
from itertools import product
from typing import Callable, TypeVar

import anyio
import numpy as np
import spaces
import torch
from anyio import Semaphore
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from PIL import Image
from typing_extensions import ParamSpec

from .loader import Loader

__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
__import__("transformers").logging.set_verbosity_error()

T = TypeVar("T")
P = ParamSpec("P")

MAX_CONCURRENT_THREADS = 1
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)

with open("./data/styles.json") as f:
    STYLES = json.load(f)


# like the original but supports args and kwargs instead of a dict
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
    async with MAX_THREADS_GUARD:
        sig = inspect.signature(fn)
        bound_args = sig.bind(*args, **kwargs)
        bound_args.apply_defaults()
        partial_fn = functools.partial(fn, **bound_args.arguments)
        return await anyio.to_thread.run_sync(partial_fn)


# parse prompts with arrays
def parse_prompt(prompt: str) -> list[str]:
    arrays = re.findall(r"\[\[(.*?)\]\]", prompt)

    if not arrays:
        return [prompt]

    tokens = [item.split(",") for item in arrays]
    combinations = list(product(*tokens))
    prompts = []

    for combo in combinations:
        current_prompt = prompt
        for i, token in enumerate(combo):
            current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
        prompts.append(current_prompt)
    return prompts


def apply_style(prompt, style_id, negative=False):
    global STYLES
    if not style_id or style_id == "None":
        return prompt
    for style in STYLES:
        if style["id"] == style_id:
            if negative:
                return prompt + " . " + style["negative_prompt"]
            else:
                return style["prompt"].format(prompt=prompt)
    return prompt


def prepare_image(input, size=None):
    image = None
    if isinstance(input, Image.Image):
        image = input
    if isinstance(input, np.ndarray):
        image = Image.fromarray(input)
    if isinstance(input, str):
        if os.path.isfile(input):
            image = Image.open(input)
    if image is not None:
        image = image.convert("RGB")
    if size is not None:
        image = image.resize(size, Image.Resampling.LANCZOS)
    if image is not None:
        return image
    else:
        raise ValueError("Invalid image prompt")


@spaces.GPU(duration=40)
def generate(
    positive_prompt,
    negative_prompt="",
    image_prompt=None,
    ip_image=None,
    ip_face=False,
    embeddings=[],
    style=None,
    seed=None,
    model="runwayml/stable-diffusion-v1-5",
    scheduler="PNDM",
    width=512,
    height=512,
    guidance_scale=7.5,
    inference_steps=50,
    denoising_strength=0.8,
    num_images=1,
    karras=False,
    taesd=False,
    freeu=False,
    clip_skip=False,
    truncate_prompts=False,
    increment_seed=True,
    deepcache=1,
    scale=1,
    Info: Callable[[str], None] = None,
    Error=Exception,
):
    if not torch.cuda.is_available():
        raise Error("CUDA not available")

    # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
    if seed is None or seed < 0:
        seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)

    DEVICE = torch.device("cuda")

    DTYPE = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 8
        else torch.float16
    )

    EMBEDDINGS_TYPE = (
        ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
        if clip_skip
        else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
    )

    KIND = "img2img" if image_prompt is not None else "txt2img"

    IP_ADAPTER = None

    if ip_image:
        IP_ADAPTER = "full-face" if ip_face else "plus"

    with torch.inference_mode():
        start = time.perf_counter()
        loader = Loader()
        pipe, upscaler = loader.load(
            KIND,
            IP_ADAPTER,
            model,
            scheduler,
            karras,
            taesd,
            freeu,
            deepcache,
            scale,
            DEVICE,
            DTYPE,
        )

        # load embeddings and append to negative prompt
        embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
        embeddings_dir = os.path.abspath(embeddings_dir)
        for embedding in embeddings:
            try:
                pipe.load_textual_inversion(
                    pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
                    token=f"<{embedding}>",
                )
                negative_prompt = (
                    f"{negative_prompt}, (<{embedding}>)1.1"
                    if negative_prompt
                    else f"(<{embedding}>)1.1"
                )
            except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
                raise Error(f"Invalid embedding: <{embedding}>")

        # prompt embeds
        compel = Compel(
            device=pipe.device,
            tokenizer=pipe.tokenizer,
            text_encoder=pipe.text_encoder,
            truncate_long_prompts=truncate_prompts,
            dtype_for_device_getter=lambda _: DTYPE,
            returned_embeddings_type=EMBEDDINGS_TYPE,
            textual_inversion_manager=DiffusersTextualInversionManager(pipe),
        )

        images = []
        current_seed = seed

        try:
            styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
            neg_embeds = compel(styled_negative_prompt)
        except PromptParser.ParsingException:
            raise Error("ParsingException: Invalid negative prompt")

        for i in range(num_images):
            # seeded generator for each iteration
            generator = torch.Generator(device=pipe.device).manual_seed(current_seed)

            try:
                all_positive_prompts = parse_prompt(positive_prompt)
                prompt_index = i % len(all_positive_prompts)
                pos_prompt = all_positive_prompts[prompt_index]
                styled_pos_prompt = apply_style(pos_prompt, style)
                pos_embeds = compel(styled_pos_prompt)
                pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
                    [pos_embeds, neg_embeds]
                )
            except PromptParser.ParsingException:
                raise Error("ParsingException: Invalid prompt")

            kwargs = {
                "width": width,
                "height": height,
                "generator": generator,
                "prompt_embeds": pos_embeds,
                "guidance_scale": guidance_scale,
                "negative_prompt_embeds": neg_embeds,
                "num_inference_steps": inference_steps,
                "output_type": "np" if scale > 1 else "pil",
            }

            if KIND == "img2img":
                kwargs["strength"] = denoising_strength
                kwargs["image"] = prepare_image(image_prompt, (width, height))

            if IP_ADAPTER:
                # don't resize full-face images
                size = None if ip_face else (width, height)
                kwargs["ip_adapter_image"] = prepare_image(ip_image, size)

            try:
                image = pipe(**kwargs).images[0]
                if scale > 1:
                    image = upscaler.predict(image)
                images.append((image, str(current_seed)))
            finally:
                pipe.unload_textual_inversion()
                torch.cuda.empty_cache()

            if increment_seed:
                current_seed += 1

        diff = time.perf_counter() - start
        if Info:
            Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
        return images