|
import gradio as gr |
|
import PIL |
|
import cv2 |
|
import numpy as np |
|
from src.deoldify import device |
|
from src.deoldify.device_id import DeviceId |
|
from src.deoldify.visualize import * |
|
from src.app_utils import get_model_bin |
|
|
|
device.set(device=DeviceId.CPU) |
|
|
|
def load_model(model_dir, option): |
|
if option.lower() == 'artistic': |
|
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' |
|
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) |
|
colorizer = get_image_colorizer(artistic=True) |
|
elif option.lower() == 'stable': |
|
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" |
|
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) |
|
colorizer = get_image_colorizer(artistic=False) |
|
|
|
return colorizer |
|
|
|
def resize_img(input_img, max_size): |
|
img = input_img.copy() |
|
img_height, img_width = img.shape[0], img.shape[1] |
|
|
|
if max(img_height, img_width) > max_size: |
|
if img_height > img_width: |
|
new_width = img_width * (max_size / img_height) |
|
new_height = max_size |
|
resized_img = cv2.resize(img, (int(new_width), int(new_height))) |
|
return resized_img |
|
elif img_height <= img_width: |
|
new_width = img_height * (max_size / img_width) |
|
new_height = max_size |
|
resized_img = cv2.resize(img, (int(new_width), int(new_height))) |
|
return resized_img |
|
|
|
return img |
|
|
|
def colorize_image(input_image, colorizer, img_size=800): |
|
pil_img = input_image.convert("RGB") |
|
img_rgb = np.array(pil_img) |
|
resized_img_rgb = resize_img(img_rgb, img_size) |
|
resized_pil_img = PIL.Image.fromarray(resized_img_rgb) |
|
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) |
|
|
|
return output_pil_img |
|
|
|
def app(input_image, model='Artistic'): |
|
|
|
colorizer = load_model('models/', model) |
|
|
|
|
|
output_image = colorize_image(input_image, colorizer) |
|
|
|
return output_image |
|
|
|
|
|
|
|
title = "<span style='color: #191970;'>Aiconvert.online</span>" |
|
|
|
gr.Interface( |
|
app, |
|
gr.inputs.Image(type="pil", label="Input"), |
|
|
|
gr.Image(type="pil", label="Output", show_share_button=False), |
|
title=title, |
|
css="footer{display:none !important;}", |
|
theme=gr.themes.Base(), |
|
enable_queue=True, |
|
allow_flagging=False |
|
).launch() |
|
|