mboss commited on
Commit
4d8c3d6
·
1 Parent(s): 64fccd8

Update inference to latest

Browse files
Files changed (6) hide show
  1. __init__.py +7 -2
  2. gradio_app.py +1 -1
  3. run.py +12 -2
  4. spar3d/models/network.py +5 -2
  5. spar3d/system.py +242 -47
  6. spar3d/utils.py +1 -1
__init__.py CHANGED
@@ -29,14 +29,19 @@ class SPAR3DLoader:
29
 
30
  @classmethod
31
  def INPUT_TYPES(cls):
32
- return {"required": {}}
 
 
 
 
33
 
34
- def load(self):
35
  device = comfy.model_management.get_torch_device()
36
  model = SPAR3D.from_pretrained(
37
  SPAR3D_MODEL_NAME,
38
  config_name="config.yaml",
39
  weight_name="model.safetensors",
 
40
  )
41
  model.to(device)
42
  model.eval()
 
29
 
30
  @classmethod
31
  def INPUT_TYPES(cls):
32
+ return {
33
+ "required": {
34
+ "low_vram_mode": ("BOOLEAN", {"default": False}),
35
+ }
36
+ }
37
 
38
+ def load(self, low_vram_mode=False):
39
  device = comfy.model_management.get_torch_device()
40
  model = SPAR3D.from_pretrained(
41
  SPAR3D_MODEL_NAME,
42
  config_name="config.yaml",
43
  weight_name="model.safetensors",
44
+ low_vram_mode=low_vram_mode,
45
  )
46
  model.to(device)
47
  model.eval()
gradio_app.py CHANGED
@@ -148,7 +148,7 @@ def run_model(
148
  start = time.time()
149
  with torch.no_grad():
150
  with (
151
- torch.autocast(device_type=device, dtype=torch.float16)
152
  if "cuda" in device
153
  else nullcontext()
154
  ):
 
148
  start = time.time()
149
  with torch.no_grad():
150
  with (
151
+ torch.autocast(device_type=device, dtype=torch.bfloat16)
152
  if "cuda" in device
153
  else nullcontext()
154
  ):
run.py CHANGED
@@ -54,6 +54,15 @@ if __name__ == "__main__":
54
  type=int,
55
  help="Texture atlas resolution. Default: 1024",
56
  )
 
 
 
 
 
 
 
 
 
57
 
58
  remesh_choices = ["none"]
59
  if TRIANGLE_REMESH_AVAILABLE:
@@ -102,6 +111,7 @@ if __name__ == "__main__":
102
  args.pretrained_model,
103
  config_name="config.yaml",
104
  weight_name="model.safetensors",
 
105
  )
106
  model.to(device)
107
  model.eval()
@@ -149,7 +159,7 @@ if __name__ == "__main__":
149
  torch.cuda.reset_peak_memory_stats()
150
  with torch.no_grad():
151
  with (
152
- torch.autocast(device_type=device, dtype=torch.float16)
153
  if "cuda" in device
154
  else nullcontext()
155
  ):
@@ -157,7 +167,7 @@ if __name__ == "__main__":
157
  image,
158
  bake_resolution=args.texture_resolution,
159
  remesh=args.remesh_option,
160
- vertex_count=args.target_vertex_count,
161
  return_points=True,
162
  )
163
  if torch.cuda.is_available():
 
54
  type=int,
55
  help="Texture atlas resolution. Default: 1024",
56
  )
57
+ parser.add_argument(
58
+ "--low-vram-mode",
59
+ action="store_true",
60
+ help=(
61
+ "Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. "
62
+ "This mode will reduce the VRAM consumption to roughly 7GB but in exchange "
63
+ "the model will be slower. Default: False"
64
+ ),
65
+ )
66
 
67
  remesh_choices = ["none"]
68
  if TRIANGLE_REMESH_AVAILABLE:
 
111
  args.pretrained_model,
112
  config_name="config.yaml",
113
  weight_name="model.safetensors",
114
+ low_vram_mode=args.low_vram_mode,
115
  )
116
  model.to(device)
117
  model.eval()
 
159
  torch.cuda.reset_peak_memory_stats()
160
  with torch.no_grad():
161
  with (
162
+ torch.autocast(device_type=device, dtype=torch.bfloat16)
163
  if "cuda" in device
164
  else nullcontext()
165
  ):
 
167
  image,
168
  bake_resolution=args.texture_resolution,
169
  remesh=args.remesh_option,
170
+ vertex_count=vertex_count,
171
  return_points=True,
172
  )
173
  if torch.cuda.is_available():
spar3d/models/network.py CHANGED
@@ -7,8 +7,8 @@ import torch.nn.functional as F
7
  from einops import rearrange
8
  from jaxtyping import Float
9
  from torch import Tensor
 
10
  from torch.autograd import Function
11
- from torch.cuda.amp import custom_bwd, custom_fwd
12
 
13
  from spar3d.models.utils import BaseModule, normalize
14
  from spar3d.utils import get_device
@@ -79,7 +79,10 @@ class _TruncExp(Function): # pylint: disable=abstract-method
79
  # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
  @staticmethod
81
  @conditional_decorator(
82
- custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32
 
 
 
83
  )
84
  def forward(ctx, x): # pylint: disable=arguments-differ
85
  ctx.save_for_backward(x)
 
7
  from einops import rearrange
8
  from jaxtyping import Float
9
  from torch import Tensor
10
+ from torch.amp import custom_bwd, custom_fwd
11
  from torch.autograd import Function
 
12
 
13
  from spar3d.models.utils import BaseModule, normalize
14
  from spar3d.utils import get_device
 
79
  # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
  @staticmethod
81
  @conditional_decorator(
82
+ custom_fwd,
83
+ "cuda" in get_device(),
84
+ cast_inputs=torch.float32,
85
+ device_type="cuda",
86
  )
87
  def forward(ctx, x): # pylint: disable=arguments-differ
88
  ctx.save_for_backward(x)
spar3d/system.py CHANGED
@@ -12,7 +12,7 @@ from huggingface_hub import hf_hub_download
12
  from jaxtyping import Float
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
- from safetensors.torch import load_model
16
  from torch import Tensor
17
 
18
  from spar3d.models.diffusion.gaussian_diffusion import (
@@ -115,11 +115,17 @@ class SPAR3D(BaseModule):
115
  sigma_max: float = 120.0
116
  s_churn: float = 3.0
117
 
 
 
118
  cfg: Config
119
 
120
  @classmethod
121
  def from_pretrained(
122
- cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
 
 
 
 
123
  ):
124
  base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
125
  if os.path.isdir(os.path.join(base_dir, pretrained_model_name_or_path)):
@@ -139,8 +145,18 @@ class SPAR3D(BaseModule):
139
 
140
  cfg = OmegaConf.load(config_path)
141
  OmegaConf.resolve(cfg)
 
 
 
 
 
142
  model = cls(cfg)
143
- load_model(model, weight_path, strict=False)
 
 
 
 
 
144
  return model
145
 
146
  @property
@@ -148,39 +164,52 @@ class SPAR3D(BaseModule):
148
  return next(self.parameters()).device
149
 
150
  def configure(self):
151
- self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
152
- self.cfg.image_tokenizer
153
- )
154
- self.point_embedder = find_class(self.cfg.point_embedder_cls)(
155
- self.cfg.point_embedder
156
- )
157
- self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
158
- self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
159
- self.cfg.camera_embedder
160
- )
161
- self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
162
- self.post_processor = find_class(self.cfg.post_processor_cls)(
163
- self.cfg.post_processor
164
- )
165
- self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
166
- self.image_estimator = find_class(self.cfg.image_estimator_cls)(
167
- self.cfg.image_estimator
168
- )
169
- self.global_estimator = find_class(self.cfg.global_estimator_cls)(
170
- self.cfg.global_estimator
171
- )
172
 
173
- # point diffusion modules
174
- self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)(
175
- self.cfg.pdiff_image_tokenizer
176
- )
177
- self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)(
178
- self.cfg.pdiff_camera_embedder
 
 
 
179
  )
180
- self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)(
181
- self.cfg.pdiff_backbone
 
 
 
 
 
182
  )
183
 
 
 
 
 
 
 
 
 
184
  self.bbox: Float[Tensor, "2 3"]
185
  self.register_buffer(
186
  "bbox",
@@ -206,30 +235,151 @@ class SPAR3D(BaseModule):
206
  self.baker = TextureBaker()
207
  self.image_processor = ImageProcessor()
208
 
209
- channel_scales = [self.cfg.scale_factor_xyz] * 3
210
- channel_scales += [self.cfg.scale_factor_rgb] * 3
211
- channel_biases = [self.cfg.bias_xyz] * 3
212
- channel_biases += [self.cfg.bias_rgb] * 3
213
- channel_scales = np.array(channel_scales)
214
- channel_biases = np.array(channel_biases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- betas = get_named_beta_schedule(
217
- self.cfg.diffu_sched, self.cfg.train_time_steps, self.cfg.diffu_sched_exp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  )
219
 
220
- diffusion_kwargs = dict(
221
- betas=betas,
222
- model_mean_type=self.cfg.mean_type,
223
- model_var_type=self.cfg.var_type,
224
- channel_scales=channel_scales,
225
- channel_biases=channel_biases,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  self.diffusion_spaced = SpacedDiffusion(
228
  use_timesteps=space_timesteps(
229
  self.cfg.train_time_steps,
230
  "ddim" + str(self.cfg.inference_time_steps),
231
  ),
232
- **diffusion_kwargs,
233
  )
234
  self.sampler = PointCloudSampler(
235
  model=self.pdiff_backbone,
@@ -243,6 +393,35 @@ class SPAR3D(BaseModule):
243
  s_churn=self.cfg.s_churn,
244
  )
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def triplane_to_meshes(
247
  self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
248
  ) -> list[Mesh]:
@@ -303,6 +482,11 @@ class SPAR3D(BaseModule):
303
  return out
304
 
305
  def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
 
 
 
 
 
306
  # if batch[rgb_cond] is only one view, add a view dimension
307
  if len(batch["rgb_cond"].shape) == 4:
308
  batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
@@ -340,9 +524,15 @@ class SPAR3D(BaseModule):
340
 
341
  direct_codes = self.tokenizer.detokenize(tokens)
342
  scene_codes = self.post_processor(direct_codes)
 
343
  return scene_codes, direct_codes
344
 
345
  def forward_pdiff_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
346
  if len(batch["rgb_cond"].shape) == 4:
347
  batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
348
  batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
@@ -512,6 +702,11 @@ class SPAR3D(BaseModule):
512
  output_rotation = rotation2 @ rotation
513
 
514
  global_dict = {}
 
 
 
 
 
515
  if self.image_estimator is not None:
516
  global_dict.update(
517
  self.image_estimator(
 
12
  from jaxtyping import Float
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
+ from safetensors.torch import load_file, load_model
16
  from torch import Tensor
17
 
18
  from spar3d.models.diffusion.gaussian_diffusion import (
 
115
  sigma_max: float = 120.0
116
  s_churn: float = 3.0
117
 
118
+ low_vram_mode: bool = False
119
+
120
  cfg: Config
121
 
122
  @classmethod
123
  def from_pretrained(
124
+ cls,
125
+ pretrained_model_name_or_path: str,
126
+ config_name: str,
127
+ weight_name: str,
128
+ low_vram_mode: bool = False,
129
  ):
130
  base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
131
  if os.path.isdir(os.path.join(base_dir, pretrained_model_name_or_path)):
 
145
 
146
  cfg = OmegaConf.load(config_path)
147
  OmegaConf.resolve(cfg)
148
+ # Add in low_vram_mode to the config
149
+ if os.environ.get("SPAR3D_LOW_VRAM", "0") == "1" and torch.cuda.is_available():
150
+ cfg.low_vram_mode = True
151
+ else:
152
+ cfg.low_vram_mode = low_vram_mode if torch.cuda.is_available() else False
153
  model = cls(cfg)
154
+
155
+ if not model.cfg.low_vram_mode:
156
+ load_model(model, weight_path, strict=False)
157
+ else:
158
+ model._state_dict = load_file(weight_path, device="cpu")
159
+
160
  return model
161
 
162
  @property
 
164
  return next(self.parameters()).device
165
 
166
  def configure(self):
167
+ # Initialize all modules as None
168
+ self.image_tokenizer = None
169
+ self.point_embedder = None
170
+ self.tokenizer = None
171
+ self.camera_embedder = None
172
+ self.backbone = None
173
+ self.post_processor = None
174
+ self.decoder = None
175
+ self.image_estimator = None
176
+ self.global_estimator = None
177
+ self.pdiff_image_tokenizer = None
178
+ self.pdiff_camera_embedder = None
179
+ self.pdiff_backbone = None
180
+ self.diffusion_spaced = None
181
+ self.sampler = None
182
+
183
+ # Dummy parameter to safe the device placement for dynamic loading
184
+ self.dummy_param = torch.nn.Parameter(torch.tensor(0.0))
 
 
 
185
 
186
+ channel_scales = [self.cfg.scale_factor_xyz] * 3
187
+ channel_scales += [self.cfg.scale_factor_rgb] * 3
188
+ channel_biases = [self.cfg.bias_xyz] * 3
189
+ channel_biases += [self.cfg.bias_rgb] * 3
190
+ channel_scales = np.array(channel_scales)
191
+ channel_biases = np.array(channel_biases)
192
+
193
+ betas = get_named_beta_schedule(
194
+ self.cfg.diffu_sched, self.cfg.train_time_steps, self.cfg.diffu_sched_exp
195
  )
196
+
197
+ self.diffusion_kwargs = dict(
198
+ betas=betas,
199
+ model_mean_type=self.cfg.mean_type,
200
+ model_var_type=self.cfg.var_type,
201
+ channel_scales=channel_scales,
202
+ channel_biases=channel_biases,
203
  )
204
 
205
+ self.is_low_vram = self.cfg.low_vram_mode and get_device() == "cuda"
206
+
207
+ # Create CPU shadow copy if in low VRAM mode
208
+ if not self.is_low_vram:
209
+ self._load_all_modules()
210
+ else:
211
+ print("Loading in low VRAM mode")
212
+
213
  self.bbox: Float[Tensor, "2 3"]
214
  self.register_buffer(
215
  "bbox",
 
235
  self.baker = TextureBaker()
236
  self.image_processor = ImageProcessor()
237
 
238
+ def _load_all_modules(self):
239
+ """Load all modules into memory"""
240
+ # Load modules to specified device
241
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
242
+ self.cfg.image_tokenizer
243
+ ).to(self.device)
244
+ self.point_embedder = find_class(self.cfg.point_embedder_cls)(
245
+ self.cfg.point_embedder
246
+ ).to(self.device)
247
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer).to(
248
+ self.device
249
+ )
250
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
251
+ self.cfg.camera_embedder
252
+ ).to(self.device)
253
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone).to(
254
+ self.device
255
+ )
256
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
257
+ self.cfg.post_processor
258
+ ).to(self.device)
259
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder).to(
260
+ self.device
261
+ )
262
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
263
+ self.cfg.image_estimator
264
+ ).to(self.device)
265
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
266
+ self.cfg.global_estimator
267
+ ).to(self.device)
268
+ self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)(
269
+ self.cfg.pdiff_image_tokenizer
270
+ ).to(self.device)
271
+ self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)(
272
+ self.cfg.pdiff_camera_embedder
273
+ ).to(self.device)
274
+ self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)(
275
+ self.cfg.pdiff_backbone
276
+ ).to(self.device)
277
 
278
+ self.diffusion_spaced = SpacedDiffusion(
279
+ use_timesteps=space_timesteps(
280
+ self.cfg.train_time_steps,
281
+ "ddim" + str(self.cfg.inference_time_steps),
282
+ ),
283
+ **self.diffusion_kwargs,
284
+ )
285
+ self.sampler = PointCloudSampler(
286
+ model=self.pdiff_backbone,
287
+ diffusion=self.diffusion_spaced,
288
+ num_points=512,
289
+ point_dim=6,
290
+ guidance_scale=self.cfg.guidance_scale,
291
+ clip_denoised=True,
292
+ sigma_min=1e-3,
293
+ sigma_max=self.cfg.sigma_max,
294
+ s_churn=self.cfg.s_churn,
295
  )
296
 
297
+ def _load_main_modules(self):
298
+ """Load the main processing modules"""
299
+ if all(
300
+ [
301
+ self.image_tokenizer,
302
+ self.point_embedder,
303
+ self.tokenizer,
304
+ self.camera_embedder,
305
+ self.backbone,
306
+ self.post_processor,
307
+ self.decoder,
308
+ ]
309
+ ):
310
+ return # Main modules already loaded
311
+
312
+ device = next(self.parameters()).device # Get the current device
313
+
314
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
315
+ self.cfg.image_tokenizer
316
+ ).to(device)
317
+ self.point_embedder = find_class(self.cfg.point_embedder_cls)(
318
+ self.cfg.point_embedder
319
+ ).to(device)
320
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer).to(
321
+ device
322
  )
323
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
324
+ self.cfg.camera_embedder
325
+ ).to(device)
326
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone).to(device)
327
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
328
+ self.cfg.post_processor
329
+ ).to(device)
330
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder).to(device)
331
+
332
+ # Restore weights if we have a checkpoint path
333
+ if hasattr(self, "_state_dict"):
334
+ self.load_state_dict(self._state_dict, strict=False)
335
+
336
+ def _load_estimator_modules(self):
337
+ """Load the estimator modules"""
338
+ if all([self.image_estimator, self.global_estimator]):
339
+ return # Estimator modules already loaded
340
+
341
+ device = next(self.parameters()).device # Get the current device
342
+
343
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
344
+ self.cfg.image_estimator
345
+ ).to(device)
346
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
347
+ self.cfg.global_estimator
348
+ ).to(device)
349
+
350
+ # Restore weights if we have a checkpoint path
351
+ if hasattr(self, "_state_dict"):
352
+ self.load_state_dict(self._state_dict, strict=False)
353
+
354
+ def _load_pdiff_modules(self):
355
+ """Load only the point diffusion modules"""
356
+ if all(
357
+ [
358
+ self.pdiff_image_tokenizer,
359
+ self.pdiff_camera_embedder,
360
+ self.pdiff_backbone,
361
+ ]
362
+ ):
363
+ return # PDiff modules already loaded
364
+
365
+ device = next(self.parameters()).device # Get the current device
366
+
367
+ self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)(
368
+ self.cfg.pdiff_image_tokenizer
369
+ ).to(device)
370
+ self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)(
371
+ self.cfg.pdiff_camera_embedder
372
+ ).to(device)
373
+ self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)(
374
+ self.cfg.pdiff_backbone
375
+ ).to(device)
376
+
377
  self.diffusion_spaced = SpacedDiffusion(
378
  use_timesteps=space_timesteps(
379
  self.cfg.train_time_steps,
380
  "ddim" + str(self.cfg.inference_time_steps),
381
  ),
382
+ **self.diffusion_kwargs,
383
  )
384
  self.sampler = PointCloudSampler(
385
  model=self.pdiff_backbone,
 
393
  s_churn=self.cfg.s_churn,
394
  )
395
 
396
+ # Restore weights if we have a checkpoint path
397
+ if hasattr(self, "_state_dict"):
398
+ self.load_state_dict(self._state_dict, strict=False)
399
+
400
+ def _unload_pdiff_modules(self):
401
+ """Unload point diffusion modules to free memory"""
402
+ self.pdiff_image_tokenizer = None
403
+ self.pdiff_camera_embedder = None
404
+ self.pdiff_backbone = None
405
+ self.diffusion_spaced = None
406
+ self.sampler = None
407
+ torch.cuda.empty_cache()
408
+
409
+ def _unload_main_modules(self):
410
+ """Unload main processing modules to free memory"""
411
+ self.image_tokenizer = None
412
+ self.point_embedder = None
413
+ self.tokenizer = None
414
+ self.camera_embedder = None
415
+ self.backbone = None
416
+ self.post_processor = None
417
+ torch.cuda.empty_cache()
418
+
419
+ def _unload_estimator_modules(self):
420
+ """Unload estimator modules to free memory"""
421
+ self.image_estimator = None
422
+ self.global_estimator = None
423
+ torch.cuda.empty_cache()
424
+
425
  def triplane_to_meshes(
426
  self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
427
  ) -> list[Mesh]:
 
482
  return out
483
 
484
  def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
485
+ if self.is_low_vram:
486
+ self._unload_pdiff_modules()
487
+ self._unload_estimator_modules()
488
+ self._load_main_modules()
489
+
490
  # if batch[rgb_cond] is only one view, add a view dimension
491
  if len(batch["rgb_cond"].shape) == 4:
492
  batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
 
524
 
525
  direct_codes = self.tokenizer.detokenize(tokens)
526
  scene_codes = self.post_processor(direct_codes)
527
+
528
  return scene_codes, direct_codes
529
 
530
  def forward_pdiff_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
531
+ if self.is_low_vram:
532
+ self._unload_main_modules()
533
+ self._unload_estimator_modules()
534
+ self._load_pdiff_modules()
535
+
536
  if len(batch["rgb_cond"].shape) == 4:
537
  batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
538
  batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
 
702
  output_rotation = rotation2 @ rotation
703
 
704
  global_dict = {}
705
+ if self.is_low_vram:
706
+ self._unload_pdiff_modules()
707
+ self._unload_main_modules()
708
+ self._load_estimator_modules()
709
+
710
  if self.image_estimator is not None:
711
  global_dict.update(
712
  self.image_estimator(
spar3d/utils.py CHANGED
@@ -10,7 +10,7 @@ import spar3d.models.utils as spar3d_utils
10
 
11
 
12
  def get_device():
13
- if os.environ.get("SF3D_USE_CPU", "0") == "1":
14
  return "cpu"
15
 
16
  device = "cpu"
 
10
 
11
 
12
  def get_device():
13
+ if os.environ.get("SPAR3D_USE_CPU", "0") == "1":
14
  return "cpu"
15
 
16
  device = "cpu"