yhzhai commited on
Commit
9f15f5c
·
1 Parent(s): eeba63e

detect device

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -29,9 +29,14 @@ def get_modelscope_pipeline(
29
  mcm_variant: Optional[str] = "WebVid",
30
  ):
31
  model_id = "ali-vilab/text-to-video-ms-1.7b"
32
- pipe = DiffusionPipeline.from_pretrained(
33
- model_id, torch_dtype=torch.float16, variant="fp16"
34
- )
 
 
 
 
 
35
  scheduler = LCMScheduler.from_pretrained(
36
  model_id,
37
  subfolder="scheduler",
@@ -82,14 +87,23 @@ def get_animatediff_pipeline(
82
  else:
83
  raise ValueError(f"Unknown real_variant {real_variant}")
84
 
85
- adapter = MotionAdapter.from_pretrained(
86
- motion_module_path, torch_dtype=torch.float16
87
- )
88
- pipe = AnimateDiffPipeline.from_pretrained(
89
- model_id,
90
- motion_adapter=adapter,
91
- torch_dtype=torch.float16,
92
- )
 
 
 
 
 
 
 
 
 
93
  scheduler = LCMScheduler.from_pretrained(
94
  model_id,
95
  subfolder="scheduler",
@@ -306,6 +320,11 @@ with gr.Blocks(css=css) as demo:
306
  """
307
  )
308
 
 
 
 
 
 
309
  with gr.Row():
310
  base_model = gr.Dropdown(
311
  label="Base model",
 
29
  mcm_variant: Optional[str] = "WebVid",
30
  ):
31
  model_id = "ali-vilab/text-to-video-ms-1.7b"
32
+ if torch.cuda.is_available():
33
+ pipe = DiffusionPipeline.from_pretrained(
34
+ model_id, torch_dtype=torch.float16, variant="fp16"
35
+ )
36
+ else:
37
+ pipe = DiffusionPipeline.from_pretrained(
38
+ model_id
39
+ )
40
  scheduler = LCMScheduler.from_pretrained(
41
  model_id,
42
  subfolder="scheduler",
 
87
  else:
88
  raise ValueError(f"Unknown real_variant {real_variant}")
89
 
90
+ if torch.cuda.is_available():
91
+ adapter = MotionAdapter.from_pretrained(
92
+ motion_module_path, torch_dtype=torch.float16
93
+ )
94
+ pipe = AnimateDiffPipeline.from_pretrained(
95
+ model_id,
96
+ motion_adapter=adapter,
97
+ torch_dtype=torch.float16,
98
+ )
99
+ else:
100
+ adapter = MotionAdapter.from_pretrained(
101
+ motion_module_path
102
+ )
103
+ pipe = AnimateDiffPipeline.from_pretrained(
104
+ model_id,
105
+ motion_adapter=adapter,
106
+ )
107
  scheduler = LCMScheduler.from_pretrained(
108
  model_id,
109
  subfolder="scheduler",
 
320
  """
321
  )
322
 
323
+ gr.Markdown(
324
+ f"""
325
+ <p align="center"> Currently running on {device}.</p>
326
+ """
327
+ )
328
  with gr.Row():
329
  base_model = gr.Dropdown(
330
  label="Base model",