|
import datetime |
|
import gradio |
|
import subprocess |
|
from PIL import Image |
|
import torch, torch.backends.cudnn, torch.backends.cuda |
|
from min_dalle import MinDalle |
|
from emoji import demojize |
|
import string |
|
|
|
def filename_from_text(text: str) -> str: |
|
text = demojize(text, delimiters=['', '']) |
|
text = text.lower().encode('ascii', errors='ignore').decode() |
|
allowed_chars = string.ascii_lowercase + ' ' |
|
text = ''.join(i for i in text.lower() if i in allowed_chars) |
|
text = text[:64] |
|
text = '-'.join(text.strip().split()) |
|
if len(text) == 0: text = 'blank' |
|
return text |
|
|
|
def log_gpu_memory(): |
|
print("Date:{}, GPU memory:{}".format(str(datetime.datetime.now()), subprocess.check_output('nvidia-smi').decode('utf-8'))) |
|
|
|
log_gpu_memory() |
|
|
|
model = MinDalle( |
|
is_mega=True, |
|
is_reusable=True, |
|
device='cuda', |
|
dtype=torch.float32 |
|
) |
|
|
|
log_gpu_memory() |
|
|
|
def run_model( |
|
text: str, |
|
grid_size: int, |
|
is_seamless: bool, |
|
save_as_png: bool, |
|
temperature: float, |
|
supercondition: str, |
|
top_k: str |
|
) -> str: |
|
torch.set_grad_enabled(False) |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.deterministic = False |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True |
|
|
|
print("Date:{}".format(str(datetime.datetime.now()))) |
|
print('text:', text) |
|
print('grid_size:', grid_size) |
|
print('is_seamless:', is_seamless) |
|
print('temperature:', temperature) |
|
print('supercondition:', supercondition) |
|
print('top_k:', top_k) |
|
|
|
try: |
|
temperature = float(temperature) |
|
assert(temperature > 1e-6) |
|
except: |
|
raise Exception('Temperature must be a positive nonzero number') |
|
try: |
|
grid_size = int(grid_size) |
|
assert(grid_size <= 5) |
|
assert(grid_size >= 1) |
|
except: |
|
raise Exception('Grid size must be between 1 and 5') |
|
try: |
|
top_k = int(top_k) |
|
assert(top_k <= 16384) |
|
assert(top_k >= 1) |
|
except: |
|
raise Exception('Top k must be between 1 and 16384') |
|
|
|
with torch.no_grad(): |
|
image = model.generate_image( |
|
text = text, |
|
seed = -1, |
|
grid_size = grid_size, |
|
is_seamless = bool(is_seamless), |
|
temperature = temperature, |
|
supercondition_factor = float(supercondition), |
|
top_k = top_k, |
|
is_verbose = True |
|
) |
|
|
|
log_gpu_memory() |
|
|
|
ext = 'png' if bool(save_as_png) else 'jpg' |
|
filename = filename_from_text(text) |
|
image_path = '{}.{}'.format(filename, ext) |
|
image.save(image_path) |
|
|
|
return image_path |
|
|
|
demo = gradio.Blocks(analytics_enabled=True) |
|
|
|
with demo: |
|
with gradio.Row(): |
|
with gradio.Column(): |
|
input_text = gradio.Textbox( |
|
label='Input Text', |
|
value='Moai statue giving a TED Talk', |
|
lines=3 |
|
) |
|
run_button = gradio.Button(value='Generate Image').style(full_width=True) |
|
''' |
|
output_image = gradio.Image( |
|
value='examples/moai-statue.jpg', |
|
label='Output Image', |
|
type='file', |
|
interactive=False |
|
) |
|
''' |
|
|
|
with gradio.Column(): |
|
gradio.Markdown('## Settings') |
|
with gradio.Row(): |
|
grid_size = gradio.Slider( |
|
label='Grid Size', |
|
value=5, |
|
minimum=1, |
|
maximum=5, |
|
step=1 |
|
) |
|
save_as_png = gradio.Checkbox( |
|
label='Output PNG', |
|
value=False |
|
) |
|
is_seamless = gradio.Checkbox( |
|
label='Seamless', |
|
value=False |
|
) |
|
gradio.Markdown('#### Advanced') |
|
with gradio.Row(): |
|
temperature = gradio.Number( |
|
label='Temperature', |
|
value=1 |
|
) |
|
top_k = gradio.Dropdown( |
|
label='Top-k', |
|
choices=[str(2 ** i) for i in range(15)], |
|
value='128' |
|
) |
|
supercondition = gradio.Dropdown( |
|
label='Super Condition', |
|
choices=[str(2 ** i) for i in range(2, 7)], |
|
value='16' |
|
) |
|
|
|
gradio.Markdown( |
|
""" |
|
#### |
|
- **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. |
|
- **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. |
|
- **Seamless**: Tile images in image token space instead of pixel space. |
|
- **Temperature**: High temperature increases the probability of sampling low scoring image tokens. |
|
- **Top-k**: Each image token is sampled from the top-k scoring tokens. |
|
- **Super Condition**: Higher values can result in better agreement with the text. |
|
""" |
|
) |
|
|
|
gradio.Examples( |
|
examples=[ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
['Astronaut riding a horse hyperrealistic', 1], |
|
|
|
|
|
|
|
|
|
], |
|
inputs=[ |
|
input_text, |
|
grid_size, |
|
|
|
], |
|
examples_per_page=20 |
|
) |
|
|
|
run_button.click( |
|
fn=run_model, |
|
inputs=[ |
|
input_text, |
|
grid_size, |
|
is_seamless, |
|
save_as_png, |
|
temperature, |
|
supercondition, |
|
top_k |
|
], |
|
outputs=[ |
|
output_image |
|
] |
|
) |
|
|
|
demo.launch() |
|
|