Spaces:
Sleeping
Sleeping
# App code based on: https://github.com/petergro-hub/ComicInpainting | |
# Model based on: https://github.com/saic-mdal/lama | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import os | |
from datetime import datetime | |
from PIL import Image | |
from streamlit_drawable_canvas import st_canvas | |
from io import BytesIO | |
from copy import deepcopy | |
from src.core import process_inpaint | |
def image_download_button(pil_image, filename: str, fmt: str, label="Download"): | |
if fmt not in ["jpg", "png"]: | |
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)") | |
pil_format = "JPEG" if fmt == "jpg" else "PNG" | |
file_format = "jpg" if fmt == "jpg" else "png" | |
mime = "image/jpeg" if fmt == "jpg" else "image/png" | |
buf = BytesIO() | |
pil_image.save(buf, format=pil_format) | |
return st.download_button( | |
label=label, | |
data=buf.getvalue(), | |
file_name=f'{filename}.{file_format}', | |
mime=mime, | |
) | |
if "button_id" not in st.session_state: | |
st.session_state["button_id"] = "" | |
if "color_to_label" not in st.session_state: | |
st.session_state["color_to_label"] = {} | |
if 'reuse_image' not in st.session_state: | |
st.session_state.reuse_image = None | |
def set_image(img): | |
st.session_state.reuse_image = img | |
st.title("AI Photo Object Removal") | |
st.image(open("assets/demo.png", "rb").read()) | |
st.markdown( | |
""" | |
So you want to remove an object in your photo? You don't need to learn photo editing skills. | |
**Just draw over the parts of the image you want to remove, then our AI will remove them.** | |
""" | |
) | |
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"]) | |
if uploaded_file is not None: | |
if st.session_state.reuse_image is not None: | |
img_input = Image.fromarray(st.session_state.reuse_image) | |
else: | |
bytes_data = uploaded_file.getvalue() | |
img_input = Image.open(BytesIO(bytes_data)).convert("RGBA") | |
# Resize the image while maintaining aspect ratio | |
max_size = 2000 | |
img_width, img_height = img_input.size | |
if img_width > max_size or img_height > max_size: | |
if img_width > img_height: | |
new_width = max_size | |
new_height = int((max_size / img_width) * img_height) | |
else: | |
new_height = max_size | |
new_width = int((max_size / img_height) * img_width) | |
img_input = img_input.resize((new_width, new_height)) | |
stroke_width = st.slider("Brush size", 1, 100, 50) | |
st.write("**Now draw (brush) the part of image that you want to remove.**") | |
# Canvas size logic | |
canvas_bg = deepcopy(img_input) | |
aspect_ratio = canvas_bg.width / canvas_bg.height | |
streamlit_width = 720 | |
# Max width is 720. Resize the height to maintain its aspectratio. | |
if canvas_bg.width > streamlit_width: | |
canvas_bg = canvas_bg.resize((streamlit_width, int(streamlit_width / aspect_ratio))) | |
canvas_result = st_canvas( | |
stroke_color="rgba(255, 0, 255, 1)", | |
stroke_width=stroke_width, | |
background_image=canvas_bg, | |
width=canvas_bg.width, | |
height=canvas_bg.height, | |
drawing_mode="freedraw", | |
key="compute_arc_length", | |
) | |
if canvas_result.image_data is not None: | |
im = np.array(Image.fromarray(canvas_result.image_data.astype(np.uint8)).resize(img_input.size)) | |
background = np.where( | |
(im[:, :, 0] == 0) & | |
(im[:, :, 1] == 0) & | |
(im[:, :, 2] == 0) | |
) | |
drawing = np.where( | |
(im[:, :, 0] == 255) & | |
(im[:, :, 1] == 0) & | |
(im[:, :, 2] == 255) | |
) | |
im[background]=[0,0,0,255] | |
im[drawing]=[0,0,0,0] # RGBA | |
reuse = False | |
if st.button('Submit'): | |
with st.spinner("AI is doing the magic!"): | |
output = process_inpaint(np.array(img_input), np.array(im)) #TODO Put button here | |
img_output = Image.fromarray(output).convert("RGB") | |
st.write("AI has finished the job!") | |
st.image(img_output) | |
# reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, )) | |
uploaded_name = os.path.splitext(uploaded_file.name)[0] | |
image_download_button( | |
pil_image=img_output, | |
filename=uploaded_name, | |
fmt="jpg", | |
label="Download Image" | |
) | |
st.info("**TIP**: If the result is not perfect, you can download it then " | |
"upload then remove the artifacts.") | |