|
import ffmpeg |
|
import os |
|
import torch |
|
import uuid |
|
import youtube_dl |
|
|
|
import numpy as np |
|
import streamlit as st |
|
|
|
from sentence_transformers import SentenceTransformer, util, models |
|
from clip import CLIPModel |
|
from PIL import Image |
|
|
|
@st.cache(allow_output_mutation=True, max_entries=1) |
|
def get_model(): |
|
txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu')) |
|
clip = CLIPModel() |
|
vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu')) |
|
return txt_model, vis_model |
|
|
|
|
|
def get_embedding(txt_model, vis_model, query, video): |
|
text_emb = txt_model.encode(query, device='cpu') |
|
|
|
|
|
images = [] |
|
for img in video: |
|
images.append(Image.fromarray(img)) |
|
img_embs = vis_model.encode(images, device='cpu') |
|
|
|
return text_emb, img_embs |
|
|
|
def find_frames(url, txt_model, vis_model, desc, seconds, top_k): |
|
text = st.text("Downloading video (Descargando video)...") |
|
|
|
gif_runner = st.image("./loading.gif") |
|
probe = ffmpeg.probe(url) |
|
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) |
|
width = int(video_stream['width']) |
|
height = int(video_stream['height']) |
|
out, _ = ( |
|
ffmpeg |
|
.input(url, t=seconds) |
|
.output('pipe:', format='rawvideo', pix_fmt='rgb24') |
|
.run(capture_stdout=True) |
|
) |
|
|
|
text.text("Processing video (Procesando video)...") |
|
video = ( |
|
np |
|
.frombuffer(out, np.uint8) |
|
.reshape([-1, height, width, 3]) |
|
)[::10] |
|
|
|
txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video) |
|
cos_scores = np.array(util.cos_sim(txt_embd, img_embds)) |
|
ids = np.argsort(cos_scores)[0][-top_k:] |
|
imgs = [Image.fromarray(video[i]) for i in sorted(ids)] |
|
|
|
|
|
fname = uuid.uuid4().hex |
|
imgs[0].save(fp=f"./{fname}.gif", format='GIF', append_images=imgs[1:], |
|
save_all=True, duration=200, loop=0) |
|
|
|
gif_runner.empty() |
|
text.empty() |
|
st.image(f"./{fname}.gif") |
|
|
|
os.remove(f"./{fname}.gif") |
|
st.image(imgs) |
|
|
|
with open("HOME.md", "r") as f: |
|
HOME_PAGE = f.read() |
|
|
|
with open("INICIO.md", "r") as f: |
|
INICIO_PAGINA = f.read() |
|
|
|
def main_page(txt_model, vis_model): |
|
st.title("Introducing Youtube CLIFS") |
|
|
|
st.markdown(HOME_PAGE) |
|
|
|
def inicio_pagina(txt_model, vis_model): |
|
st.title("Presentando Youtube CLIFS") |
|
|
|
st.markdown(INICIO_PAGINA) |
|
|
|
def clifs_page(txt_model, vis_model): |
|
st.title("CLIFS") |
|
|
|
st.sidebar.markdown("### Controls (Controles):") |
|
seconds = st.sidebar.slider( |
|
"How many seconds of video to consider? (¿Cuántos segundos de video considerar?)", |
|
min_value=10, |
|
max_value=120, |
|
value=60, |
|
step=1, |
|
) |
|
top_k = st.sidebar.slider( |
|
"Top K", |
|
min_value=1, |
|
max_value=20, |
|
value=10, |
|
step=1, |
|
) |
|
desc = st.sidebar.text_input( |
|
"Search Query (Consulta de Búsqueda)", |
|
value="Pancake in the shape of an otter", |
|
help="Text description of what you want to find in the video (Descripción de texto de que desea encontrar en el video)", |
|
) |
|
url = st.sidebar.text_input( |
|
"Youtube Video URL (URL del Video de Youtube)", |
|
value='https://youtu.be/xUv6XgPwGaQ', |
|
help="Youtube video you want to search (Video de Youtube que desea búscar)", |
|
) |
|
quality = st.sidebar.radio( |
|
"Quality of the Video (Calidad del Video)", |
|
[144, 240, 360, 480], |
|
help="Quality of the video to download. Higher quality takes more time (Calidad del video para descargar. Calidad más alta toma más tiempo)", |
|
) |
|
|
|
submit_button = st.sidebar.button("Search (Búscar)") |
|
if submit_button: |
|
ydl_opts = {"format": f"mp4[height={quality}]"} |
|
with youtube_dl.YoutubeDL(ydl_opts) as ydl: |
|
info_dict = ydl.extract_info(url, download=False) |
|
video_url = info_dict.get("url", None) |
|
find_frames(video_url, txt_model, vis_model, desc, seconds, top_k) |
|
|
|
PAGES = { |
|
"CLIFS": clifs_page, |
|
"Home": main_page, |
|
"Inicio": inicio_pagina, |
|
} |
|
|
|
|
|
|
|
def run(): |
|
st.set_page_config(page_title="Youtube CLIFS") |
|
|
|
txt_model, vis_model = get_model() |
|
|
|
st.sidebar.title("Navigation (Navegación)") |
|
selection = st.sidebar.radio("Go to (Ir a)", list(PAGES.keys())) |
|
|
|
page = PAGES[selection](txt_model, vis_model) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
run() |