joyson commited on
Commit
9d9968c
Β·
verified Β·
1 Parent(s): cf8c8d1

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +169 -0
  2. image_to_image.py +60 -0
  3. image_to_text.py +29 -0
  4. text_to_image.py +55 -0
  5. utils.py +21 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import asyncio
4
+
5
+ from text_to_image import TextToImage
6
+ from image_to_text import ImageToText
7
+ from image_to_image import ImageToImage
8
+
9
+ # ============================================
10
+ # Initialize Model Classes
11
+ # ============================================
12
+
13
+ text_to_image = TextToImage()
14
+ image_to_text = ImageToText()
15
+ image_to_image = ImageToImage()
16
+
17
+ # ============================================
18
+ # Gradio Interface Functions with Async and Error Handling
19
+ # ============================================
20
+
21
+ async def async_text_to_image(prompt):
22
+ """
23
+ Asynchronous interface function for Text-to-Image generation with error handling.
24
+ """
25
+ try:
26
+ image = await text_to_image.generate_image(prompt)
27
+ return image
28
+ except Exception as e:
29
+ raise gr.Error(f"Text-to-Image Generation Failed: {str(e)}")
30
+
31
+ async def async_image_to_text(image):
32
+ """
33
+ Asynchronous interface function for Image-to-Text captioning with error handling.
34
+ """
35
+ try:
36
+ caption = await image_to_text.generate_caption(image)
37
+ return caption
38
+ except Exception as e:
39
+ raise gr.Error(f"Image-to-Text Captioning Failed: {str(e)}")
40
+
41
+ async def async_image_to_image(image, prompt):
42
+ """
43
+ Asynchronous interface function for Image-to-Image transformation with error handling.
44
+ """
45
+ try:
46
+ transformed_image = await image_to_image.transform_image(image, prompt)
47
+ return transformed_image
48
+ except Exception as e:
49
+ raise gr.Error(f"Image-to-Image Transformation Failed: {str(e)}")
50
+
51
+ # ============================================
52
+ # Gradio UI Design
53
+ # ============================================
54
+
55
+ with gr.Blocks(css=".gradio-container {background-color: #f0f8ff}") as demo:
56
+ # Title Section
57
+ gr.Markdown("# 🎨 AI Creativity Hub πŸš€")
58
+ gr.Markdown("### Unleash the power of AI to transform your ideas into reality!")
59
+
60
+ # Task Selection Radio
61
+ with gr.Tab("✨ Choose Your Magic ✨"):
62
+ task = gr.Radio(
63
+ ["πŸ–ΌοΈ Text-to-Image", "πŸ“ Image-to-Text", "πŸ–ŒοΈ Image-to-Image"],
64
+ label="Select a Task",
65
+ interactive=True,
66
+ value="πŸ–ΌοΈ Text-to-Image"
67
+ )
68
+
69
+ # Text-to-Image Section
70
+ with gr.Row(visible=False) as text_to_image_tab:
71
+ with gr.Column():
72
+ gr.Markdown("## πŸ–ΌοΈ Text-to-Image Generator")
73
+ prompt_input = gr.Textbox(
74
+ label="πŸ“ Enter your prompt:",
75
+ placeholder="e.g., A serene sunset over the mountains",
76
+ lines=2
77
+ )
78
+ generate_btn = gr.Button("🎨 Generate Image")
79
+ with gr.Row():
80
+ output_image = gr.Image(label="πŸ–ΌοΈ Generated Image")
81
+ download_btn = gr.Button("πŸ“₯ Download Image")
82
+
83
+ # Image-to-Text Section
84
+ with gr.Row(visible=False) as image_to_text_tab:
85
+ with gr.Column():
86
+ gr.Markdown("## πŸ“ Image-to-Text Captioning")
87
+ image_input = gr.Image(
88
+ label="πŸ“Έ Upload an image:",
89
+ type="pil"
90
+ )
91
+ generate_caption_btn = gr.Button("πŸ–‹οΈ Generate Caption")
92
+ caption_output = gr.Textbox(
93
+ label="πŸ“ Generated Caption:",
94
+ lines=2
95
+ )
96
+
97
+ # Image-to-Image Section
98
+ with gr.Row(visible=False) as image_to_image_tab:
99
+ with gr.Column():
100
+ gr.Markdown("## πŸ–ŒοΈ Image-to-Image Transformer")
101
+ init_image_input = gr.Image(
102
+ label="πŸ“Έ Upload an image:",
103
+ type="pil"
104
+ )
105
+ transformation_prompt = gr.Textbox(
106
+ label="πŸ“ Enter transformation prompt:",
107
+ placeholder="e.g., Make it look like a Van Gogh painting",
108
+ lines=2
109
+ )
110
+ transform_btn = gr.Button("πŸ”„ Transform Image")
111
+ with gr.Row():
112
+ transformed_image = gr.Image(label="πŸ–ŒοΈ Transformed Image")
113
+ download_transformed_btn = gr.Button("πŸ“₯ Download Image")
114
+
115
+ # Define Visibility Based on Task Selection
116
+ def toggle_visibility(selected_task):
117
+ return {
118
+ text_to_image_tab: selected_task == "πŸ–ΌοΈ Text-to-Image",
119
+ image_to_text_tab: selected_task == "πŸ“ Image-to-Text",
120
+ image_to_image_tab: selected_task == "πŸ–ŒοΈ Image-to-Image",
121
+ }
122
+
123
+ task.change(
124
+ fn=toggle_visibility,
125
+ inputs=task,
126
+ outputs=[text_to_image_tab, image_to_text_tab, image_to_image_tab]
127
+ )
128
+
129
+ # Define Button Actions
130
+ generate_btn.click(
131
+ fn=async_text_to_image,
132
+ inputs=prompt_input,
133
+ outputs=output_image
134
+ )
135
+
136
+ download_btn.click(
137
+ fn=lambda img: img.save("generated_image.png") or "Image downloaded!",
138
+ inputs=output_image,
139
+ outputs=None
140
+ )
141
+
142
+ generate_caption_btn.click(
143
+ fn=async_image_to_text,
144
+ inputs=image_input,
145
+ outputs=caption_output
146
+ )
147
+
148
+ transform_btn.click(
149
+ fn=async_image_to_image,
150
+ inputs=[init_image_input, transformation_prompt],
151
+ outputs=transformed_image
152
+ )
153
+
154
+ download_transformed_btn.click(
155
+ fn=lambda img: img.save("transformed_image.png") or "Image downloaded!",
156
+ inputs=transformed_image,
157
+ outputs=None
158
+ )
159
+
160
+ # Footer Section with Quirky Elements
161
+ gr.Markdown("----")
162
+ gr.Markdown("### 🌟 Explore the endless possibilities with AI! 🌟")
163
+ gr.Markdown("#### πŸš€ Built with ❀️ by [Your Name]")
164
+
165
+ # ============================================
166
+ # Launch the Gradio App
167
+ # ============================================
168
+
169
+ demo.launch()
image_to_image.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionXLImg2ImgPipeline, EulerDiscreteScheduler
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from utils import load_unet_model
6
+
7
+ class ImageToImage:
8
+ """
9
+ Class to handle Image-to-Image transformations using Stable Diffusion XL.
10
+ """
11
+ def __init__(self, device="cpu"):
12
+ # Model and repository details
13
+ self.base = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ self.repo = "ByteDance/SDXL-Lightning"
15
+ self.ckpt = "sdxl_lightning_4step_unet.safetensors"
16
+ self.device = device
17
+
18
+ # Load the UNet model
19
+ print("Loading Image-to-Image model...")
20
+ self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
21
+
22
+ # Initialize the pipeline
23
+ self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
24
+ self.base,
25
+ unet=self.unet,
26
+ torch_dtype=torch.float32
27
+ ).to(self.device)
28
+
29
+ # Set the scheduler
30
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
31
+ self.pipe.scheduler.config,
32
+ timestep_spacing="trailing"
33
+ )
34
+ print("Image-to-Image model loaded successfully.")
35
+
36
+
37
+ async def transform_image(self, image, prompt):
38
+ """
39
+ Transform an uploaded image based on a text prompt.
40
+
41
+ Args:
42
+ image (PIL.Image): The input image to transform.
43
+ prompt (str): The text prompt to guide the transformation.
44
+
45
+ Returns:
46
+ PIL.Image: The transformed image.
47
+ """
48
+ if not prompt:
49
+ raise ValueError("Prompt cannot be empty.")
50
+
51
+ # Resize the image as required by the model
52
+ init_image = image.resize((768, 512))
53
+ with torch.no_grad():
54
+ transformed_image = self.pipe(
55
+ prompt=prompt,
56
+ image=init_image,
57
+ strength=0.75,
58
+ guidance_scale=7.5
59
+ ).images[0]
60
+ return transformed_image
image_to_text.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
+ from PIL import Image
3
+
4
+ class ImageToText:
5
+ """
6
+ Class to handle Image-to-Text captioning using BLIP.
7
+ """
8
+ def __init__(self):
9
+ # Initialize the processor and model
10
+ print("Loading Image-to-Text model...")
11
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
12
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
13
+ print("Image-to-Text model loaded successfully.")
14
+
15
+
16
+ async def generate_caption(self, image):
17
+ """
18
+ Generate a descriptive caption for an uploaded image.
19
+
20
+ Args:
21
+ image (PIL.Image): The image to caption.
22
+
23
+ Returns:
24
+ str: The generated caption.
25
+ """
26
+ inputs = self.processor(image, return_tensors="pt")
27
+ out = self.model.generate(**inputs)
28
+ caption = self.processor.decode(out[0], skip_special_tokens=True)
29
+ return caption
text_to_image.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from utils import load_unet_model
7
+
8
+ @spaces.GPU
9
+ class TextToImage:
10
+ """
11
+ Class to handle Text-to-Image generation using Stable Diffusion XL.
12
+ """
13
+ def __init__(self, device="cpu"):
14
+ # Model and repository details
15
+ self.base = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ self.repo = "ByteDance/SDXL-Lightning"
17
+ self.ckpt = "sdxl_lightning_4step_unet.safetensors"
18
+ self.device = device
19
+
20
+ # Load the UNet model
21
+ print("Loading Text-to-Image model...")
22
+ self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
23
+
24
+ # Initialize the pipeline
25
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
26
+ self.base,
27
+ unet=self.unet,
28
+ torch_dtype=torch.float32,
29
+ ).to(self.device)
30
+
31
+ # Set the scheduler
32
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
33
+ self.pipe.scheduler.config,
34
+ timestep_spacing="trailing"
35
+ )
36
+ print("Text-to-Image model loaded successfully.")
37
+
38
+
39
+ async def generate_image(self, prompt):
40
+ """
41
+ Generate an image from a text prompt.
42
+
43
+ Args:
44
+ prompt (str): The text prompt to generate the image.
45
+
46
+ Returns:
47
+ PIL.Image: The generated image.
48
+ """
49
+ with torch.no_grad():
50
+ image = self.pipe(
51
+ prompt,
52
+ num_inference_steps=4,
53
+ guidance_scale=0
54
+ ).images[0]
55
+ return image
utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+ from safetensors.torch import load_file
4
+
5
+ def load_unet_model(base, repo, ckpt, device="cpu"):
6
+ """
7
+ Load the UNet model from Hugging Face Hub.
8
+
9
+ Args:
10
+ base (str): Base model name.
11
+ repo (str): Repository name.
12
+ ckpt (str): Checkpoint filename.
13
+ device (str): Device to load the model on.
14
+
15
+ Returns:
16
+ UNet2DConditionModel: Loaded UNet model.
17
+ """
18
+ from diffusers import UNet2DConditionModel
19
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
20
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
21
+ return unet