Fabrice-TIERCELIN commited on
Commit
f8b8c3b
·
verified ·
1 Parent(s): 758a3f9

Show the interface on CPU

Browse files
Files changed (1) hide show
  1. gradio_demo.py +25 -24
gradio_demo.py CHANGED
@@ -31,32 +31,33 @@ server_ip = args.ip
31
  server_port = args.port
32
  use_llava = not args.no_llava
33
 
34
- if torch.cuda.device_count() >= 2:
35
- SUPIR_device = 'cuda:0'
36
- LLaVA_device = 'cuda:1'
37
- elif torch.cuda.device_count() == 1:
38
- SUPIR_device = 'cuda:0'
39
- LLaVA_device = 'cuda:0'
40
- else:
41
- SUPIR_device = 'cpu'
42
- LLaVA_device = 'cpu'
 
43
 
44
- # load SUPIR
45
- model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
46
- if args.loading_half_params:
47
- model = model.half()
48
- if args.use_tile_vae:
49
- model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
50
- model = model.to(SUPIR_device)
51
- model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
52
- model.current_model = 'v0-Q'
53
- ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
54
 
55
- # load LLaVA
56
- if use_llava:
57
- llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
58
- else:
59
- llava_agent = None
60
 
61
  def stage1_process(input_image, gamma_correction):
62
  if torch.cuda.device_count() == 0:
 
31
  server_port = args.port
32
  use_llava = not args.no_llava
33
 
34
+ if torch.cuda.device_count() > 0:
35
+ if torch.cuda.device_count() >= 2:
36
+ SUPIR_device = 'cuda:0'
37
+ LLaVA_device = 'cuda:1'
38
+ elif torch.cuda.device_count() == 1:
39
+ SUPIR_device = 'cuda:0'
40
+ LLaVA_device = 'cuda:0'
41
+ else:
42
+ SUPIR_device = 'cpu'
43
+ LLaVA_device = 'cpu'
44
 
45
+ # load SUPIR
46
+ model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
47
+ if args.loading_half_params:
48
+ model = model.half()
49
+ if args.use_tile_vae:
50
+ model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
51
+ model = model.to(SUPIR_device)
52
+ model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
53
+ model.current_model = 'v0-Q'
54
+ ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
55
 
56
+ # load LLaVA
57
+ if use_llava:
58
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
59
+ else:
60
+ llava_agent = None
61
 
62
  def stage1_process(input_image, gamma_correction):
63
  if torch.cuda.device_count() == 0: