aidevhund commited on
Commit
e80e051
·
verified ·
1 Parent(s): 775ca8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -58
app.py CHANGED
@@ -17,7 +17,6 @@ from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
 
20
-
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
@@ -32,7 +31,6 @@ try:
32
  except:
33
  pass
34
 
35
-
36
  def start_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  print(f'Creating user directory: {user_dir}')
@@ -44,7 +42,6 @@ def end_session(req: gr.Request):
44
  print(f'Removing user directory: {user_dir}')
45
  shutil.rmtree(user_dir)
46
 
47
-
48
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
49
  """
50
  Preprocess the input image.
@@ -59,7 +56,6 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
59
  processed_image = pipeline.preprocess_image(image)
60
  return processed_image
61
 
62
-
63
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
64
  return {
65
  'gaussian': {
@@ -100,14 +96,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
100
 
101
  return gs, mesh, state['trial_id']
102
 
103
-
104
  def get_seed(randomize_seed: bool, seed: int) -> int:
105
  """
106
  Get the random seed.
107
  """
108
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
109
 
110
-
111
  @spaces.GPU
112
  def image_to_3d(
113
  image: Image.Image,
@@ -158,7 +152,6 @@ def image_to_3d(
158
  torch.cuda.empty_cache()
159
  return state, video_path
160
 
161
-
162
  @spaces.GPU
163
  def extract_glb(
164
  state: dict,
@@ -185,39 +178,14 @@ def extract_glb(
185
  torch.cuda.empty_cache()
186
  return glb_path, glb_path
187
 
 
 
 
188
 
189
- with open("styles.css", "w") as f:
190
- f.write("""
191
- .gradio-container {
192
- background-color: #f0f0f0;
193
- font-family: sans-serif;
194
- }
195
- .my-button {
196
- background-color: #4CAF50;
197
- color: white;
198
- padding: 10px 20px;
199
- border: none;
200
- border-radius: 5px;
201
- cursor: pointer;
202
- }
203
- .my-button:hover {
204
- background-color: #3e8e41;
205
- }
206
- .gradio-container h1{
207
- text-align: center;
208
- color: #333;
209
- margin-bottom: 20px;
210
- }
211
- """)
212
-
213
-
214
- with gr.Blocks(css=open("styles.css").read(), delete_cache=(600, 600)) as demo:
215
- gr.Markdown("# 3D Model Generator")
216
-
217
  with gr.Row():
218
- with gr.Column(style={"padding": "20px"}):
219
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
220
-
221
  with gr.Accordion(label="Generation Settings", open=False):
222
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
223
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -230,39 +198,25 @@ with gr.Blocks(css=open("styles.css").read(), delete_cache=(600, 600)) as demo:
230
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
231
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
232
 
233
- generate_btn = gr.Button("Generate", style={"classes": ["my-button"]})
234
-
235
  with gr.Accordion(label="GLB Extraction Settings", open=False):
236
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
237
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
238
-
239
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
240
 
241
- with gr.Column(style={"padding": "20px"}):
242
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
243
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
244
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
245
-
246
- output_buf = gr.State()
247
 
248
- # Example images at the bottom of the page
249
- with gr.Row():
250
- examples = gr.Examples(
251
- examples=[
252
- f'assets/example_image/{image}'
253
- for image in os.listdir("assets/example_image")
254
- ],
255
- inputs=[image_prompt],
256
- fn=preprocess_image,
257
- outputs=[image_prompt],
258
- run_on_click=True,
259
- examples_per_page=64,
260
- )
261
 
262
  # Handlers
263
  demo.load(start_session)
264
  demo.unload(end_session)
265
-
266
  image_prompt.upload(
267
  preprocess_image,
268
  inputs=[image_prompt],
@@ -300,7 +254,6 @@ with gr.Blocks(css=open("styles.css").read(), delete_cache=(600, 600)) as demo:
300
  lambda: gr.Button(interactive=False),
301
  outputs=[download_glb],
302
  )
303
-
304
 
305
  # Launch the Gradio app
306
  if __name__ == "__main__":
 
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
 
31
  except:
32
  pass
33
 
 
34
  def start_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
  print(f'Creating user directory: {user_dir}')
 
42
  print(f'Removing user directory: {user_dir}')
43
  shutil.rmtree(user_dir)
44
 
 
45
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
46
  """
47
  Preprocess the input image.
 
56
  processed_image = pipeline.preprocess_image(image)
57
  return processed_image
58
 
 
59
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
60
  return {
61
  'gaussian': {
 
96
 
97
  return gs, mesh, state['trial_id']
98
 
 
99
  def get_seed(randomize_seed: bool, seed: int) -> int:
100
  """
101
  Get the random seed.
102
  """
103
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
104
 
 
105
  @spaces.GPU
106
  def image_to_3d(
107
  image: Image.Image,
 
152
  torch.cuda.empty_cache()
153
  return state, video_path
154
 
 
155
  @spaces.GPU
156
  def extract_glb(
157
  state: dict,
 
178
  torch.cuda.empty_cache()
179
  return glb_path, glb_path
180
 
181
+ # UI Initialization with Custom CSS
182
+ with gr.Blocks(css="style.css") as demo:
183
+ gr.Markdown("""# 3D Model Generator with Enhanced UI""")
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
+ with gr.Column():
187
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
188
+
189
  with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
198
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
199
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
200
 
201
+ generate_btn = gr.Button("Generate")
202
+
203
  with gr.Accordion(label="GLB Extraction Settings", open=False):
204
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
+
207
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
208
 
209
+ with gr.Column():
210
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
211
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
212
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
 
 
213
 
214
+ output_buf = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  # Handlers
217
  demo.load(start_session)
218
  demo.unload(end_session)
219
+
220
  image_prompt.upload(
221
  preprocess_image,
222
  inputs=[image_prompt],
 
254
  lambda: gr.Button(interactive=False),
255
  outputs=[download_glb],
256
  )
 
257
 
258
  # Launch the Gradio app
259
  if __name__ == "__main__":