Menyu commited on
Commit
a9fd8d2
·
verified ·
1 Parent(s): 6a22f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -15
app.py CHANGED
@@ -3,24 +3,30 @@ import gradio as gr
3
  import numpy as np
4
  import spaces
5
  import torch
6
- from diffusers import AutoPipelineForText2Image, AutoencoderKL #,EulerDiscreteScheduler
 
 
 
7
 
8
  if not torch.cuda.is_available():
9
- DESCRIPTION += "\n<p>你现在运行在CPU上 但是只支持GPU.</p>"
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 4096
13
 
14
  if torch.cuda.is_available():
15
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
 
16
  pipe = AutoPipelineForText2Image.from_pretrained(
17
  "John6666/noobai-xl-nai-xl-epsilonpred075version-sdxl",
18
  vae=vae,
19
  torch_dtype=torch.float16,
20
  use_safetensors=True,
21
- add_watermarker=False
22
  )
23
- #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
24
  pipe.to("cuda")
25
 
26
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
@@ -43,10 +49,40 @@ def infer(
43
  progress=gr.Progress(track_tqdm=True),
44
  ):
45
  seed = int(randomize_seed_fn(seed, randomize_seed))
46
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  image = pipe(
48
- prompt=prompt,
49
- negative_prompt=negative_prompt,
 
 
50
  width=width,
51
  height=height,
52
  guidance_scale=guidance_scale,
@@ -62,13 +98,13 @@ examples = [
62
  ]
63
 
64
  css = '''
65
- .gradio-container{max-width: 560px !important}
66
- h1{text-align:center}
67
- footer {
68
- visibility: hidden
69
- }
70
  '''
71
-
72
  with gr.Blocks(css=css) as demo:
73
  gr.Markdown("""# 梦羽的模型生成器
74
  ### 快速生成NoobXL的模型图片.""")
@@ -140,7 +176,7 @@ with gr.Blocks(css=css) as demo:
140
  )
141
 
142
  gr.on(
143
- triggers=[prompt.submit,run_button.click],
144
  fn=infer,
145
  inputs=[
146
  prompt,
 
3
  import numpy as np
4
  import spaces
5
  import torch
6
+ from diffusers import AutoPipelineForText2Image, AutoencoderKL # , EulerDiscreteScheduler
7
+
8
+ # 添加导入语句
9
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
10
 
11
  if not torch.cuda.is_available():
12
+ DESCRIPTION += "\n<p>你现在运行在CPU上,但是该程序仅支持GPU。</p>"
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 4096
16
 
17
  if torch.cuda.is_available():
18
+ vae = AutoencoderKL.from_pretrained(
19
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
20
+ )
21
  pipe = AutoPipelineForText2Image.from_pretrained(
22
  "John6666/noobai-xl-nai-xl-epsilonpred075version-sdxl",
23
  vae=vae,
24
  torch_dtype=torch.float16,
25
  use_safetensors=True,
26
+ add_watermarker=False,
27
  )
28
+ # pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
29
+ pipe.tokenizer.model_max_length = 512
30
  pipe.to("cuda")
31
 
32
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
  seed = int(randomize_seed_fn(seed, randomize_seed))
52
+ generator = torch.Generator("cuda").manual_seed(seed)
53
+
54
+ # 使用 get_weighted_text_embeddings_sdxl 获取文本嵌入
55
+ if use_negative_prompt and negative_prompt:
56
+ (
57
+ prompt_embeds,
58
+ prompt_neg_embeds,
59
+ pooled_prompt_embeds,
60
+ negative_pooled_prompt_embeds,
61
+ ) = get_weighted_text_embeddings_sdxl(
62
+ pipe,
63
+ prompt=prompt,
64
+ neg_prompt=negative_prompt,
65
+ device=pipe.device,
66
+ )
67
+ else:
68
+ (
69
+ prompt_embeds,
70
+ _,
71
+ pooled_prompt_embeds,
72
+ _,
73
+ ) = get_weighted_text_embeddings_sdxl(
74
+ pipe,
75
+ prompt=prompt,
76
+ device=pipe.device,
77
+ )
78
+ prompt_neg_embeds = None
79
+ negative_pooled_prompt_embeds = None
80
+
81
  image = pipe(
82
+ prompt_embeds=prompt_embeds,
83
+ negative_prompt_embeds=prompt_neg_embeds,
84
+ pooled_prompt_embeds=pooled_prompt_embeds,
85
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
86
  width=width,
87
  height=height,
88
  guidance_scale=guidance_scale,
 
98
  ]
99
 
100
  css = '''
101
+ .gradio-container{max-width: 560px !important}
102
+ h1{text-align:center}
103
+ footer {
104
+ visibility: hidden
105
+ }
106
  '''
107
+
108
  with gr.Blocks(css=css) as demo:
109
  gr.Markdown("""# 梦羽的模型生成器
110
  ### 快速生成NoobXL的模型图片.""")
 
176
  )
177
 
178
  gr.on(
179
+ triggers=[prompt.submit, run_button.click],
180
  fn=infer,
181
  inputs=[
182
  prompt,