File size: 4,602 Bytes
256da70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6e7cb
256da70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb6e7cb
256da70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from enum import Enum, auto

import torch
from huggingface_hub import (  # pyright: ignore[reportMissingTypeStubs]
    hf_hub_download,  # pyright: ignore[reportUnknownVariableType]
)
from PIL import Image
from refiners.fluxion.utils import load_from_safetensors, tensor_to_image
from refiners.foundationals.clip import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight


def load_ic_light(device: torch.device, dtype: torch.dtype) -> ICLight:
    return ICLight(
        patch_weights=load_from_safetensors(
            path=hf_hub_download(
                repo_id="refiners/sd15.ic_light.fc",
                filename="model.safetensors",
                revision="ea10b4403e97c786a98afdcbdf0e0fec794ea542",
            ),
        ),
        unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.unet",
                filename="model.safetensors",
                revision="94f74be7adfd27bee330ea1071481c0254c29989",
            )
        ),
        clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.text_encoder",
                filename="model.safetensors",
                revision="7f6fa1e870c8f197d34488e14b89e63fb8d7fd6e",
            )
        ),
        lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.autoencoder",
                filename="model.safetensors",
                revision="99f089787a6e1a852a0992da1e286a19fcbbaa50",
            )
        ),
        device=device,
        dtype=dtype,
    )


def resize_modulo_8(
    image: Image.Image,
    size: int = 768,
    resample: Image.Resampling | None = None,
    on_short: bool = True,
) -> Image.Image:
    """Resize an image respecting the aspect ratio and ensuring the size is a multiple of 8.

    The `on_short` parameter determines whether the resizing is based on the shortest side.
    """
    assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size."
    side_size = min(image.size) if on_short else max(image.size)
    scale = size / (side_size * 8)
    new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8)
    return image.resize(new_size, resample=resample or Image.Resampling.LANCZOS)


class LightingPreference(str, Enum):
    LEFT = auto()
    RIGHT = auto()
    TOP = auto()
    BOTTOM = auto()
    NONE = auto()

    def get_init_image(self, width: int, height: int, interval: tuple[float, float] = (0.0, 1.0)) -> Image.Image | None:
        """Generate an image with a linear gradient based on the lighting preference.

        In the original code, interval is always (0., 1.) ; we added it as a parameter to make the function more
        flexible and allow for less contrasted images with a smaller interval.
        see https://github.com/lllyasviel/IC-Light/blob/7886874/gradio_demo.py#L242
        """
        start, end = interval
        match self:
            case LightingPreference.LEFT:
                tensor = torch.linspace(end, start, width).repeat(1, 1, height, 1)
            case LightingPreference.RIGHT:
                tensor = torch.linspace(start, end, width).repeat(1, 1, height, 1)
            case LightingPreference.TOP:
                tensor = torch.linspace(end, start, height).repeat(1, 1, width, 1).transpose(2, 3)
            case LightingPreference.BOTTOM:
                tensor = torch.linspace(start, end, height).repeat(1, 1, width, 1).transpose(2, 3)
            case LightingPreference.NONE:
                return None

        return tensor_to_image(tensor).convert("RGB")

    @classmethod
    def from_str(cls, value: str):
        match value.lower():
            case "left":
                return LightingPreference.LEFT
            case "right":
                return LightingPreference.RIGHT
            case "top":
                return LightingPreference.TOP
            case "bottom":
                return LightingPreference.BOTTOM
            case "none":
                return LightingPreference.NONE
            case _:
                raise ValueError(f"Invalid lighting preference: {value}")