hysts HF staff commited on
Commit
e2e6064
·
1 Parent(s): 68ff7c6

Use diffusers

Browse files
Files changed (3) hide show
  1. app_image_to_3d.py +1 -3
  2. model.py +34 -117
  3. requirements.txt +2 -2
app_image_to_3d.py CHANGED
@@ -24,9 +24,7 @@ def create_demo(model: Model) -> gr.Blocks:
24
 
25
  with gr.Blocks() as demo:
26
  with gr.Box():
27
- image = gr.Image(label='Input image',
28
- show_label=False,
29
- type='filepath')
30
  run_button = gr.Button('Run')
31
  result = gr.Model3D(label='Result', show_label=False)
32
  with gr.Accordion('Advanced options', open=False):
 
24
 
25
  with gr.Blocks() as demo:
26
  with gr.Box():
27
+ image = gr.Image(label='Input image', show_label=False, type='pil')
 
 
28
  run_button = gr.Button('Run')
29
  result = gr.Model3D(label='Result', show_label=False)
30
  with gr.Accordion('Advanced options', open=False):
model.py CHANGED
@@ -1,99 +1,33 @@
1
  import tempfile
2
 
3
  import numpy as np
 
4
  import torch
5
  import trimesh
6
- from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
7
- from shap_e.diffusion.sample import sample_latents
8
- from shap_e.models.download import load_config, load_model
9
- from shap_e.models.nn.camera import (DifferentiableCameraBatch,
10
- DifferentiableProjectiveCamera)
11
- from shap_e.models.transmitter.base import Transmitter, VectorDecoder
12
- from shap_e.rendering.torch_mesh import TorchMesh
13
- from shap_e.util.collections import AttrDict
14
- from shap_e.util.image_util import load_image
15
-
16
-
17
- # Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L15-L42
18
- def create_pan_cameras(size: int,
19
- device: torch.device) -> DifferentiableCameraBatch:
20
- origins = []
21
- xs = []
22
- ys = []
23
- zs = []
24
- for theta in np.linspace(0, 2 * np.pi, num=20):
25
- z = np.array([np.sin(theta), np.cos(theta), -0.5])
26
- z /= np.sqrt(np.sum(z**2))
27
- origin = -z * 4
28
- x = np.array([np.cos(theta), -np.sin(theta), 0.0])
29
- y = np.cross(z, x)
30
- origins.append(origin)
31
- xs.append(x)
32
- ys.append(y)
33
- zs.append(z)
34
- return DifferentiableCameraBatch(
35
- shape=(1, len(xs)),
36
- flat_camera=DifferentiableProjectiveCamera(
37
- origin=torch.from_numpy(np.stack(origins,
38
- axis=0)).float().to(device),
39
- x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
40
- y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
41
- z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
42
- width=size,
43
- height=size,
44
- x_fov=0.7,
45
- y_fov=0.7,
46
- ),
47
- )
48
-
49
-
50
- # Copied from https://github.com/openai/shap-e/blob/8625e7c15526d8510a2292f92165979268d0e945/shap_e/util/notebooks.py#LL64C1-L76C33
51
- @torch.no_grad()
52
- def decode_latent_mesh(
53
- xm: Transmitter | VectorDecoder,
54
- latent: torch.Tensor,
55
- ) -> TorchMesh:
56
- decoded = xm.renderer.render_views(
57
- AttrDict(cameras=create_pan_cameras(
58
- 2, latent.device)), # lowest resolution possible
59
- params=(xm.encoder if isinstance(xm, Transmitter) else
60
- xm).bottleneck_to_params(latent[None]),
61
- options=AttrDict(rendering_mode='stf', render_with_direction=False),
62
- )
63
- return decoded.raw_meshes[0]
64
 
65
 
66
  class Model:
67
  def __init__(self):
68
  self.device = torch.device(
69
  'cuda' if torch.cuda.is_available() else 'cpu')
70
- self.xm = load_model('transmitter', device=self.device)
71
- self.diffusion = diffusion_from_config(load_config('diffusion'))
72
- self.model_text = None
73
- self.model_image = None
74
 
75
- def load_model(self, model_name: str) -> None:
76
- assert model_name in ['text300M', 'image300M']
77
- if model_name == 'text300M' and self.model_text is None:
78
- self.model_text = load_model(model_name, device=self.device)
79
- elif model_name == 'image300M' and self.model_image is None:
80
- self.model_image = load_model(model_name, device=self.device)
81
 
82
- def to_glb(self, latent: torch.Tensor) -> str:
83
- ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
84
- delete=False,
85
- mode='w+b')
86
- decode_latent_mesh(self.xm, latent).tri_mesh().write_ply(ply_path)
87
-
88
- mesh = trimesh.load(ply_path.name)
89
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
90
  mesh = mesh.apply_transform(rot)
91
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
92
  mesh = mesh.apply_transform(rot)
93
-
94
  mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
95
  mesh.export(mesh_path.name, file_type='glb')
96
-
97
  return mesh_path.name
98
 
99
  def run_text(self,
@@ -101,48 +35,31 @@ class Model:
101
  seed: int = 0,
102
  guidance_scale: float = 15.0,
103
  num_steps: int = 64) -> str:
104
- self.load_model('text300M')
105
- torch.manual_seed(seed)
106
-
107
- latents = sample_latents(
108
- batch_size=1,
109
- model=self.model_text,
110
- diffusion=self.diffusion,
111
- guidance_scale=guidance_scale,
112
- model_kwargs=dict(texts=[prompt]),
113
- progress=True,
114
- clip_denoised=True,
115
- use_fp16=True,
116
- use_karras=True,
117
- karras_steps=num_steps,
118
- sigma_min=1e-3,
119
- sigma_max=160,
120
- s_churn=0,
121
- )
122
- return self.to_glb(latents[0])
123
 
124
  def run_image(self,
125
- image_path: str,
126
  seed: int = 0,
127
  guidance_scale: float = 3.0,
128
  num_steps: int = 64) -> str:
129
- self.load_model('image300M')
130
- torch.manual_seed(seed)
131
-
132
- image = load_image(image_path)
133
- latents = sample_latents(
134
- batch_size=1,
135
- model=self.model_image,
136
- diffusion=self.diffusion,
137
- guidance_scale=guidance_scale,
138
- model_kwargs=dict(images=[image]),
139
- progress=True,
140
- clip_denoised=True,
141
- use_fp16=True,
142
- use_karras=True,
143
- karras_steps=num_steps,
144
- sigma_min=1e-3,
145
- sigma_max=160,
146
- s_churn=0,
147
- )
148
- return self.to_glb(latents[0])
 
1
  import tempfile
2
 
3
  import numpy as np
4
+ import PIL.Image
5
  import torch
6
  import trimesh
7
+ from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
8
+ from diffusers.utils import export_to_ply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class Model:
12
  def __init__(self):
13
  self.device = torch.device(
14
  'cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.pipe = ShapEPipeline.from_pretrained('YiYiXu/shap-e',
16
+ torch_dtype=torch.float16)
17
+ self.pipe.to(self.device)
 
18
 
19
+ self.pipe_img = ShapEImg2ImgPipeline.from_pretrained(
20
+ 'YiYiXu/shap-e-img2img', torch_dtype=torch.float16)
21
+ self.pipe_img.to(self.device)
 
 
 
22
 
23
+ def to_glb(self, ply_path: str) -> str:
24
+ mesh = trimesh.load(ply_path)
 
 
 
 
 
25
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
26
  mesh = mesh.apply_transform(rot)
27
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
28
  mesh = mesh.apply_transform(rot)
 
29
  mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
30
  mesh.export(mesh_path.name, file_type='glb')
 
31
  return mesh_path.name
32
 
33
  def run_text(self,
 
35
  seed: int = 0,
36
  guidance_scale: float = 15.0,
37
  num_steps: int = 64) -> str:
38
+ generator = torch.Generator(device=self.device).manual_seed(seed)
39
+ images = self.pipe(prompt,
40
+ generator=generator,
41
+ guidance_scale=guidance_scale,
42
+ num_inference_steps=num_steps,
43
+ output_type='mesh').images
44
+ ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
45
+ delete=False,
46
+ mode='w+b')
47
+ export_to_ply(images[0], ply_path.name)
48
+ return self.to_glb(ply_path.name)
 
 
 
 
 
 
 
 
49
 
50
  def run_image(self,
51
+ image: PIL.Image.Image,
52
  seed: int = 0,
53
  guidance_scale: float = 3.0,
54
  num_steps: int = 64) -> str:
55
+ generator = torch.Generator(device=self.device).manual_seed(seed)
56
+ images = self.pipe_img(image,
57
+ generator=generator,
58
+ guidance_scale=guidance_scale,
59
+ num_inference_steps=num_steps,
60
+ output_type='mesh').images
61
+ ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
62
+ delete=False,
63
+ mode='w+b')
64
+ export_to_ply(images[0], ply_path.name)
65
+ return self.to_glb(ply_path.name)
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- git+https://github.com/openai/shap-e@8625e7c
2
  gradio==3.36.1
3
  torch==2.0.1
4
  torchvision==0.15.2
5
- trimesh==3.22.1
 
1
+ git+https://github.com/huggingface/diffusers@shap-ee-mesh
2
  gradio==3.36.1
3
  torch==2.0.1
4
  torchvision==0.15.2
5
+ trimesh==3.22.3