Spaces:
Runtime error
Runtime error
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision | |
import clip | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import gradio as gr | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] | |
model, preprocess = clip.load(model_name) | |
model.to(DEVICE).eval() | |
resolution = model.visual.input_resolution | |
resizer = torchvision.transforms.Resize(size=(resolution, resolution)) | |
def create_rgb_tensor(color): | |
"""color is e.g. [1,0,0]""" | |
return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1)) | |
def encode_color(color): | |
"""color is e.g. [1,0,0]""" | |
rgb = create_rgb_tensor(color) | |
return model.encode_image( resizer(rgb) ) | |
def encode_text(text): | |
tokenized_text = clip.tokenize(text).to(DEVICE) | |
return model.encode_text(tokenized_text) | |
class RGBModel(torch.nn.Module): | |
def __init__(self, device): | |
# Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff | |
super(RGBModel, self).__init__() | |
self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2) | |
def forward(self): | |
# Clamp numbers to the closed interval [0,1] | |
self.color.data = self.color.data.clamp(0,1) | |
return self.color | |
text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square') | |
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps") | |
lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate") | |
decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay") | |
def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50): | |
rgb_model = RGBModel(device=DEVICE) | |
opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay) | |
with torch.no_grad(): | |
tokenized_text = clip.tokenize(text_prompt).to(DEVICE) | |
target_embedding = model.encode_text(tokenized_text).detach().clone() | |
def training_step(): | |
opt.zero_grad() | |
color = rgb_model() | |
color_img = resizer(color) | |
image_embedding = model.encode_image(color_img) | |
loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1) | |
loss.backward() | |
opt.step() | |
steps = [] | |
steps.append(rgb_model().cpu().detach().numpy()) | |
for iteration in range(n_iterations): | |
training_step() | |
steps.append(rgb_model().cpu().detach().numpy()) | |
steps = np.stack([steps]) | |
img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0) | |
return img_train | |
iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image") | |
iface.launch() |