aidevhund commited on
Commit
e59aaec
·
verified ·
1 Parent(s): 0c615e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -17,12 +17,13 @@ from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
  # Initialize pipeline here
25
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
26
  pipeline.cuda()
27
 
28
  # Preload rembg (optional)
@@ -31,6 +32,7 @@ try:
31
  except:
32
  pass
33
 
 
34
  def start_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
  print(f'Creating user directory: {user_dir}')
@@ -42,6 +44,7 @@ def end_session(req: gr.Request):
42
  print(f'Removing user directory: {user_dir}')
43
  shutil.rmtree(user_dir)
44
 
 
45
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
46
  """
47
  Preprocess the input image.
@@ -56,6 +59,7 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
56
  processed_image = pipeline.preprocess_image(image)
57
  return processed_image
58
 
 
59
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
60
  return {
61
  'gaussian': {
@@ -96,12 +100,14 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
96
 
97
  return gs, mesh, state['trial_id']
98
 
 
99
  def get_seed(randomize_seed: bool, seed: int) -> int:
100
  """
101
  Get the random seed.
102
  """
103
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
104
 
 
105
  @spaces.GPU
106
  def image_to_3d(
107
  image: Image.Image,
@@ -152,6 +158,7 @@ def image_to_3d(
152
  torch.cuda.empty_cache()
153
  return state, video_path
154
 
 
155
  @spaces.GPU
156
  def extract_glb(
157
  state: dict,
@@ -178,14 +185,14 @@ def extract_glb(
178
  torch.cuda.empty_cache()
179
  return glb_path, glb_path
180
 
181
- # UI Initialization with Custom CSS
182
- with gr.Blocks(css="style.css") as demo:
183
- gr.Markdown("""# 3D Model Generator with Enhanced UI""")
184
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
188
-
189
  with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -198,25 +205,39 @@ with gr.Blocks(css="style.css") as demo:
198
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
199
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
200
 
201
- generate_btn = gr.Button("Generate")
202
-
203
  with gr.Accordion(label="GLB Extraction Settings", open=False):
204
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
-
207
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
208
 
209
  with gr.Column():
210
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
211
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
212
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
213
-
214
  output_buf = gr.State()
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # Handlers
217
  demo.load(start_session)
218
  demo.unload(end_session)
219
-
220
  image_prompt.upload(
221
  preprocess_image,
222
  inputs=[image_prompt],
@@ -254,7 +275,8 @@ with gr.Blocks(css="style.css") as demo:
254
  lambda: gr.Button(interactive=False),
255
  outputs=[download_glb],
256
  )
 
257
 
258
  # Launch the Gradio app
259
  if __name__ == "__main__":
260
- demo.launch()
 
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
 
20
+
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
  # Initialize pipeline here
26
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("aidevhund/3dgameasset")
27
  pipeline.cuda()
28
 
29
  # Preload rembg (optional)
 
32
  except:
33
  pass
34
 
35
+
36
  def start_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  print(f'Creating user directory: {user_dir}')
 
44
  print(f'Removing user directory: {user_dir}')
45
  shutil.rmtree(user_dir)
46
 
47
+
48
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
49
  """
50
  Preprocess the input image.
 
59
  processed_image = pipeline.preprocess_image(image)
60
  return processed_image
61
 
62
+
63
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
64
  return {
65
  'gaussian': {
 
100
 
101
  return gs, mesh, state['trial_id']
102
 
103
+
104
  def get_seed(randomize_seed: bool, seed: int) -> int:
105
  """
106
  Get the random seed.
107
  """
108
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
109
 
110
+
111
  @spaces.GPU
112
  def image_to_3d(
113
  image: Image.Image,
 
158
  torch.cuda.empty_cache()
159
  return state, video_path
160
 
161
+
162
  @spaces.GPU
163
  def extract_glb(
164
  state: dict,
 
185
  torch.cuda.empty_cache()
186
  return glb_path, glb_path
187
 
 
 
 
188
 
189
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
190
+ gr.Markdown("""#HundAI 3D Game Asset Creator""")
191
+
192
  with gr.Row():
193
  with gr.Column():
194
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
195
+
196
  with gr.Accordion(label="Generation Settings", open=False):
197
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
198
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
205
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
206
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
207
 
208
+ generate_btn = gr.Button("Create")
209
+
210
  with gr.Accordion(label="GLB Extraction Settings", open=False):
211
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
212
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
213
+
214
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
215
 
216
  with gr.Column():
217
+ video_output = gr.Video(label="Created 3D Asset", autoplay=True, loop=True, height=300)
218
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
219
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
220
+
221
  output_buf = gr.State()
222
 
223
+ # Example images at the bottom of the page
224
+ with gr.Row():
225
+ examples = gr.Examples(
226
+ examples=[
227
+ f'assets/example_image/{image}'
228
+ for image in os.listdir("assets/example_image")
229
+ ],
230
+ inputs=[image_prompt],
231
+ fn=preprocess_image,
232
+ outputs=[image_prompt],
233
+ run_on_click=True,
234
+ examples_per_page=5,
235
+ )
236
+
237
  # Handlers
238
  demo.load(start_session)
239
  demo.unload(end_session)
240
+
241
  image_prompt.upload(
242
  preprocess_image,
243
  inputs=[image_prompt],
 
275
  lambda: gr.Button(interactive=False),
276
  outputs=[download_glb],
277
  )
278
+
279
 
280
  # Launch the Gradio app
281
  if __name__ == "__main__":
282
+ demo.launch()