aidevhund commited on
Commit
775ca8a
·
verified ·
1 Parent(s): c441619

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -32
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  from transformers import pipeline
 
5
  import os
6
  import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
@@ -16,6 +17,7 @@ from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
@@ -30,21 +32,26 @@ try:
30
  except:
31
  pass
32
 
 
33
  def start_session(req: gr.Request):
34
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
  print(f'Creating user directory: {user_dir}')
36
  os.makedirs(user_dir, exist_ok=True)
37
-
 
38
  def end_session(req: gr.Request):
39
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
40
  print(f'Removing user directory: {user_dir}')
41
  shutil.rmtree(user_dir)
42
 
 
43
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
44
  """
45
  Preprocess the input image.
 
46
  Args:
47
  image (Image.Image): The input image.
 
48
  Returns:
49
  str: uuid of the trial.
50
  Image.Image: The preprocessed image.
@@ -52,6 +59,7 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
52
  processed_image = pipeline.preprocess_image(image)
53
  return processed_image
54
 
 
55
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
56
  return {
57
  'gaussian': {
@@ -68,7 +76,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
68
  },
69
  'trial_id': trial_id,
70
  }
71
-
 
72
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
73
  gs = Gaussian(
74
  aabb=state['gaussian']['aabb'],
@@ -91,12 +100,14 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
91
 
92
  return gs, mesh, state['trial_id']
93
 
 
94
  def get_seed(randomize_seed: bool, seed: int) -> int:
95
  """
96
  Get the random seed.
97
  """
98
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
99
 
 
100
  @spaces.GPU
101
  def image_to_3d(
102
  image: Image.Image,
@@ -109,6 +120,7 @@ def image_to_3d(
109
  ) -> Tuple[dict, str]:
110
  """
111
  Convert an image to a 3D model.
 
112
  Args:
113
  image (Image.Image): The input image.
114
  seed (int): The random seed.
@@ -116,6 +128,7 @@ def image_to_3d(
116
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
117
  slat_guidance_strength (float): The guidance strength for structured latent generation.
118
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
119
  Returns:
120
  dict: The information of the generated 3D model.
121
  str: The path to the video of the 3D model.
@@ -145,6 +158,7 @@ def image_to_3d(
145
  torch.cuda.empty_cache()
146
  return state, video_path
147
 
 
148
  @spaces.GPU
149
  def extract_glb(
150
  state: dict,
@@ -154,10 +168,12 @@ def extract_glb(
154
  ) -> Tuple[str, str]:
155
  """
156
  Extract a GLB file from the 3D model.
 
157
  Args:
158
  state (dict): The state of the generated 3D model.
159
  mesh_simplify (float): The mesh simplification factor.
160
  texture_size (int): The texture resolution.
 
161
  Returns:
162
  str: The path to the extracted GLB file.
163
  """
@@ -169,55 +185,123 @@ def extract_glb(
169
  torch.cuda.empty_cache()
170
  return glb_path, glb_path
171
 
172
- with gr.Blocks(delete_cache=(600, 600)) as demo:
173
- gr.Markdown("""<h1 style='color: #4CAF50; text-align: center;'>Image to 3D Model Generator</h1>""")
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Row():
176
- with gr.Column():
177
- image_prompt = gr.Image(
178
- label="Image Prompt",
179
- format="png",
180
- image_mode="RGBA",
181
- type="pil",
182
- height=300,
183
- style={"border": "2px solid #4CAF50", "border-radius": "10px", "background-color": "#F9F9F9"}
184
- )
185
-
186
- with gr.Accordion(label="Generation Settings", open=False, style={"border": "1px solid #2196F3", "border-radius": "8px"}):
187
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, style={"color": "#2196F3"})
188
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True, style={"color": "#2196F3"})
189
- gr.Markdown("<strong>Stage 1: Sparse Structure Generation</strong>", style={"color": "#2196F3"})
190
  with gr.Row():
191
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
192
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
193
- gr.Markdown("<strong>Stage 2: Structured Latent Generation</strong>", style={"color": "#2196F3"})
194
  with gr.Row():
195
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
196
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
197
 
198
- generate_btn = gr.Button("Generate", style={"background-color": "#4CAF50", "color": "white", "border": "none", "border-radius": "8px"})
199
-
200
- with gr.Accordion(label="GLB Extraction Settings", open=False, style={"border": "1px solid #2196F3", "border-radius": "8px"}):
201
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
202
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
 
203
 
204
- extract_glb_btn = gr.Button("Extract GLB", interactive=False, style={"background-color": "#FF9800", "color": "white", "border-radius": "8px"})
 
 
 
 
 
205
 
206
- with gr.Column():
207
- video_output = gr.Video(label="Generated 3D Model", interactive=True, style={"border-radius": "10px"})
208
- glb_output = gr.File(label="Download GLB", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # Trigger functions for generating 3D model and extracting GLB
211
  generate_btn.click(
 
 
 
 
212
  image_to_3d,
213
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, gr.Request()],
214
- outputs=[glb_output, video_output]
 
 
 
 
 
 
 
 
215
  )
216
 
217
  extract_glb_btn.click(
218
  extract_glb,
219
- inputs=[state, mesh_simplify, texture_size, gr.Request()],
220
- outputs=[glb_output, glb_output]
 
 
 
 
 
 
 
 
221
  )
 
222
 
223
- demo.launch() # Start the demo
 
 
 
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
  from transformers import pipeline
5
+
6
  import os
7
  import shutil
8
  os.environ['SPCONV_ALGO'] = 'native'
 
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
 
20
+
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
 
32
  except:
33
  pass
34
 
35
+
36
  def start_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  print(f'Creating user directory: {user_dir}')
39
  os.makedirs(user_dir, exist_ok=True)
40
+
41
+
42
  def end_session(req: gr.Request):
43
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
44
  print(f'Removing user directory: {user_dir}')
45
  shutil.rmtree(user_dir)
46
 
47
+
48
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
49
  """
50
  Preprocess the input image.
51
+
52
  Args:
53
  image (Image.Image): The input image.
54
+
55
  Returns:
56
  str: uuid of the trial.
57
  Image.Image: The preprocessed image.
 
59
  processed_image = pipeline.preprocess_image(image)
60
  return processed_image
61
 
62
+
63
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
64
  return {
65
  'gaussian': {
 
76
  },
77
  'trial_id': trial_id,
78
  }
79
+
80
+
81
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
82
  gs = Gaussian(
83
  aabb=state['gaussian']['aabb'],
 
100
 
101
  return gs, mesh, state['trial_id']
102
 
103
+
104
  def get_seed(randomize_seed: bool, seed: int) -> int:
105
  """
106
  Get the random seed.
107
  """
108
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
109
 
110
+
111
  @spaces.GPU
112
  def image_to_3d(
113
  image: Image.Image,
 
120
  ) -> Tuple[dict, str]:
121
  """
122
  Convert an image to a 3D model.
123
+
124
  Args:
125
  image (Image.Image): The input image.
126
  seed (int): The random seed.
 
128
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
129
  slat_guidance_strength (float): The guidance strength for structured latent generation.
130
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
131
+
132
  Returns:
133
  dict: The information of the generated 3D model.
134
  str: The path to the video of the 3D model.
 
158
  torch.cuda.empty_cache()
159
  return state, video_path
160
 
161
+
162
  @spaces.GPU
163
  def extract_glb(
164
  state: dict,
 
168
  ) -> Tuple[str, str]:
169
  """
170
  Extract a GLB file from the 3D model.
171
+
172
  Args:
173
  state (dict): The state of the generated 3D model.
174
  mesh_simplify (float): The mesh simplification factor.
175
  texture_size (int): The texture resolution.
176
+
177
  Returns:
178
  str: The path to the extracted GLB file.
179
  """
 
185
  torch.cuda.empty_cache()
186
  return glb_path, glb_path
187
 
 
 
188
 
189
+ with open("styles.css", "w") as f:
190
+ f.write("""
191
+ .gradio-container {
192
+ background-color: #f0f0f0;
193
+ font-family: sans-serif;
194
+ }
195
+ .my-button {
196
+ background-color: #4CAF50;
197
+ color: white;
198
+ padding: 10px 20px;
199
+ border: none;
200
+ border-radius: 5px;
201
+ cursor: pointer;
202
+ }
203
+ .my-button:hover {
204
+ background-color: #3e8e41;
205
+ }
206
+ .gradio-container h1{
207
+ text-align: center;
208
+ color: #333;
209
+ margin-bottom: 20px;
210
+ }
211
+ """)
212
+
213
+
214
+ with gr.Blocks(css=open("styles.css").read(), delete_cache=(600, 600)) as demo:
215
+ gr.Markdown("# 3D Model Generator")
216
+
217
  with gr.Row():
218
+ with gr.Column(style={"padding": "20px"}):
219
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
220
+
221
+ with gr.Accordion(label="Generation Settings", open=False):
222
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
223
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
224
+ gr.Markdown("Stage 1: Sparse Structure Generation")
 
 
 
 
 
 
 
225
  with gr.Row():
226
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
227
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
228
+ gr.Markdown("Stage 2: Structured Latent Generation")
229
  with gr.Row():
230
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
231
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
232
 
233
+ generate_btn = gr.Button("Generate", style={"classes": ["my-button"]})
234
+
235
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
236
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
237
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
238
+
239
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
240
 
241
+ with gr.Column(style={"padding": "20px"}):
242
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
243
+ model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
244
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
245
+
246
+ output_buf = gr.State()
247
 
248
+ # Example images at the bottom of the page
249
+ with gr.Row():
250
+ examples = gr.Examples(
251
+ examples=[
252
+ f'assets/example_image/{image}'
253
+ for image in os.listdir("assets/example_image")
254
+ ],
255
+ inputs=[image_prompt],
256
+ fn=preprocess_image,
257
+ outputs=[image_prompt],
258
+ run_on_click=True,
259
+ examples_per_page=64,
260
+ )
261
+
262
+ # Handlers
263
+ demo.load(start_session)
264
+ demo.unload(end_session)
265
+
266
+ image_prompt.upload(
267
+ preprocess_image,
268
+ inputs=[image_prompt],
269
+ outputs=[image_prompt],
270
+ )
271
 
 
272
  generate_btn.click(
273
+ get_seed,
274
+ inputs=[randomize_seed, seed],
275
+ outputs=[seed],
276
+ ).then(
277
  image_to_3d,
278
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
279
+ outputs=[output_buf, video_output],
280
+ ).then(
281
+ lambda: gr.Button(interactive=True),
282
+ outputs=[extract_glb_btn],
283
+ )
284
+
285
+ video_output.clear(
286
+ lambda: gr.Button(interactive=False),
287
+ outputs=[extract_glb_btn],
288
  )
289
 
290
  extract_glb_btn.click(
291
  extract_glb,
292
+ inputs=[output_buf, mesh_simplify, texture_size],
293
+ outputs=[model_output, download_glb],
294
+ ).then(
295
+ lambda: gr.Button(interactive=True),
296
+ outputs=[download_glb],
297
+ )
298
+
299
+ model_output.clear(
300
+ lambda: gr.Button(interactive=False),
301
+ outputs=[download_glb],
302
  )
303
+
304
 
305
+ # Launch the Gradio app
306
+ if __name__ == "__main__":
307
+ demo.launch()