Bwcdocke / app (17).py
Ashrafb's picture
Upload app (17).py
52eb1db verified
import gradio as gr
import PIL
import cv2
import numpy as np
from src.deoldify import device
from src.deoldify.device_id import DeviceId
from src.deoldify.visualize import *
from src.app_utils import get_model_bin
device.set(device=DeviceId.CPU)
def load_model(model_dir, option):
if option.lower() == 'artistic':
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
colorizer = get_image_colorizer(artistic=True)
elif option.lower() == 'stable':
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
colorizer = get_image_colorizer(artistic=False)
return colorizer
def resize_img(input_img, max_size):
img = input_img.copy()
img_height, img_width = img.shape[0], img.shape[1]
if max(img_height, img_width) > max_size:
if img_height > img_width:
new_width = img_width * (max_size / img_height)
new_height = max_size
resized_img = cv2.resize(img, (int(new_width), int(new_height)))
return resized_img
elif img_height <= img_width:
new_width = img_height * (max_size / img_width)
new_height = max_size
resized_img = cv2.resize(img, (int(new_width), int(new_height)))
return resized_img
return img
def colorize_image(input_image, colorizer, img_size=800):
pil_img = input_image.convert("RGB")
img_rgb = np.array(pil_img)
resized_img_rgb = resize_img(img_rgb, img_size)
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
return output_pil_img
def app(input_image, model='Artistic'):
# Load models
colorizer = load_model('models/', model)
# Colorize the image
output_image = colorize_image(input_image, colorizer)
return output_image
title = "<span style='color: #191970;'>Aiconvert.online</span>"
gr.Interface(
app,
gr.inputs.Image(type="pil", label="Input"),
gr.Image(type="pil", label="Output", show_share_button=False),
title=title,
css="footer{display:none !important;}",
theme=gr.themes.Base(),
enable_queue=True,
allow_flagging=False
).launch()